import { NumericGridTableBlockRow } from "../../../../api/inputFormTypes";

type ValuePerColumn = [string, number];

const collectMetricValues = <T extends NumericGridTableBlockRow>(rows: T[]): Map<string, ValuePerColumn[]> => {
  const metricValues = new Map<string, ValuePerColumn[]>();

  for (const row of rows) {
    if (row.type === "Metric" || row.type === "MetricExtension") {
      const valuesPerColumn: ValuePerColumn[] = Object.entries(row.values).map(([columnId, value]) => [
        columnId,
        value ?? 0,
      ]);

      metricValues.set(row.id, valuesPerColumn);
    }
  }

  return metricValues;
};

const getRowIdToReferenceRowsMap = <T extends NumericGridTableBlockRow>(rows: T[]): Map<string, string[]> => {
  const result = new Map<string, string[]>();

  for (const row of rows) {
    if (row.type === "Total") {
      result.set(row.id, row.rowIds);
    } else if (row.type === "ExtendedMetricSection") {
      result.set(
        row.id,
        rows.filter((r) => r.type === "MetricExtension" && r.metricId === row.metricId).map((r) => r.id)
      );
    }
  }

  return result;
};

const getMetricIdsForRowId = (rowId: string, referenceMap: Map<string, string[]>): string[] => {
  const referencedRowIds = referenceMap.get(rowId);
  if (referencedRowIds === undefined) {
    return [rowId];
  }

  const result = [];
  for (const referencedRowId of referencedRowIds) {
    const metricIds = getMetricIdsForRowId(referencedRowId, referenceMap);
    result.push(...metricIds);
  }

  return result;
};

const mapTotalRowsToMetricIds = <T extends NumericGridTableBlockRow>(rows: T[]): Map<string, string[]> => {
  const referenceMap = getRowIdToReferenceRowsMap(rows);

  const result = new Map<string, string[]>();

  for (const row of rows) {
    if (row.type === "Total") {
      const metricIds = getMetricIdsForRowId(row.id, referenceMap);
      result.set(row.id, metricIds);
    }
  }

  return result;
};

export const calculateTotals = <T extends NumericGridTableBlockRow>(rows: T[]): T[] => {
  const updatedRows: T[] = [];

  const metricValues = collectMetricValues(rows);
  const totalRowsToMetricIds = mapTotalRowsToMetricIds(rows);

  for (const row of rows) {
    if (row.type !== "Total") {
      updatedRows.push({ ...row });
      continue;
    }

    const totalValues: Record<string, number> = {};
    const metricIds = totalRowsToMetricIds.get(row.id) ?? [];
    for (const metricId of metricIds) {
      const valuesPerColumn = metricValues.get(metricId) ?? [];
      for (const [columnId, value] of valuesPerColumn) {
        totalValues[columnId] = (totalValues[columnId] ?? 0) + value;
      }
    }

    updatedRows.push({ ...row, values: totalValues });
  }

  return updatedRows;
};
