import { useCallback, useEffect, useState } from 'react';
import type { Edge, Node, NodeAddChange, NodeChange, NodeResetChange, OnNodesChange } from 'reactflow';
import { applyNodeChanges, getConnectedEdges } from 'reactflow';

import type { Doc } from 'yjs';
import type { YMap } from 'yjs/dist/src/internals';

const isNodeAddChange = (change: NodeChange): change is NodeAddChange => change.type === 'add';
const isNodeResetChange = (change: NodeChange): change is NodeResetChange => change.type === 'reset';

function useNodesStateSynced(doc: Doc, edgesSharedState: YMap<Edge>) {
  const [nodes, setNodes] = useState<Node[]>([]);
  const nodesSharedState = doc.getMap<Node>('nodes');

  const onNodesChanges: OnNodesChange = useCallback(
    (changes) => {
      const nodes = Array.from(nodesSharedState?.values());
      const nextNodes = applyNodeChanges(changes, nodes);
      changes.forEach((change: NodeChange) => {
        if (!isNodeAddChange(change) && !isNodeResetChange(change)) {
          const node = nextNodes.find((n) => n.id === change.id);

          if (node && change.type !== 'remove') {
            nodesSharedState.set(change.id, node);
          } else if (change.type === 'remove') {
            const deletedNode = nodesSharedState.get(change.id);
            nodesSharedState.delete(change.id);
            // when a node is removed, we also need to remove the connected edges
            const edges = Array.from(edgesSharedState.values()).map((e) => e);
            const connectedEdges = getConnectedEdges(deletedNode ? [deletedNode] : [], edges);
            connectedEdges.forEach((edge) => edgesSharedState.delete(edge.id));
          }
        }
      });
    },
    [edgesSharedState, nodesSharedState],
  );

  // here we are observing the nodesSharedState and updating the nodes state whenever the map changes.
  useEffect(() => {
    const observer = () => {
      setNodes(Array.from(nodesSharedState.values()));
    };

    setNodes(Array.from(nodesSharedState.values()));
    nodesSharedState.observe(observer);

    return () => nodesSharedState.unobserve(observer);
  }, [setNodes]);

  return { nodes: nodes.filter((n) => n), nodesSharedState, onNodesChanges };
}

export default useNodesStateSynced;
