import { useFloating } from '@floating-ui/react'
import Highcharts from 'highcharts'
import HighchartsReact from 'highcharts-react-official'
import React, {
  ReactNode,
  useCallback,
  useEffect,
  useRef,
  useState,
} from 'react'
import { ErrorBoundary } from 'react-error-boundary'
import styled from 'styled-components'

import {
  DEFAULT_USE_FLOATING_PROPS,
  setFloatingTooltipReferencePoint,
} from 'components/tooltip'
import { TooltipContainer } from 'components/tooltip/TooltipContainer'

import { ScatterPlotChartScale } from 'pages/analysis/store/selectors'

import { useAsyncMemo } from 'shared/hooks/useAsyncMemo'
import { useEventCallback } from 'shared/hooks/useEventCallback'
import { useStable } from 'shared/hooks/useStable'
import { useSize } from 'shared/utils/utils'
import { analysisWorker } from 'shared/worker'
import { KDTree } from 'shared/worker/ScatterPlotSeries'

import { ErrorFallback } from '../ErrorFallback'
import {
  computeSeriesImageDataURL,
  renderSeriesImage,
} from './high-performance-scatter-plot.utils'

type HighPerformanceScatterPlotBaseProps<
  TooltipProps extends { seriesId: string },
> = {
  series:
    | {
        seriesIdByPixelPosition: Map<string, string>
        plotWidth: number
        plotHeight: number
      }
    | undefined
  isComputingSeriesRef: React.MutableRefObject<boolean>
  scale: ScatterPlotChartScale | undefined
  colorBySeriesId: Record<string, string>
  options?: Highcharts.Options
  shouldDisableZoom?: boolean
  highlightedSeriesId?: string[]
  seriesDotSizes?: Record<string, number>
  shouldHideTooltip?: boolean
  highchartsRef?: React.ForwardedRef<HighchartsReact.RefObject>
  onContextMenu?: (
    event: React.MouseEvent<HTMLDivElement, MouseEvent>,
    seriesId: string | undefined,
  ) => void
  onRenderFinished?: (this: Highcharts.Chart) => void
  onHighchartsChartChange?: (chart: Highcharts.Chart) => void
  Tooltip: React.ComponentType<TooltipProps>
  tooltipProps: Omit<TooltipProps, 'seriesId'>
}

export const HighPerformanceScatterPlotBase = <
  TooltipProps extends { seriesId: string },
>({
  series,
  isComputingSeriesRef,
  scale,
  colorBySeriesId,
  options,
  shouldDisableZoom = false,
  highlightedSeriesId,
  seriesDotSizes,
  shouldHideTooltip = false,
  highchartsRef,
  onContextMenu,
  onRenderFinished,
  onHighchartsChartChange,
  Tooltip,
  tooltipProps,
}: HighPerformanceScatterPlotBaseProps<TooltipProps>): ReactNode => {
  const highchartsContainerRef = useRef<HTMLDivElement>(null)
  const chartRef = useRef<HighchartsReact.RefObject>()
  const [chartWrapperRef, { width, height }] = useSize<HTMLDivElement>({
    debounce: 500,
  })

  const [hoveredPoint, setHoveredPoint] = useState<{
    position: [number, number]
    seriesId: string
  }>()

  const handleSetHighchartsRef = useEventCallback(
    (chart: HighchartsReact.RefObject) => {
      chartRef.current = chart
      if (highchartsRef) {
        if (typeof highchartsRef === 'function') {
          highchartsRef(chart)
        } else {
          highchartsRef.current = chart
        }
      }
    },
  )

  const highchartsChart = useHighchartsChart(
    scale,
    highchartsContainerRef,
    options,
    width,
    height,
    shouldDisableZoom,
    handleSetHighchartsRef,
    onHighchartsChartChange,
  )

  const kdTree = useKDTree(series?.seriesIdByPixelPosition)

  const seriesImageDataUrl = useSeriesImageDataUrl(
    series,
    highlightedSeriesId,
    useStable(seriesDotSizes),
    colorBySeriesId,
  )

  useRenderSeries(
    highchartsChart,
    seriesImageDataUrl,
    isComputingSeriesRef,
    onRenderFinished,
  )

  const { floatingStyles, refs } = useFloating(DEFAULT_USE_FLOATING_PROPS)

  const handleContextMenu = useEventCallback(
    (event: React.MouseEvent<HTMLDivElement, MouseEvent>) => {
      if (!onContextMenu || !hoveredPoint) {
        return
      }

      event.preventDefault()
      onContextMenu(event, hoveredPoint.seriesId)
    },
  )

  const handleMouseMove = useEventCallback(
    (event: React.MouseEvent<HTMLDivElement, MouseEvent>) => {
      if (!chartRef.current) {
        return
      }

      const { chartX, chartY } = chartRef.current.chart.pointer.normalize(
        event.nativeEvent,
      )
      const plotLeft = chartRef.current.chart.plotLeft
      const plotTop = chartRef.current.chart.plotTop
      const plotWidth = chartRef.current.chart.plotWidth
      const plotHeight = chartRef.current.chart.plotHeight

      const x = chartX - plotLeft
      const y = chartY - plotTop
      if (x < 0 || x >= plotWidth || y < 0 || y >= plotHeight || !kdTree) {
        setHoveredPoint(undefined)
      } else {
        const hoveredPoint = findNearestPoint([x, y], kdTree, 0)
        if (!hoveredPoint) {
          console.error('Could not find hovered point pixel position', kdTree)
          setHoveredPoint(undefined)
          return
        }
        setHoveredPoint(hoveredPoint)
        setFloatingTooltipReferencePoint(
          refs,
          [
            hoveredPoint.position[0] + plotLeft,
            hoveredPoint.position[1] + plotTop,
          ],
          true,
        )
      }
    },
  )

  const handleMouseLeave = useEventCallback(() => {
    setHoveredPoint(undefined)
  })

  const handleSetContainerRef = useEventCallback(
    (container: HTMLDivElement) => {
      refs.setReference(container)
      chartWrapperRef(container)
    },
  )

  return (
    <ErrorBoundary FallbackComponent={ErrorFallback}>
      <ChartWrapper
        ref={handleSetContainerRef}
        onContextMenu={handleContextMenu}
        onMouseMove={handleMouseMove}
        onMouseLeave={handleMouseLeave}
      >
        {!shouldHideTooltip && hoveredPoint ? (
          <TooltipContainer ref={refs.setFloating} style={floatingStyles}>
            <Tooltip
              {...(tooltipProps as TooltipProps)}
              seriesId={hoveredPoint.seriesId}
            />
          </TooltipContainer>
        ) : null}
        <div ref={highchartsContainerRef} />
      </ChartWrapper>
    </ErrorBoundary>
  )
}

const useHighchartsChart = (
  scale: ScatterPlotChartScale | undefined,
  containerRef: React.RefObject<HTMLDivElement>,
  options: Highcharts.Options | undefined,
  width: number,
  height: number,
  isLassoToolActive: boolean,
  setRef: (chart: HighchartsReact.RefObject) => void,
  onHighchartsChartChange: ((chart: Highcharts.Chart) => void) | undefined,
) => {
  const highchartsChartRef = useRef<Highcharts.Chart>()
  const [highchartsChart, setHighchartsChart] = useState<Highcharts.Chart>()

  useEffect(() => {
    if (
      !containerRef.current ||
      width === 0 ||
      height === 0 ||
      scale === undefined
    ) {
      return
    }

    const newHighchartsChart = new Highcharts.Chart(containerRef.current, {
      ...options,
      chart: {
        ...(options?.chart ?? {}),
        width: options?.chart?.width ?? width,
        height: options?.chart?.height ?? height,
      },
      xAxis: {
        ...options?.xAxis,
        ...scale?.xAxis,
      },
      yAxis: {
        ...options?.yAxis,
        ...scale?.yAxis,
      },
      series: [{ type: 'scatter', data: [] }],
    })

    highchartsChartRef.current = newHighchartsChart
    setHighchartsChart(newHighchartsChart)
    onHighchartsChartChange?.(newHighchartsChart)
    setRef({
      chart: newHighchartsChart,
      container: containerRef,
    })

    return () => {
      newHighchartsChart.destroy()
      highchartsChartRef.current = undefined
    }
  }, [
    containerRef,
    height,
    onHighchartsChartChange,
    options,
    scale,
    setRef,
    width,
  ])

  useEffect(() => {
    highchartsChartRef.current?.update({
      chart: { zooming: { type: isLassoToolActive ? undefined : 'xy' } },
    })
  }, [highchartsChart, isLassoToolActive])

  return highchartsChart
}

type UseSeriesResult = {
  seriesIdByPixelPosition: Map<string, string>
  plotWidth: number
  plotHeight: number
}

const useKDTree = (
  seriesIdByPixelPosition: Map<string, string> | undefined,
) => {
  const [kdTree, setKDTree] = useState<KDTree>()
  useEffect(() => {
    let cancelled = false

    if (!seriesIdByPixelPosition) {
      setKDTree(undefined)
      return
    }

    analysisWorker
      .computeScatterPlotKDTree(seriesIdByPixelPosition)
      .then(kdTree => {
        if (!cancelled) {
          setKDTree(kdTree)
        }
      })

    return () => {
      cancelled = true
    }
  }, [seriesIdByPixelPosition])

  return kdTree
}

const useSeriesImageDataUrl = (
  series: UseSeriesResult | undefined,
  highlightedSeriesIds: string[] | undefined,
  seriesDotSizes: Record<string, number> | undefined,
  colorBySeriesId: Record<string, string>,
) => {
  const stable__colorBySeriesId = useStable(colorBySeriesId)

  return useAsyncMemo(
    useCallback(async () => {
      if (
        !series?.seriesIdByPixelPosition ||
        !series?.plotWidth ||
        !series?.plotHeight
      ) {
        return
      }
      return computeSeriesImageDataURL({
        width: series.plotWidth,
        height: series.plotHeight,
        seriesIdByPixelPosition: series.seriesIdByPixelPosition,
        colorBySeriesId: stable__colorBySeriesId,
        highlightedSeriesIds,
        seriesDotSizes,
      })
    }, [
      seriesDotSizes,
      highlightedSeriesIds,
      series?.plotHeight,
      series?.plotWidth,
      series?.seriesIdByPixelPosition,
      stable__colorBySeriesId,
    ]),
  ).value
}

const useRenderSeries = (
  highchartsChart: Highcharts.Chart | undefined,
  seriesImageDataUrl: string | undefined,
  isComputingSeriesRef: React.MutableRefObject<boolean>,
  onRenderFinished: ((this: Highcharts.Chart) => void) | undefined,
) => {
  const seriesImageSvgElementRef = useRef<Highcharts.SVGElement>()

  useEffect(() => {
    if (
      !highchartsChart?.renderer ||
      !seriesImageDataUrl ||
      isComputingSeriesRef.current
    ) {
      return
    }

    seriesImageSvgElementRef.current = renderSeriesImage(
      highchartsChart,
      seriesImageDataUrl,
    )
    onRenderFinished?.call(highchartsChart)

    return () => {
      seriesImageSvgElementRef.current?.destroy()
    }
  }, [
    seriesImageDataUrl,
    highchartsChart,
    onRenderFinished,
    isComputingSeriesRef,
  ])
}

const findNearestPoint = (
  searchPoint: [number, number],
  kdTree: KDTree,
  axisIndex: number,
) => {
  let nearestPoint = kdTree.point

  const distance = (a: [number, number], b: [number, number]) =>
    Math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2)

  const splittingPointDistance =
    searchPoint[axisIndex] - nearestPoint.position[axisIndex]
  const sideA = kdTree[splittingPointDistance < 0 ? 'left' : 'right']
  const sideB = kdTree[splittingPointDistance < 0 ? 'right' : 'left']
  const nextAxisIndex = (axisIndex + 1) % 2

  if (sideA) {
    const sideANearestPoint = findNearestPoint(
      searchPoint,
      sideA,
      nextAxisIndex,
    )
    if (
      distance(sideANearestPoint.position, searchPoint) <
      distance(nearestPoint.position, searchPoint)
    ) {
      nearestPoint = sideANearestPoint
    }
  }

  if (sideB) {
    if (
      Math.abs(splittingPointDistance) <
      distance(nearestPoint.position, searchPoint)
    ) {
      const sideBNearestPoint = findNearestPoint(
        searchPoint,
        sideB,
        nextAxisIndex,
      )

      if (
        distance(sideBNearestPoint.position, searchPoint) <
        distance(nearestPoint.position, searchPoint)
      ) {
        nearestPoint = sideBNearestPoint
      }
    }
  }

  return nearestPoint
}

const ChartWrapper = styled.div`
  width: 100%;
  height: 100%;
  overflow: hidden;
  position: relative;
`
