import {
  ScatterPlotChartScale,
  ScatterPlotScaleType,
} from 'pages/analysis/store/selectors'

type GroupScatterEventsIntoGridProps = {
  eventsBySeriesId: Record<string, Event[]>
  scale: ScatterPlotChartScale
  scaleType: ScatterPlotScaleType
  plotWidth: number
  plotHeight: number
}

type Event = { id: number; x: number; y: number }

type EventsBySeriesIdByPixelPosition = Map<string, Map<string, Event[]>>

export const groupScatterEventsIntoGrid = ({
  eventsBySeriesId,
  scale,
  scaleType,
  plotWidth,
  plotHeight,
}: GroupScatterEventsIntoGridProps): EventsBySeriesIdByPixelPosition => {
  const transateEventToPixel = createTranslateEventToPixel({
    scale,
    scaleType,
    plotWidth,
    plotHeight,
  })

  const eventsBySeriesIdByPixelPosition = new Map<
    string,
    Map<string, { id: number; x: number; y: number }[]>
  >()

  for (const [seriesId, events] of Object.entries(eventsBySeriesId)) {
    for (const event of events) {
      if (shouldSkipEvent(event, scaleType)) {
        continue
      }

      const [xIndex, yIndex] = transateEventToPixel(event)
      if (xIndex >= plotWidth || yIndex >= plotHeight) {
        continue
      }

      const pixelPosition = `${xIndex},${yIndex}`

      let eventsBySeriesId = eventsBySeriesIdByPixelPosition.get(pixelPosition)
      if (!eventsBySeriesId) {
        eventsBySeriesId = new Map<
          string,
          { id: number; x: number; y: number }[]
        >()
        eventsBySeriesIdByPixelPosition.set(pixelPosition, eventsBySeriesId)
      }
      let events = eventsBySeriesId.get(seriesId)
      if (!events) {
        events = []
        eventsBySeriesId.set(seriesId, events)
      }
      events.push(event)
    }
  }

  return eventsBySeriesIdByPixelPosition
}

type CreateTranslateEventToPixelProps = {
  scale: ScatterPlotChartScale
  scaleType: ScatterPlotScaleType
  plotWidth: number
  plotHeight: number
}

export const createTranslateEventToPixel = ({
  scale,
  scaleType,
  plotWidth,
  plotHeight,
}: CreateTranslateEventToPixelProps) => {
  const xRange = scale.xAxis.max - scale.xAxis.min
  const yRange = scale.yAxis.max - scale.yAxis.min
  const xStep = xRange / plotWidth
  const yStep = yRange / plotHeight

  const xLogUnityLength =
    plotWidth / (Math.log10(scale.xAxis.max) - Math.log10(scale.xAxis.min))
  const yLogUnityLength =
    plotHeight / (Math.log10(scale.yAxis.max) - Math.log10(scale.yAxis.min))

  return (event: { id: number; x: number; y: number }): [number, number] => {
    const xIndex =
      scaleType.xAxis === 'linear'
        ? Math.floor((event.x - scale.xAxis.min) / xStep)
        : Math.floor(
            (Math.log10(event.x) - Math.log10(scale.xAxis.min)) *
              xLogUnityLength,
          )
    const yIndex =
      scaleType.yAxis === 'linear'
        ? plotHeight - 1 - Math.floor((event.y - scale.yAxis.min) / yStep)
        : plotHeight -
          1 -
          Math.floor(
            (Math.log10(event.y) - Math.log10(scale.yAxis.min)) *
              yLogUnityLength,
          )
    return [xIndex, yIndex]
  }
}

export const shouldSkipEvent = (
  event: { id: number; x: number; y: number },
  scaleType: ScatterPlotScaleType,
): boolean => {
  return (
    (scaleType.xAxis === 'logarithmic' && event.x <= 0) ||
    (scaleType.yAxis === 'logarithmic' && event.y <= 0)
  )
}

type DownsampleScatterDataProps = {
  eventsBySeriesIdByPixelPosition: EventsBySeriesIdByPixelPosition
  zIndexBySeriesId?: Record<string, number>
  downsampledRatioBySeriesId?: Record<string, number>
}

type DownsampleScatterDataReturnValue = Map<string, string>

export const downsampleScatterData = ({
  eventsBySeriesIdByPixelPosition,
  zIndexBySeriesId = {},
  downsampledRatioBySeriesId,
}: DownsampleScatterDataProps): DownsampleScatterDataReturnValue => {
  const seriesIdByPixelPosition: DownsampleScatterDataReturnValue = new Map()
  for (const [
    pixelPosition,
    eventsBySeriesId,
  ] of eventsBySeriesIdByPixelPosition) {
    let maxZIndex = -Infinity
    let maxPointCount = -Infinity
    let mostPrevalentSeriesId = eventsBySeriesId.keys()[0]

    for (const [seriesId, events] of eventsBySeriesId.entries()) {
      const zIndex = zIndexBySeriesId[seriesId] ?? 0
      const downsampledRatio = downsampledRatioBySeriesId?.[seriesId] ?? 1
      const downsampledPointCount = events.length * downsampledRatio
      if (
        zIndex > maxZIndex ||
        (zIndex === maxZIndex && downsampledPointCount > maxPointCount)
      ) {
        maxZIndex = zIndex
        maxPointCount = downsampledPointCount
        mostPrevalentSeriesId = seriesId
      }
    }
    seriesIdByPixelPosition.set(pixelPosition, mostPrevalentSeriesId)
  }

  return seriesIdByPixelPosition
}
