import { Set, Stack } from "immutable";

interface INode {
  id: number;
  name?: string;
  children?: INode[];
  _metaCollapsed?: boolean;
}

const COLOR_AP_RED = "#ae4946";
const COLOR_AP_BLUE = "#6082b6";
const COLOR_AP_GREEN = "#76a95a";
const COLOR_AP_YELLOW = "#d69a23";
const DEFAULT_FILL = "white";
const DEFAULT_STROKE = "black";
const NODE_RECT_STROKE = 5;
const NODE_ELLIPSE_STROKE = 8;
const NODE_ELLIPSE_AP_STROKE = 14;
const STROKE_WIDTH_DEFAULT = 8;
const AP_STROKE_BASE = 10;
const AP_STROKE_MULT_FACTOR = 100;

const COLORS = [COLOR_AP_RED, COLOR_AP_BLUE, COLOR_AP_GREEN, COLOR_AP_YELLOW];

type AST = INode;

const noop = () => {};

// TODO: add function to sto recursion in closure
function walk(visit: (Node) => void = noop, leave: (Node) => void = noop) {
  return function walkRecurse(node: INode) {
    if (visit(node)) {
      // stop recursion in case of return
      return;
    }
    if (node.children) {
      node.children.forEach(walkRecurse);
    }
    leave(node);
  };
}

function parseRawAST(ast: AST) {
  if (!ast) return {};
  // TODO: optimize
  const withoutRootNode = ast.children ? ast.children[0] : ast;
  let newAST = JSON.parse(JSON.stringify(withoutRootNode));

  const visit = node => {
    const { type: oldType, name: oldName } = node;
    node.name = String(oldName || oldType || "");
  };

  // mutate AST
  walk(visit)(newAST);
  return newAST;
}

function parseForLowestCommonAncestor(
  path1: number[],
  path2: number[]
): [number[], number[]] | null {
  if (!path1 || !path2) return null;
  const longestPath = path1.length > path2.length ? path1 : path2;
  let indexToShiftUntil = 0;
  for (let i = 0; path1[i + 1] === path2[i + 1]; i++) {
    indexToShiftUntil++;
  }
  path1.splice(0, indexToShiftUntil);
  path2.splice(0, indexToShiftUntil);
  return [path1, path2];
}

function getPathFromRoot(ast: AST, leaf: number): number[] {
  let path: number[] = [];
  const pathStack: number[] = [];

  const visit = node => {
    pathStack.push(node.id);
    if (node.id === leaf) {
      path = Array.from(pathStack);
      return path;
    }
  };
  const leave = node => {
    pathStack.pop();
  };

  walk(visit, leave)(ast);
  return path;
}

function getPaths(ast: AST, leaves: number[]) {
  const leavesSet = Set(leaves);
  const paths: number[][] = [];
  const pathStack: number[] = [];

  const visit = node => {
    const { id } = node;
    pathStack.push(id);
    if (node.children) return;
    if (leavesSet.includes(id)) {
      paths.push(Array.from(pathStack));
    }
  };

  const leave = node => {
    pathStack.pop();
  };

  walk(visit, leave)(ast);

  // to remove path until lowest common ancestor
  const lcaPaths = parseForLowestCommonAncestor(paths[0], paths[1]);

  return { lcaPaths };
}

function styleNodeShape(node, indexOfAP) {
  const isNodeInAP = indexOfAP !== -1;
  const color = isNodeInAP ? COLORS[indexOfAP] : DEFAULT_STROKE;
  if (!node.children) {
    return {
      shape: "rect",
      shapeProps: {
        width: 480,
        height: 190,
        x: -235,
        y: -90,
        stroke: color,
        strokeWidth: NODE_RECT_STROKE,
        fill: DEFAULT_FILL
      }
    };
  } else {
    return {
      shape: "ellipse",
      shapeProps: {
        rx: 250,
        ry: 100,
        stroke: color,
        strokeWidth: isNodeInAP ? NODE_ELLIPSE_AP_STROKE : NODE_ELLIPSE_STROKE,
        fill: DEFAULT_FILL
      }
    };
  }
}

function styleNodes(ast: AST, apScores: number[], attentionPaths: number[][]) {
  attentionPaths.reverse().forEach((ap, index) => {
    const apIndex = attentionPaths.length - 1 - index; // because reverse
    const { lcaPaths: paths } = getPaths(ast, ap);
    if (!paths) return;
    const path = paths[0].concat(paths[1].slice(1));
    const lca = path[0];
    const nodesSet = Set(path);
    // const setOfNodesInAPs =
    const visit = node => {
      const { id } = node;
      const isNodeInAP = nodesSet.includes(id);
      if (isNodeInAP) {
        if (id !== lca) {
          node.linkStyles = {
            stroke: COLORS[apIndex],
            strokeWidth:
              AP_STROKE_BASE + apScores[apIndex] * AP_STROKE_MULT_FACTOR
          };
        }
        node.nodeSvgShape = styleNodeShape(node, isNodeInAP ? apIndex : -1);
        node._metaCollapsed = false;
      } else {
        node.linkStyles = node.linkStyles || {
          stroke: "black",
          strokeWidth: STROKE_WIDTH_DEFAULT
        };
        node.nodeSvgShape =
          node.nodeSvgShape || styleNodeShape(node, isNodeInAP ? apIndex : -1);
        node._metaCollapsed =
          node._metaCollapsed === undefined ? true : node._metaCollapsed;
      }
    };
    walk(visit)(ast);
  });

  const leaves: Set<number> = attentionPaths.reduce(
    (acc, curr) => acc.union(Set(curr)),
    Set()
  ) as Set<number>;
  const nodesInPathToLeaves = leaves
    .map(leaf => Set(getPathFromRoot(ast, leaf as number)))
    .flatten(1);

  const visitMarkNotCollapsed = node => {
    if (nodesInPathToLeaves.includes(node.id)) {
      node._metaCollapsed = false;
    }
  };
  walk(visitMarkNotCollapsed)(ast);

  // make sure that the root node is not collapsed
  // (if it doesn't belong in an AP)
  ast._metaCollapsed = false;
}

function highlightAPs(
  ast: AST,
  apScores: number[],
  attentionPaths: number[][]
) {
  if (!ast || !apScores || !attentionPaths) return {};
  // TODO: optimize
  const newAST = JSON.parse(JSON.stringify(ast));

  styleNodes(newAST, apScores, attentionPaths);

  return newAST;
}

function normalize(items: number[]) {
  const len = Math.sqrt(items.reduce((acc, curr) => acc + curr * curr, 0));
  return items.map(item => item / len);
}

function parseAttentionPaths(apArray: any) {
  // length of all keys minus the `path` and `score` key
  const ap = apArray.map(a => {
    const res = Object.entries(a)
      .filter(([key, val]) => /token\d+/.test(key))
      // @ts-ignore
      .map(([_, val]): A => val.node_id);
    return res;
  });
  const scores = normalize(apArray.slice(0, 4).map(({ score }) => score));
  return { ap, scores };
}

export { parseRawAST, highlightAPs, parseAttentionPaths };
