import { StudioNodeType } from '@common/studio-types';
import { useCallback, useRef } from 'react';
import { Edge, Node, NodePositionChange, useReactFlow } from 'reactflow';

/**
 * If the two edges are close enough, fix the position of the node
 * to match the same center of the previous or next node.
 */
const THRESHOLD = 10;

const DOUBLE_BOTTOM_EDGE_NODES = [
  StudioNodeType.CoinToss,
  StudioNodeType.ConditionCheck,
];

export const useFixNodePosition = () => {
  const { getEdges, getNode } = useReactFlow();
  const previousNodesRef = useRef<Node[]>([]);
  const nextRef = useRef<Node | undefined>(undefined);
  const nodeIdRef = useRef<string | undefined>(undefined);
  const nodeWidthRef = useRef<number | undefined>(undefined);

  const fixNodePosition = (change: NodePositionChange): NodePositionChange => {
    if (!change.position) return change;

    if (nodeIdRef.current !== change.id) {
      nodeIdRef.current = change.id;

      const node = getNode(change.id)!;

      nextRef.current = undefined;
      nodeWidthRef.current = node.width!;
      previousNodesRef.current = [];

      getEdges().forEach((edge: Edge) => {
        if (edge.source === change.id) {
          nextRef.current = getNode(edge.target);
        } else if (edge.target === change.id) {
          const sourceNode = getNode(edge.source);

          if (
            sourceNode &&
            !DOUBLE_BOTTOM_EDGE_NODES.includes(sourceNode.data.type)
          ) {
            previousNodesRef.current.push(sourceNode);
          }
        }
      });
    }

    const nodeCenter = change.position.x + nodeWidthRef.current! / 2;

    const checkAndFix = (node: Node | undefined) => {
      if (!node) return;

      const center = node.position.x + node.width! / 2;

      if (Math.abs(center - nodeCenter) < THRESHOLD) {
        return {
          ...change,
          position: change.position
            ? { ...change.position, x: center - nodeWidthRef.current! / 2 }
            : undefined,
        };
      }
    };

    for (const previous of previousNodesRef.current) {
      const previousChecked = checkAndFix(previous);

      if (previousChecked) {
        return previousChecked;
      }
    }

    return checkAndFix(nextRef.current) ?? change;
  };

  return useCallback(fixNodePosition, [getEdges, getNode]);
};
