import { useLayoutEffect } from 'react';
import { useReactFlow, Node, Edge } from 'reactflow';
import { stratify, tree } from 'd3-hierarchy';

import { GraphEdgeData, GraphNodeData } from '@spektr/shared/utils';
import { EDGE, NODE } from '@spektr/shared/components';

import { useNodeCount } from './useNodeCount';

const layout = tree<Node<GraphNodeData>>()
  .nodeSize([NODE.HEIGHT, NODE.WIDTH]) // NOTE: we swap width and height to make the layout horizontal instead of vertical (default in d3)
  .separation(() => 2);

type Graph = ReturnType<typeof layout>;

function deriveShortcircuitEdges(
  nodes: Node<GraphNodeData>[],
  edges: Edge<GraphEdgeData>[]
): [string[], Edge<GraphEdgeData>[]] {
  const fallbackNodeIds: string[] = []; // Nodes whose ingoing edge is a fallback edge. These nodes make up the detours.
  const shorcircuitEdges: Edge<GraphEdgeData>[] = [];

  const hierarchy = stratify<Node<GraphNodeData>>()
    .id((descendant) => descendant.id)
    .parentId(
      (descendant) =>
        edges.find((edge) => edge.target === descendant.id)?.source
    )(nodes);
  // Add short-circuiting edges around detours
  let source: string | undefined;
  let target: string | undefined;
  for (const descendant of hierarchy.descendants()) {
    const node = descendant.data.data.node;

    // find the edge pointing to the next node in the fallback-chain.
    const outgoing = node?.adj.find((edge) => edge.type === 'fallback');
    if (!outgoing) {
      // 'node' is leaf.
      if (!source) continue; // skip leaf nodes that are not part of a detour
      // 'node is the last node on a detour.
      // mark the successor (the 'add-new'-pseudo node) as target to complete the detour
      const newNodeLeafId = edges.find(
        (edge) => edge.source === node?.id
      )?.target;

      if (!newNodeLeafId) continue;

      target = newNodeLeafId;

      const shortcircuiting = createShortCircuitingEdge(source, target);

      shorcircuitEdges.push(shortcircuiting);

      // reset source and target to find the next detour
      source = undefined;
      target = undefined;
      continue;
    }

    if (source && outgoing.type !== 'fallback') {
      // We have found the end of a detour.
      // We create a short-circuiting edge to skip the detour
      target = outgoing.id;
      const shortcircuiting = createShortCircuitingEdge(source, target);

      shorcircuitEdges.push(shortcircuiting);

      // reset source and target to find the next detour
      source = undefined;
      target = undefined;
    }

    if (!source && outgoing.type === 'fallback') {
      // 'descendant' marks the beginning of a detour.
      // We keep a reference to it to create a short-circuiting edge in an upcoming iteration
      source = descendant.id;
    }

    if (outgoing.type === 'fallback') {
      fallbackNodeIds.push(outgoing.id);
    }
  }

  return [fallbackNodeIds, shorcircuitEdges];
}
/**
 *  Documentation https://www.notion.so/Process-Layout-Algorithm-8113977a30914f6d8c979598b46df1ab?pvs=4
 */
function layoutNodes(
  nodes: Node<GraphNodeData>[],
  edges: Edge<GraphEdgeData>[]
): Node<GraphNodeData>[] {
  if (nodes.length === 0) return [];

  const [fallbackNodeIds, shortcircuits] = deriveShortcircuitEdges(
    nodes,
    edges
  );

  const edgesAll = [...edges, ...shortcircuits];
  // Exclude fallback nodes from layout
  const nodesWithoutFallbacks = nodes.filter(
    (n) => !fallbackNodeIds.includes(n.id)
  );

  // Re-do the hierarchy but without the fallback nodes
  const hierarchy = stratify<Node<GraphNodeData>>()
    .id((descendant) => descendant.id)
    .parentId((descendant) => {
      // find the parent of 'descendant' who'se not a fallback node
      const ingoing = edgesAll.find(
        (edge) =>
          edge.target === descendant.id &&
          !fallbackNodeIds.includes(edge.source)
      );

      return ingoing?.source;
    })(nodesWithoutFallbacks);

  const root = layout(hierarchy);

  const accumulatedOffsets = accumulatedAdditionalHorizontalOffsetByDepth(root);

  const rootAdjusted = root.descendants().map((node) => {
    const position = { x: node.x, y: node.y };

    if (node.parent) {
      // for child nodes we swap x and y to make the layout horizontal
      position.x = node.y;
      position.y = node.x;

      // horizontally offset each child by the accumulated default edge widths from root to this child
      position.x += node.depth * EDGE.WIDTH;

      // horizontally offset each child by the accumulated edge label widths from root to this child
      position.x += accumulatedOffsets.at(node.depth) ?? 0;
    }

    return { ...node.data, position };
  });

  // lay out detours on top of the main main layout
  for (const descendant of rootAdjusted) {
    const node = descendant.data.node;

    if (!node) continue;

    for (const outgoing of node.adj) {
      if (outgoing.type !== 'fallback') continue; // skip non-detours since they have already been laid out

      // 'target' marks the first node on a detour path
      const target = nodes.find((node) => node.id === outgoing.id);
      if (!target) continue;

      // lay out the detour orthogonal to the main path
      rootAdjusted.push({
        ...target,
        position: {
          x: descendant.position.x,
          y: descendant.position.y + NODE.HEIGHT + EDGE.WIDTH * 0.5,
        },
      });
    }
  }

  return rootAdjusted;
}

function createShortCircuitingEdge(
  source: string,
  target: string
): Edge<GraphEdgeData> {
  return {
    id: `edge-${source}-${target}`,
    source: source,
    target: target,
    type: 'edgeIntermediary',
    animated: true,
  };
}

/**
 * Accumulates the additional horizontal offset needed for each depth (index) in the graph to make room for edge labels.
 * At each depth, we compute the required offset by finding the _longest_ edge label and multiply it by the estimated width per character.
 * The offset are then accumulated from the root to the current node.
 *
 * Example: If returned array is [0, 10, 60] it means that;
 * - every node at depth 0 needs an additional offset of 0px,
 * - every node at depth 1 needs an additional offset of 10px
 * - and every node at depth 2 needs additional offset of 60px
 * to make room for edge labels.
 *
 * the sum of the _longest_ edge labels from the root to the current node.
 * @param root root of the graph to traverse
 * @returns List of accumulated horizontal offsets by depth
 */
function accumulatedAdditionalHorizontalOffsetByDepth(root: Graph) {
  const maxLabelLengthByDepth: number[] = [];

  for (const node of root.descendants()) {
    let maxLabelLengthAtDepth = maxLabelLengthByDepth.at(node.depth) ?? 0;

    const incidentEdges = node.parent?.data.data.node?.adj ?? [];
    for (const edge of incidentEdges) {
      maxLabelLengthAtDepth = Math.max(
        maxLabelLengthAtDepth,
        edge.name?.length ?? 0
      );
    }

    maxLabelLengthByDepth[node.depth] = maxLabelLengthAtDepth;
  }

  let sum = 0;
  const accumulated: number[] = [];

  for (const maxLabelByDepth of maxLabelLengthByDepth) {
    const estimatedOffset = maxLabelByDepth * EDGE.LABEL.ESTIMATED_PX_PER_CHAR;
    sum += estimatedOffset;
    accumulated.push(sum);
  }

  return accumulated;
}

export function useLayout() {
  const nodeCount = useNodeCount();
  const { getNodes, getNode, setNodes, getEdges, setCenter } = useReactFlow();

  useLayoutEffect(() => {
    const nodesBeforeLayout = getNodes();
    const edgesBeforeLayout = getEdges();

    const nodesAfterLayout = layoutNodes(nodesBeforeLayout, edgesBeforeLayout);
    setNodes(nodesAfterLayout);
  }, [nodeCount, getEdges, getNodes, getNode, setNodes, setCenter]); // NOTE: it's intentional to have 'setCenter' and 'getNode' as dependency even though they are not used in the effect. Without them, the effect will not run correctly on node deletions for some reason
}
