import { Edge } from '@xyflow/react';
import { StateCreator } from 'zustand';
import { GROUP_DIMENSIONS, NODE_GAP, STICKY_NOTE_DIMENSIONS } from '../consts/whiteboard.const';
import { calculateAbsolutePosition } from '../helpers/group-helpers';
import { generateId } from '../helpers/node-helpers';
import { calculateGroupDimensions, findPosition } from '../helpers/position-helpers';
import {
  ExecutionEngineNodeStatus,
  GroupNodeType,
  NodeData,
  NodeType,
  NodeTypesEnum,
  StickyNoteNodeType,
} from '../types';
import { ExecutionSlice } from './execution.slice';
import { HistorySlice } from './history.slice';
import { ReactFlowSlice } from './react-flow.slice';
import { WhiteboardSlice } from './whiteboard.slice';

type NodeState = {
  nodeStreamOutputs: Record<string, string | undefined>;
};

type NodeAction = {
  addNodes: (nodes: NodeType[]) => void;
  updateNodeData(nodeId: string, data: Partial<NodeData>): void;
  updateNodeStreamOutput(nodeId: string, streamOutput: string | undefined): void;
  updateNodeDataIfNotDone: (nodeId: string, data: Partial<NodeData>) => void;
  getNodeStreamOutput(nodeId: string): string | undefined;
  setOutputHistoryIndex: (nodeId: string, index: number) => void;
  calculateNewNodePosition: () => { x: number; y: number };
  unselectNodes: () => void;
  addTemplateNodes: (templateNodes: NodeType[], templateEdges: Edge[], templateName?: string) => void;
  getSelectedNodes: () => NodeType[];
  getToolNodes: () => NodeType[];
  clearOutput: (nodeId: string) => void;
  removeNode: (nodeId: string) => void;
  addStickyNote: (params: Partial<GroupNodeType>, data?: Partial<NodeData>) => void;
  addGroupNode: (params: Partial<GroupNodeType>, data?: Partial<NodeData>, shouldAddToHistory?: boolean) => NodeType;
};
export type NodeSlice = NodeState & NodeAction;

export const initialNodeState: NodeState = {
  nodeStreamOutputs: {},
};

export const createNodeSlice: StateCreator<
  NodeSlice & ExecutionSlice & HistorySlice & ReactFlowSlice & WhiteboardSlice,
  [],
  [],
  NodeSlice
> = (set, get) => ({
  ...initialNodeState,
  addNodes: (nodes) => {
    get().unselectNodes();

    set((state) => ({
      nodes: state.nodes.concat(nodes),
    }));
  },

  updateNodeData: (nodeId, data) => {
    set((state) => ({
      nodes: state.nodes.map((node) =>
        node.id === nodeId ? ({ ...node, data: { ...node.data, ...data } } as NodeType) : node,
      ),
      shouldSave: true,
    }));
  },

  updateNodeDataIfNotDone: (nodeId, data) => {
    const node = get().nodes.find((node) => node.id === nodeId);
    if (
      node &&
      node.data.status !== ExecutionEngineNodeStatus.DONE &&
      node.data.status !== ExecutionEngineNodeStatus.ERROR
    ) {
      get().updateNodeData(nodeId, data);
    }
  },

  getNodeStreamOutput: (nodeId: string) => {
    return get().nodeStreamOutputs[nodeId];
  },

  updateNodeStreamOutput: (nodeId, streamOutput) => {
    set(() => ({
      nodeStreamOutputs: { ...get().nodeStreamOutputs, [nodeId]: streamOutput },
    }));
  },

  setOutputHistoryIndex: (nodeId, index) => {
    const node = get().nodes.find((node) => node.id === nodeId);
    if (node) {
      const executionEngineNode = get().getExecutionEngineNode(nodeId);
      if (!executionEngineNode) {
        return;
      }
      const outputHistory = executionEngineNode.tool.outputHistory?.[index];
      const output = outputHistory ? outputHistory.output : node.data.output;
      get().updateNodeData(nodeId, { outputHistoryIndex: index, output });
      get().updateNodeStreamOutput(nodeId, undefined);
      get().addToHistoryStack();
    }
  },

  calculateNewNodePosition: () => {
    return findPosition(get().nodes.filter((node) => !node.parentId));
  },

  unselectNodes: () => {
    set((state) => ({
      nodes: state.nodes.map((node) => ({
        ...node,
        selected: false,
      })),
    }));
  },

  addTemplateNodes: (templateNodes: NodeType[], templateEdges?: Edge[], templateName?: string) => {
    const { width, height, minX, minY } = calculateGroupDimensions(templateNodes);
    const position = findPosition(
      get().nodes.filter((node) => !node.parentId),
      width,
      height,
    );

    const padding = NODE_GAP;

    const nodesWithoutGroup = templateNodes.filter((node) => node.type !== NodeTypesEnum.Group && !node.parentId);
    const groupNode =
      nodesWithoutGroup.length > 0
        ? get().addGroupNode(
            { position, width: width + 2 * padding, height: height + 2 * padding },
            { name: templateName },
            false,
          )
        : undefined;

    function calculateNodePosition(node: NodeType) {
      if (node.parentId) {
        return node.position;
      }
      if (node.type === NodeTypesEnum.Group) {
        return {
          x: position.x + node.position.x + padding - minX,
          y: position.y + node.position.y + padding - minY,
        };
      }
      return {
        x: node.position.x + padding - minX,
        y: node.position.y + padding - minY,
      };
    }

    const nodes = templateNodes.map((node) => {
      return {
        ...node,
        position: calculateNodePosition(node),
        parentId: node.parentId ? node.parentId : node.type !== NodeTypesEnum.Group ? groupNode?.id : undefined,
      };
    });

    set((state) => ({
      nodes: state.nodes.concat(nodes),
      edges: state.edges.concat(templateEdges ?? []),
    }));

    get().fitView(groupNode ? [groupNode] : nodes);
    get().addToHistoryStack();
  },

  getSelectedNodes: () => {
    return get().nodes.filter((node) => node.selected);
  },

  getToolNodes: () => {
    return get().nodes.filter((node) => node.type === NodeTypesEnum.Tool);
  },

  clearOutput: (nodeId) => {
    get().updateNodeData(nodeId, { output: '' });
    get().updateNodeStreamOutput(nodeId, undefined);
  },

  removeNode: (nodeId) => {
    const node = get().nodes.find((node) => node.id === nodeId);
    if (!node) {
      return;
    }

    let newNodes = get().nodes.filter((n) => n.id !== nodeId);

    if (node.type === NodeTypesEnum.Group) {
      newNodes = newNodes.map((n) => {
        if (n.parentId === nodeId) {
          return {
            ...n,
            parentId: undefined,
            position: calculateAbsolutePosition(n, node),
          };
        }
        return n;
      });
    }
    set({ nodes: newNodes, shouldSave: true });
    get().addToHistoryStack();
  },

  addStickyNote: (params: Partial<GroupNodeType>, data?: Partial<NodeData>) => {
    get().unselectNodes();
    const stickyNoteNode = {
      id: params.id || generateId(NodeTypesEnum.StickyNote),
      type: NodeTypesEnum.StickyNote,
      data: {
        isNewlyAdded: true,
        name: 'Note',
        content: '',
        ...data,
      },
      width: STICKY_NOTE_DIMENSIONS.width,
      height: STICKY_NOTE_DIMENSIONS.height,
      selected: true,
      ...params,
    } as StickyNoteNodeType;

    set((state) => ({
      nodes: state.nodes.concat(stickyNoteNode),
    }));

    get().addToHistoryStack();
  },

  addGroupNode: (params: Partial<GroupNodeType>, data?: Partial<NodeData>, shouldAddToHistory = true) => {
    get().unselectNodes();
    const groupNode = {
      id: params.id || generateId(NodeTypesEnum.Group),
      type: NodeTypesEnum.Group,
      data: {
        isNewlyAdded: true,
        name: 'Group',
        ...data,
      },
      width: GROUP_DIMENSIONS.width,
      height: GROUP_DIMENSIONS.height,
      ...params,
    } as GroupNodeType;

    // group node should be added before other nodes
    set((state) => ({
      nodes: [groupNode, ...state.nodes],
      fullSync: true,
    }));
    if (shouldAddToHistory) {
      get().addToHistoryStack();
    }
    return groupNode;
  },
});
