import { useCallback, useEffect, useMemo, useRef } from 'react'

import { CellClassParams, GridApi, IRowNode } from 'ag-grid-community'
import { AgGridReact } from 'ag-grid-react'

import { CHANGE_KEY } from '@hooks/use-change-detection'
import { useLatestRef } from '@hooks/useLatestRef'

import { standardDeviation, valueWithinStdDev } from 'utils/number-utils'

import { ITableColumn } from '../../types/table-builder-types'
import { heatMapColumn } from '../utils'
import { isColumnPivotRowTotal, isPivotRowTotal } from './use-grand-totals'

interface DataMeanStddev {
  [s: string]: { mean: number; stddev: number }
}

function getHeatmapColumnsIds(columns: ITableColumn[]): string[] {
  return _.map(
    _.filter(columns, (column) => column.customData?.heatmap) as ITableColumn[],
    (column) => column.field!
  )
}

function getDataFromRowNode(rowNode: IRowNode, heatmapColumnsId: string[]) {
  if (rowNode.group) {
    return _.filter(
      _.flatMap(heatmapColumnsId, (columnId) =>
        _.map(rowNode.aggData as Record<string, number>, (value, key) => {
          if (isPivotRowTotal(key)) return null // ignore pivot row totals

          return key.endsWith(columnId) ? +value : null
        })
      ),
      (value) => value !== null
    ) as number[]
  } else if (rowNode.data) {
    return _.map(heatmapColumnsId, (columnId) => +rowNode.data[columnId])
  } else {
    return []
  }
}

function calculateMeanStddev(gridApi: GridApi, heatmapColumnsIds: string[]) {
  if (!heatmapColumnsIds.length) return null
  const allMeanStddev: DataMeanStddev = {}
  gridApi.forEachNode((rowNode) => {
    if (!rowNode.id) return

    const data = getDataFromRowNode(rowNode, heatmapColumnsIds)
    const meanValue = _.mean(data)
    const stddevValue = standardDeviation(data)

    allMeanStddev[rowNode.id!] = { mean: meanValue, stddev: stddevValue }
  })

  return allMeanStddev
}

export default function useHeatmap({
  columns,
  gridRef,
  externalFiltersChanged
}: {
  columns: ITableColumn[]
  gridRef: React.RefObject<AgGridReact | null>
  externalFiltersChanged: CHANGE_KEY
}) {
  const prevDataMeanStddevRef = useRef<DataMeanStddev | null>(null)
  const dataMeanStddevRef = useRef<DataMeanStddev | null>(null)
  const heatmapColumIds = useMemo(() => getHeatmapColumnsIds(columns), [columns])

  const heatmapColumnIdsRef = useLatestRef(heatmapColumIds) // ag-grid seems to cache the cellClass function, so we need to maintain a ref to the heatmap columns

  useEffect(() => {
    // changing filters changes the data, so we need to force refresh the cells to recalculate heatmap
    prevDataMeanStddevRef.current = dataMeanStddevRef.current
    dataMeanStddevRef.current = null
    gridRef.current?.api?.refreshCells({ force: true })
  }, [columns, gridRef, externalFiltersChanged])

  const getColorClass = useCallback(
    (params: CellClassParams, column: ITableColumn) => {
      const nodeId = params.node.id as string
      const value = params.value
      const gridApi = params.api
      const heatmapColumnsIds = heatmapColumnIdsRef.current
      if (!dataMeanStddevRef.current) {
        const newDataMeanStdDev = calculateMeanStddev(gridApi, heatmapColumnsIds)
        if (_.isEmpty(newDataMeanStdDev)) {
          dataMeanStddevRef.current = prevDataMeanStddevRef.current
        } else {
          dataMeanStddevRef.current = newDataMeanStdDev
        }
      }

      if (
        !dataMeanStddevRef.current?.[nodeId] ||
        !heatMapColumn(column) ||
        isColumnPivotRowTotal(params.column) // Don't apply heatmap to pivot row totals
      ) {
        return
      }
      const dataMeanStddev = dataMeanStddevRef.current

      const { mean, stddev } = dataMeanStddev[nodeId]

      const valueWithinStdDevCurried = (value: number, stddevCount: number) =>
        valueWithinStdDev({ value, stddevCount, mean, stddev })

      if (valueWithinStdDevCurried(value, 1) || valueWithinStdDevCurried(value, -1)) {
        return 'bg-heatmap-divergent'
      } else if (valueWithinStdDevCurried(value, -2)) {
        return 'bg-heatmap-divergent-neg1'
      } else if (valueWithinStdDevCurried(value, 2)) {
        return 'bg-heatmap-divergent-pos1'
      } else if (valueWithinStdDevCurried(value, -3)) {
        return 'bg-heatmap-divergent-neg2'
      } else if (valueWithinStdDevCurried(value, 3)) {
        return 'bg-heatmap-divergent-pos2'
      } else if (value > mean) {
        return 'bg-heatmap-divergent-pos3'
      } else if (value < mean) {
        return 'bg-heatmap-divergent-neg3'
      }
    },
    [heatmapColumnIdsRef]
  )

  return {
    getColorClass
  }
}
