import * as Rf from 'reactflow';
import { create } from 'zustand';
import { v4 as uuidv4 } from 'uuid';
import { ModelNodeType, ModelNodeProps, RelationshipTypes, Relation } from '../components';

const genericNode: ModelNodeType = {
  id: uuidv4(),
  type: 'model',
  position: { x: 0, y: 0 },
  data: {
    title: 'New Model',
    fields: [],
    relations: [],
  },
};

export const initialNodes: ModelNodeType[] = [];

export const initialEdges: any[] = [];

export type RFState = {
  nodes: ModelNodeType[];
  getNodes: () => ModelNodeType[];
  addNode: (viewport?: { x: number; y: number; zoom: number }) => void;
  setNode: (nodeId: string, data: ModelNodeProps) => void;
  setNodes: (nodes: ModelNodeType[]) => void;
  deleteNode: (nodeId: string) => void;
  onNodesChange: Rf.OnNodesChange;

  edges: Rf.Edge[];
  getEdges: () => Rf.Edge[];
  setEdge: (edgeId: string, data: any) => void;
  setEdges: (edges: Rf.Edge[]) => void;
  onEdgesChange: Rf.OnEdgesChange;

  onConnect: Rf.OnConnect;
};

export const useStore = create<RFState>()((set, get) => ({
  nodes: initialNodes,

  getNodes: () => get().nodes,

  setNode: (nodeId, data) => set({ nodes: get().nodes.map((n) => (n.id === nodeId ? { ...n, data } : n)) }),

  setNodes: (nodes) => set({ nodes }),

  addNode: (viewport) => {
    const positionX = (viewport ? -viewport.x / viewport.zoom : 0) + Math.random() * 50;
    const positionY = (viewport ? -viewport.y / viewport.zoom : 0) + Math.random() * 50;
    const position = { x: positionX, y: positionY };
    set({ nodes: [...get().nodes, { ...genericNode, id: uuidv4(), position }] });
  },

  deleteNode: (nodeId) =>
    set({
      nodes: get()
        .nodes.filter((n) => n.id !== nodeId)
        .map((n) => ({
          ...n,
          data: { ...n.data, relations: n.data.relations.filter((r: Relation) => r.to !== nodeId) },
        })),
    }),

  onNodesChange: (changes: Rf.NodeChange[]) => set({ nodes: Rf.applyNodeChanges(changes, get().nodes) }),

  edges: initialEdges,

  getEdges: () => get().edges,

  setEdge: (edgeId, data) => set({ edges: get().edges.map((e) => (e.id === edgeId ? { ...e, data } : e)) }),

  setEdges: (edges) => set({ edges }),

  onEdgesChange: (changes: Rf.EdgeChange[]) => {
    changes
      .slice()
      .filter((change: Rf.EdgeChange) => change.type === 'remove')
      .forEach((change: Rf.EdgeRemoveChange) => {
        const edge = get().edges.find((e) => e.id === change.id);
        if (edge) {
          const sourceNode = get().nodes.find((n) => n.id === edge.source);
          const targetNode = get().nodes.find((n) => n.id === edge.target);
          const sourceRelations = sourceNode?.data.relations.filter((r: any) => r.to !== edge.target);
          const targetRelations = targetNode?.data.relations.filter((r: any) => r.to !== edge.source);

          get().setNode(edge.source, {
            ...sourceNode?.data,
            relations: sourceRelations,
          });

          get().setNode(edge.target, {
            ...targetNode?.data,
            relations: targetRelations,
          });
        }
      });
    set({ edges: Rf.applyEdgeChanges(changes, get().edges) });
  },

  onConnect: (connection: Rf.Connection) => {
    const sourceNode = get().nodes.find((n) => n.id === connection.source);
    const targetNode = get().nodes.find((n) => n.id === connection.target);
    const congruentEdge = get().edges.find(
      (e) =>
        (e.source === connection.source && e.target === connection.target) ||
        (e.source === connection.target && e.target === connection.source)
    );

    if (congruentEdge) {
      connection.source &&
        get().setNode(connection.source, {
          ...sourceNode?.data,
          relations: [...sourceNode?.data.relations.filter((r: any) => r.to)],
        });
    } else {
      connection.source &&
        get().setNode(connection.source, {
          ...sourceNode?.data,
          relations: [
            ...sourceNode?.data.relations.filter((r: any) => r.to),
            { to: connection.target, sourceOrTarget: 'source', type: RelationshipTypes.OneToOne, required: true },
          ],
        });
      connection.target &&
        get().setNode(connection.target, {
          ...targetNode?.data,
          relations: [
            ...targetNode?.data.relations.filter((r: any) => r.to),
            { to: connection.source, sourceOrTarget: 'target', type: RelationshipTypes.OneToOne, required: true },
          ],
        });
      set({
        edges: Rf.addEdge(
          {
            ...connection,
            sourceHandle: `${connection.source}-relation-${connection.target}`,
            targetHandle: `${connection.target}-relation-${connection.source}`,
            data: {
              startLabel: '1',
              endLabel: '1',
            },
            type: 'labeled',
          },
          get().edges
        ),
      });
    }
  },
}));
