import { createSelector } from '@reduxjs/toolkit'
import { mapValues, pick } from 'lodash'
import SeededRandomUtilities from 'seeded-random-utilities'

import {
  ScatterPlotChartScale,
  ScatterPlotScaleType,
} from 'pages/analysis/store/selectors'

import { AnalysisWorkerStore } from './AnalysisWorkerStore'

export type Event = [number, number?]

type ComputeEventsByLeafId = {
  xAxis: string
  yAxis?: string
  leafIds: string[]
  hiddenClustersIds: string[]
  eventLimit?: number
}

type ComputeLogarithmicScaleWorkaroundProps = {
  chart: {
    xAxis: string
    yAxis: string
    scale: ScatterPlotChartScale | undefined
    scaleType: ScatterPlotScaleType
  }
}

export class SeriesEvents {
  private store: AnalysisWorkerStore

  constructor(store: AnalysisWorkerStore) {
    this.store = store
  }

  public async computeLogarithmicScaleWorkaround({
    chart,
  }: ComputeLogarithmicScaleWorkaroundProps): Promise<
    ScatterPlotChartScale | undefined
  > {
    if (!chart.scale) {
      return undefined
    }
    let xMin = chart.scale.xAxis.min
    if (chart.scaleType.xAxis === 'logarithmic' && xMin <= 0) {
      xMin = await this.computeAxisPositiveMin(chart.xAxis)
    }
    let yMin = chart.scale.yAxis.min
    if (chart.scaleType.yAxis === 'logarithmic' && yMin <= 0) {
      yMin = await this.computeAxisPositiveMin(chart.yAxis)
    }
    return {
      xAxis: {
        ...chart.scale.xAxis,
        min: xMin,
      },
      yAxis: {
        ...chart.scale.yAxis,
        min: yMin,
      },
    }
  }

  public async computeEventsByLeafId({
    xAxis,
    yAxis,
    hiddenClustersIds,
    leafIds,
    eventLimit,
  }: ComputeEventsByLeafId): Promise<
    Record<string, { id: number; x: number; y: number }[]>
  > {
    await this.store.waitForClusteringFiles()
    const leafLabels = this.store.getLeafLabels()
    const transformedData = this.store.getTransformedData()
    const clusterTree = this.store.getClusterTree()

    const treeLeafIdsByClusterId = this.computeTreeLeafIdsByClusterId({
      clusterTree,
    })
    const eventIdsByLeafId = this.computeEventIdsByLeafId({
      leafLabels,
    })

    const treeLeafIdsToRemove = new Set<string>()
    for (const hiddenClusterId of hiddenClustersIds) {
      const treeLeafIds = treeLeafIdsByClusterId[hiddenClusterId]
      for (const treeLeafId of treeLeafIds) {
        treeLeafIdsToRemove.add(treeLeafId)
      }
    }

    const shouldDownSample = eventLimit && eventLimit < 100

    const eventsByLeafId: Record<
      string,
      { id: number; x: number; y: number }[]
    > = {}
    for (const leafId of leafIds) {
      eventsByLeafId[leafId] = []

      const treeLeafIds = treeLeafIdsByClusterId[leafId]
      for (const treeLeafId of treeLeafIds) {
        if (treeLeafIdsToRemove.has(treeLeafId)) {
          continue
        }

        const eventIds = shouldDownSample
          ? this.downsampleEventIds(eventIdsByLeafId[treeLeafId], eventLimit)
          : eventIdsByLeafId[treeLeafId]

        for (const eventId of eventIds) {
          eventsByLeafId[leafId].push({
            id: eventId,
            x: transformedData[xAxis][eventId],
            y: yAxis ? transformedData[yAxis][eventId] : NaN,
          })
        }
      }
    }

    return eventsByLeafId
  }

  public async computeDownsampledRatioByLeafIds(
    leafIds: string[],
  ): Promise<Record<string, number>> {
    await this.store.waitForClusteringFiles()
    const clusterTree = this.store.getClusterTree()
    const clusterById = this.computeClusterById({ clusterTree })

    return mapValues(
      pick(clusterById, leafIds),
      cluster => cluster.downsampled_ratio ?? 1,
    )
  }

  private async computeAxisPositiveMin(axis: string): Promise<number> {
    await this.store.waitForClusteringFiles()
    const transformedData = this.store.getTransformedData()

    let min = Infinity
    for (let i = 0; i < transformedData[axis].length; i++) {
      const value = transformedData[axis][i]
      if (value > 0 && value < min) {
        min = value
      }
    }
    return min
  }

  private computeEventIdsByLeafId = createSelector(
    (props: { leafLabels: Analysis.LeafLabels }) => props.leafLabels,
    (leafLabels: Analysis.LeafLabels) => {
      const eventIdsByLeafId: Record<string, number[]> = {}

      for (let eventId = 0; eventId < leafLabels.length; eventId++) {
        const leafId = leafLabels[eventId].toString()
        if (!eventIdsByLeafId[leafId]) {
          eventIdsByLeafId[leafId] = []
        }
        eventIdsByLeafId[leafId].push(eventId)
      }

      return eventIdsByLeafId
    },
  )

  private computeTreeLeafIdsByClusterId = createSelector(
    (props: { clusterTree: Analysis.ClusterTree }) => props.clusterTree,
    tree => {
      const leavesByClusterId: Record<string, string[]> = {}

      const traverseTree = (id: string, node: Analysis.ClusterNode) => {
        const isLeaf = Object.keys(node.children).length === 0
        if (isLeaf) {
          leavesByClusterId[id] = [id]
        } else {
          for (const [childId, child] of Object.entries(node.children)) {
            traverseTree(childId, child)
          }
          leavesByClusterId[id] = [...Object.keys(node.children)].flatMap(
            childId => leavesByClusterId[childId],
          )
        }
      }

      for (const [nodeId, node] of Object.entries(tree)) {
        traverseTree(nodeId, node)
      }

      return leavesByClusterId
    },
  )

  private computeClusterById = createSelector(
    (props: { clusterTree: Analysis.ClusterTree }) => props.clusterTree,
    tree => {
      const clusterById: Record<string, Analysis.ClusterNode> = {}

      const traverseTree = (id: string, node: Analysis.ClusterNode) => {
        clusterById[id] = node
        for (const [childId, child] of Object.entries(node.children)) {
          traverseTree(childId, child)
        }
      }

      for (const [nodeId, node] of Object.entries(tree)) {
        traverseTree(nodeId, node)
      }

      return clusterById
    },
  )

  private downsampleEventIds(eventIds: number[], eventLimit: number) {
    const numberOfEvents = (eventIds.length * (eventLimit ?? 100)) / 100
    const seedUtil = new SeededRandomUtilities('seed')
    return seedUtil.selectUniqueRandomElements(eventIds, numberOfEvents)
  }
}
