import { useLayoutEffect, useState } from 'react';
import {
  useReactFlow,
  useStore,
  type Node,
  type ReactFlowState,
} from 'reactflow';

import { stratify, tree } from 'd3-hierarchy';

import { NODE_HEIGHT, NODE_SEPARATION, NODE_WIDTH } from '../constants/graph';

import { type EdgeData } from '../types/EdgeData';
import { type NodeData } from '../types/NodeData';

const layout = tree<Node<NodeData>>()
  .nodeSize([NODE_WIDTH, NODE_HEIGHT])
  .separation(() => NODE_SEPARATION);

export function useGraphLayout(spektrId: string) {
  const nodeLength = useStore(
    (state: ReactFlowState) => state.nodeInternals.size
  );
  const { getNodes, getNode, setNodes, getEdges, setCenter } = useReactFlow<
    NodeData,
    EdgeData
  >();
  const [wasCentered, setWasCentered] = useState(false);

  useLayoutEffect(() => {
    if (nodeLength === 0) return;

    const nodesBeforeLayout = getNodes();
    const edgesBeforeLayout = getEdges();

    const hierarchy = stratify<Node<NodeData>>()
      .id((node) => node.id)
      .parentId(
        (node) =>
          edgesBeforeLayout.find((edge) => edge.target === node.id)?.source
      );
    const root = hierarchy(nodesBeforeLayout);

    const nodesAfterLayout = layout(root)
      .descendants()
      .map((node) => ({
        ...node.data,
        position: { x: node.x, y: node.y * 2.5 },
      }));

    const currentNode = nodesAfterLayout.find((node) => node.id === spektrId);

    if (currentNode && !wasCentered) {
      setCenter(
        currentNode.position.x + NODE_WIDTH / 2,
        currentNode.position.y + NODE_HEIGHT / 2,
        {
          zoom: 1,
        }
      );
      setWasCentered(true);
    }

    setNodes(nodesAfterLayout);
  }, [
    getEdges,
    getNodes,
    nodeLength,
    setNodes,
    getNode,
    setCenter,
    spektrId,
    wasCentered,
  ]);
}
