import { scaleBand, scaleLinear } from "@visx/scale";
import { Bar, BarRounded, BarStack } from "@visx/shape";
import type {
  StackedStandKpiComparisonInfoDto,
  StackedStandKpiSummaryDto,
  StandKpiSliceDto,
} from "apis/oag";
import type { KpiGroupUserLensDto, StackedStandKpiComparisonDto, StackedStandKpiDto } from "apis/oag";
import { DisplayOption, ResultDataState, StackingType, VisualAidType } from "apis/oag";
import { AverageLabel, AverageLine } from "components/Lenses/common/AverageLine";
import { Chart } from "components/Lenses/common/Chart";
import { Label } from "components/Lenses/common/Label";
import { MedianLabel, MedianLine } from "components/Lenses/common/MedianLine";
import OperationCount, { StyledOperationCountContainer } from "components/Lenses/common/OperationCount";
import TargetLine from "components/Lenses/common/TargetLine";
import { TooltipGroup, TooltipHighlightValue } from "components/Lenses/common/Tooltip";
import { TooltipVisualAidInfo, useChartTooltip } from "components/Lenses/common/useChartTooltip";
import { useOutlierThreshold } from "components/Lenses/common/useOutlierThreshold";
import type { StackedKpiChartComparisonProps } from "components/Lenses/ContainerLens/StackedKpi/interfaces";
import { getMinMax, getSVGNormalizedValue } from "components/Lenses/utils";
import { RIGHT_AXIS_PADDING } from "components/TvDChart/constants";
import { useTargetSegments } from "hooks/charting/useTargetSegments";
import { useKpiTypes } from "hooks/useKpiTypes";
import { useWellShortInfoSuspended } from "hooks/useWellShortInfo";
import { roundNumberToDecimal } from "pages/Lens/LensSummaryView";
import { shade, tint, transparentize } from "polished";
import React, { useCallback, useMemo } from "react";
import { useResizeDetector } from "react-resize-detector";
import { OPERATION_COUNT_HEIGHT, TOP_LABEL_HEIGHT, TOP_LABEL_MARGIN } from "utils/constants";
import { useUOM, useUOMbyLens, UtilDimensions } from "utils/format";
import { truncateMiddleString } from "utils/helper";
import { useColors } from "utils/useColors";
import { useCustomTheme } from "utils/useTheme";

interface StackedColumnProps extends StackedKpiChartComparisonProps {
  lens: KpiGroupUserLensDto;
  displayOption: DisplayOption;
}

type TransformedDataItem = {
  index: number;
  id: number | string;
  targetValue?: number | null;
  isOutlier?: boolean;
  operationCount?: number;
  totalValue?: number;
  name: string;
  values: Record<number, number>; // kpi id -> value
  average: Record<number, number>; // kpi id -> value
};

type TransformedData = TransformedDataItem[];

export function StackedComparisonWellChart({
  detailed,
  data,
  focalWellColor = "",
  dimension,
  lens,
  displayOption,
  templateType,
}: StackedColumnProps) {
  const { getColor } = useColors();
  const { data: wellShortInfo } = useWellShortInfoSuspended();
  const { data: kpiTypes } = useKpiTypes();
  const valueUOM = useUOMbyLens(dimension, lens);
  const percentageUOM = useUOM(UtilDimensions.Percentage);

  const { width: chartWidthHook, height: chartHeightHook, ref: containerRef } = useResizeDetector();
  const { chartWidth, chartHeight: containerHeight } = {
    chartHeight: getSVGNormalizedValue(chartHeightHook),
    chartWidth: getSVGNormalizedValue(chartWidthHook),
  };
  const chartHeight = getSVGNormalizedValue(containerHeight - (lens.showOperationCount ? OPERATION_COUNT_HEIGHT : 0));

  const operationCountTopMargin = useMemo(() => {
    if (lens.label && lens.showOperationCount) {
      return TOP_LABEL_MARGIN;
    }
    return 0;
  }, [lens.label, lens.showOperationCount]);

  const topMargin = useMemo(() => {
    if (lens.label) {
      return TOP_LABEL_HEIGHT + operationCountTopMargin;
    }
    return 0;
  }, [lens.label, operationCountTopMargin]);

  const plotHeight = getSVGNormalizedValue(chartHeight - (detailed ? 136 : 36) - topMargin);
  const plotWidth = getSVGNormalizedValue(chartWidth - (detailed ? 76 : 44) - 10);
  const getData = useCallback(
    (d: StandKpiSliceDto | StackedStandKpiSummaryDto) => {
      if (!d) return 0;
      if (lens.stackingType === StackingType.Distribution) return (d as StandKpiSliceDto).distribution;
      return d.average;
    },
    [lens.stackingType],
  );
  const kpiKeys = useMemo(() => {
    if (displayOption === DisplayOption.Shift) {
      return ((data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.daySlices || [])
        .sort((a, b) => a.position - b.position)
        .map((item) => item.kpiTypeId);
    }
    return (data as StackedStandKpiComparisonDto & StackedStandKpiDto).summaryByKpi
      ?.slice()
      .sort((a, b) => a.position - b.position)
      .map((item) => item.id);
  }, [displayOption, data]);

  const [average, median] = useMemo(() => {
    // TODO this is kinda bad... We are using the same type for comparison and non comparison
    if (displayOption === DisplayOption.Shift) {
      return [
        getData((data as unknown as StackedStandKpiDto)?.summary),
        (data as StackedStandKpiComparisonDto & StackedStandKpiDto).summary.median ?? 0,
      ];
    }

    const aggregate = ((data as StackedStandKpiComparisonDto & StackedStandKpiDto).summaryByKpi || []).find(
      (e) => e.isAggregate,
    );

    return [aggregate?.allAverage ?? 0, aggregate?.median ?? 0];
  }, [displayOption, data, getData]);

  const transformedData = useMemo<TransformedData>(() => {
    if (displayOption === DisplayOption.Shift) {
      const daySliceMap = (
        (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.daySlices ?? []
      ).reduce(
        (acc, slice) => {
          acc[slice.kpiTypeId] = getData(slice);
          return acc;
        },
        {} as Record<number, number>,
      );
      const nightSliceMap = (
        (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.nightSlices ?? []
      ).reduce(
        (acc, slice) => {
          acc[slice.kpiTypeId] = getData(slice);
          return acc;
        },
        {} as Record<number, number>,
      );
      const daySliceMapAverage = (
        (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.daySlices ?? []
      ).reduce(
        (acc, slice) => {
          acc[slice.kpiTypeId] = slice.average;
          return acc;
        },
        {} as Record<number, number>,
      );
      const nightSliceMapAverage = (
        (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.nightSlices ?? []
      ).reduce(
        (acc, slice) => {
          acc[slice.kpiTypeId] = slice.average;
          return acc;
        },
        {} as Record<number, number>,
      );
      const totalDayAverage = (
        data as StackedStandKpiComparisonDto & StackedStandKpiDto
      )?.detailsByShift?.daySlices?.reduce((acc, slice) => acc + getData(slice), 0);
      const totalNightAverage = (
        data as StackedStandKpiComparisonDto & StackedStandKpiDto
      )?.detailsByShift?.nightSlices?.reduce((acc, slice) => acc + getData(slice), 0);
      return [
        {
          index: 0,
          // totalValue: data?.detailsByShift?.daySlices.,
          id: "Day",
          name: "Day",
          targetValue: (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.dayTargetValue,
          totalValue: totalDayAverage,
          values: daySliceMap,
          average: daySliceMapAverage,
          operationCount: (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift
            ?.dayOperationCount,
        },
        {
          index: 1,
          id: "Night",
          name: "Night",
          targetValue: (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift?.nightTargetValue,
          totalValue: totalNightAverage,
          values: nightSliceMap,
          average: nightSliceMapAverage,
          operationCount: (data as StackedStandKpiComparisonDto & StackedStandKpiDto)?.detailsByShift
            ?.nightOperationCount,
        },
      ];
    }
    if (!(data as StackedStandKpiComparisonDto & StackedStandKpiDto).comparisons) return [];
    return (data as StackedStandKpiComparisonDto & StackedStandKpiDto).comparisons
      .filter((d) => d.dataState === ResultDataState.Valid)
      .map((detail, index) => {
        const categoryInfo: StackedStandKpiComparisonInfoDto | undefined = detail?.detailsByDisplayOption?.[0]; // TODO we shall not assume the first one is the right one
        // if (!categoryInfo?.slices) return {};
        const sliceMap = (categoryInfo?.slices || []).reduce(
          (acc, slice) => {
            acc[slice.kpiTypeId] = getData(slice);
            return acc;
          },
          {} as Record<number, number>,
        );
        const sliceMapAverage = (categoryInfo?.slices || []).reduce(
          (acc, slice) => {
            acc[slice.kpiTypeId] = slice.average;
            return acc;
          },
          {} as Record<number, number>,
        );
        const totalValue = Object.keys(sliceMap)
          .map((e) => sliceMap[+e])
          .reduce((acc, item) => acc + item, 0);
        return {
          index,
          id: detail.wellId,
          name: wellShortInfo?.byId[detail.wellId]?.name ?? "Unnamed",
          targetValue: detail.targetAverage,
          totalValue,
          operationCount: categoryInfo?.operationCount,
          values: sliceMap,
          average: sliceMapAverage,
        };
      })
      .filter((e) => e);
  }, [data, displayOption, getData, wellShortInfo?.byId]);

  const categoryScale = useMemo(
    () =>
      scaleBand<string>({
        domain: transformedData.map((i) => i.id.toString()),
        // Offsetting the pixel range based on zoom, provides scrolling and zoom at the same time
        range: [0, plotWidth],
        paddingInner: 0.2,
        paddingOuter: displayOption === DisplayOption.Shift ? 0.2 : 1,
      }),
    [plotWidth, displayOption, transformedData],
  );

  const computedTargetSegments = useTargetSegments(transformedData, categoryScale);

  const totalValues = useMemo(
    () => transformedData?.filter((stand) => !stand?.isOutlier).map((stand) => stand.totalValue || 0),
    [transformedData],
  );

  const { outlierThreshold, gradientDefinition, gradientFill } = useOutlierThreshold({
    values: totalValues,
    enabled: lens?.showsOutliers,
    selectedVisualAids: lens?.selectedVisualAids,
    targetSegments: computedTargetSegments,
  });

  const valueScale = useMemo(() => {
    return scaleLinear<number>({
      domain: getMinMax(
        totalValues,
        computedTargetSegments,
        outlierThreshold,
        median,
        lens?.selectedVisualAids,
        lens?.outlierFlaggingType,
        lens?.isManualYaxis,
        lens?.yaxisEnd,
        lens?.yaxisStart,
      ),
      range: [plotHeight, 0],
      clamp: true,
    });
  }, [
    totalValues,
    computedTargetSegments,
    outlierThreshold,
    median,
    lens?.selectedVisualAids,
    lens?.outlierFlaggingType,
    lens?.isManualYaxis,
    lens?.yaxisEnd,
    lens?.yaxisStart,
    plotHeight,
  ]);

  const {
    isDark,
    themeStyle: { colors: themeColors },
  } = useCustomTheme();

  const columnWidth = categoryScale.bandwidth();

  const { showTooltip, hideTooltip, tooltipElement } = useChartTooltip<TransformedDataItem>({
    containerRef,
    renderContent: ({ tooltipData }) => (
      <>
        <TooltipHighlightValue>
          {lens.stackingType === StackingType.Values
            ? valueUOM.display(tooltipData?.totalValue)
            : (kpiKeys || [])
                .slice()
                .reverse()
                .map(
                  (key) =>
                    kpiTypes?.byId[key] && (
                      <div key={key}>
                        {kpiTypes.byId[key].name}:{" "}
                        {roundNumberToDecimal((tooltipData?.values[key] ?? 0) * 100, 2) + "%"}
                      </div>
                    ),
                )}
        </TooltipHighlightValue>
        <TooltipGroup>
          {(kpiKeys || [])
            .slice()
            .reverse()
            .map(
              (key) =>
                kpiTypes?.byId[key] && (
                  <div key={key}>
                    {kpiTypes.byId[key].name}: {valueUOM.display(tooltipData?.average[key] ?? 0)}
                  </div>
                ),
            )}
        </TooltipGroup>

        <TooltipVisualAidInfo
          selectedVisualAids={lens?.selectedVisualAids}
          display={valueUOM.display}
          targetValue={tooltipData?.targetValue}
          averageValue={average}
          median={median}
        />

        {lens.showOperationCount && tooltipData?.operationCount ? (
          <TooltipHighlightValue>Operation count: {tooltipData?.operationCount}</TooltipHighlightValue>
        ) : null}
        <span>{tooltipData?.name}</span>
        {tooltipData?.isOutlier ? <TooltipHighlightValue fontWeight="normal">OUTLIER</TooltipHighlightValue> : null}
      </>
    ),
  });

  const handleMouseOver = useCallback(
    (barData: TransformedDataItem) => {
      showTooltip({
        tooltipLeft: (categoryScale(barData.id.toString()) || 0) + columnWidth / 2,
        tooltipTop: valueScale(barData.totalValue || 0),
        tooltipData: barData,
      });
    },
    [categoryScale, columnWidth, showTooltip, valueScale],
  );

  // Starting initialization only with div to calculate plot width & height
  if (Number.isNaN(plotWidth) || Number.isNaN(plotHeight)) {
    return (
      <div
        ref={containerRef}
        style={{
          padding: 0,
          width: "100%",
          height: "100%",
          position: "relative",
          overflow: "visible",
        }}
      />
    );
  }

  return (
    <div
      ref={containerRef}
      style={{
        padding: 0,
        width: "100%",
        height: "100%",
        position: "relative",
        overflow: "visible",
      }}
    >
      <StyledOperationCountContainer visible={lens.showOperationCount} />
      <svg width={chartWidth} height={chartHeight} style={{ overflow: "visible", userSelect: "none" }}>
        {gradientDefinition}
        <Chart
          topMargin={topMargin}
          detailed={detailed}
          isManual={lens?.isManualYaxis}
          chartWidth={chartWidth}
          chartHeight={chartHeight}
          plotWidth={plotWidth}
          plotHeight={plotHeight}
          valueScale={valueScale}
          showOperationCount={lens.showOperationCount}
          categoryScale={categoryScale}
          operationCountTopMargin={-operationCountTopMargin}
          valueUOM={lens.stackingType === StackingType.Distribution ? percentageUOM : valueUOM}
          rightTickFormat={
            lens.stackingType === StackingType.Distribution ? (value) => `${value as string}%` : undefined
          }
          tickFormat={(id) =>
            displayOption === DisplayOption.Shift
              ? transformedData.find((item) => item.id === id)?.name
              : truncateMiddleString(wellShortInfo?.byId[+id]?.name || "", columnWidth / 10)
          }
        >
          <BarStack<TransformedDataItem, number>
            x={(d) => d.id.toString()}
            keys={kpiKeys}
            value={(d, key) => d.values[key] ?? 0}
            data={transformedData}
            xScale={categoryScale}
            yScale={valueScale}
            onMouseOut={hideTooltip}
            color={() => focalWellColor}
          >
            {(barStacks) => [
              ...barStacks.map((barStack) => {
                let fullBarHeight = 0;
                return (
                  <React.Fragment key={`bar-stack-${barStack.index}`}>
                    {barStack.bars.map((bar, index) => {
                      if (bar.bar.some((e) => Number.isNaN(e))) return null;
                      const bars = Object.values(bar?.bar?.data?.values ?? []).filter((x) => x);
                      const isLast = barStack.index === bars.length;
                      fullBarHeight += bar.height;
                      return (
                        <React.Fragment key={`bar-stack-${barStack.index}-${bar.index}-fragment`}>
                          <BarRounded
                            key={`bar-stack-${barStack.index}-${bar.index}`}
                            x={bar.x}
                            y={bar.y}
                            top={isLast ? lens.stackingType !== StackingType.Distribution : undefined}
                            radius={2}
                            height={bar.height}
                            width={bar.width}
                            filter={(() => {
                              if (templateType === "Rotate vs Slide" && bar.bar.data.name === "Night")
                                return `grayscale(${barStack.index * 50}%)`;
                              return "";
                            })()}
                            fill={(() => {
                              if (!getColor({ key: bar.bar.data.id.toString() }) && !bar.color) return "";
                              if (bar.bar.data?.isOutlier)
                                return isDark
                                  ? `${transparentize(
                                      0.9,
                                      shade(0.3 * (barStacks.length - barStack.index - 1), themeColors.outlier_bars),
                                    )}`
                                  : `${tint(0.3 * barStack.index, themeColors.outlier_bars_stacked_transparent)}`;
                              let multiplier = 0.3;
                              if (templateType === "Rotate vs Slide") multiplier = 0.5;
                              return displayOption === DisplayOption.Shift
                                ? bar.bar.data.name === "Night"
                                  ? shade(multiplier * (barStacks.length - barStack.index - 1), bar.color)
                                  : tint(multiplier * barStack.index, bar.color)
                                : tint(
                                    multiplier * (barStack.index - 1),
                                    getColor({ key: (bar.bar.data.id ?? "").toString() }),
                                  );
                            })()}
                          />
                          {barStack.index === barStacks.length - 1 && (
                            <>
                              <Label
                                barX={bar.x}
                                key={`bar-stack-label-${barStack.index}`}
                                barY={bar.y + (detailed ? 0 : 10)}
                                topLabel={bar.y !== plotHeight}
                                detailed={detailed}
                                index={index}
                                columnWidth={bar.width}
                                barHeight={fullBarHeight}
                                value={
                                  lens.stackingType === StackingType.Distribution
                                    ? ""
                                    : valueUOM.display(bar.bar.data.totalValue, { unit: "" })
                                }
                                label={lens?.label}
                              />
                              {lens.showOperationCount ? (
                                <OperationCount
                                  topMargin={-operationCountTopMargin}
                                  x={bar.x}
                                  width={bar.width}
                                  index={index}
                                  detailed={detailed}
                                  value={bar.bar.data.operationCount}
                                />
                              ) : null}

                              {bar.bar.data?.isOutlier ? (
                                <Bar
                                  x={bar.x - 1}
                                  y={bar.y - 1}
                                  width={bar.width + 2}
                                  height={detailed ? 50 : 25}
                                  fill={gradientFill}
                                />
                              ) : null}
                            </>
                          )}
                        </React.Fragment>
                      );
                    })}
                  </React.Fragment>
                );
              }),
              ...barStacks.map((barStack) => (
                <React.Fragment key={`bar-stack-inner-${barStack.index}`}>
                  {barStack.bars.map((bar, index) => {
                    if (bar.bar.some((e) => Number.isNaN(e))) return null;
                    return (
                      <Label
                        barX={bar.x}
                        innerLabel
                        key={`bar-stack-inner-label-${barStack.index}-${bar.index}`}
                        barY={bar.y + bar.height / 2}
                        columnWidth={bar.width}
                        barHeight={bar.height}
                        index={index}
                        value={
                          lens.stackingType === StackingType.Distribution
                            ? `${roundNumberToDecimal(bar.bar.data.values[bar.key] * 100, 2)}%`
                            : valueUOM.display(bar.bar.data.values[bar.key], { unit: "" })
                        }
                        label={lens?.label}
                      />
                    );
                  })}
                </React.Fragment>
              )),
            ]}
          </BarStack>

          {/* Rectangles to cover the whole stacked bar and handle all the events */}
          {transformedData.map((item) => {
            const x = categoryScale(item.id.toString());

            return (
              <Bar
                key={item.id}
                x={x}
                y={valueScale(item.totalValue || 0)}
                width={columnWidth}
                height={valueScale(0) - valueScale(item.totalValue || 0)}
                fillOpacity={0}
                onMouseOver={() => handleMouseOver(item)}
                onMouseOut={hideTooltip}
              />
            );
          })}

          <AverageLine
            isVisible={(lens?.selectedVisualAids ?? []).includes(VisualAidType.Average)}
            y={valueScale(average)}
            x={-RIGHT_AXIS_PADDING}
            width={plotWidth + RIGHT_AXIS_PADDING}
          />

          <MedianLine
            isVisible={(lens?.selectedVisualAids ?? []).includes(VisualAidType.Median)}
            y={valueScale(median)}
            x={-RIGHT_AXIS_PADDING}
            width={plotWidth + RIGHT_AXIS_PADDING}
          />

          {(lens?.selectedVisualAids ?? []).includes(VisualAidType.Targets) &&
            computedTargetSegments.map(({ target, lineStart, lineEnd, showTag }) => {
              if (!target) return null;
              return (
                <TargetLine
                  key={`${target}-${lineStart}-${lineEnd}`}
                  start={lineStart}
                  end={lineEnd}
                  y={valueScale(target)}
                  label={valueUOM.display(target, { unit: "" })}
                  showTag={!!showTag}
                  detailed={detailed}
                />
              );
            })}
        </Chart>
      </svg>

      {(lens?.selectedVisualAids ?? []).includes(VisualAidType.Average) && (
        <AverageLabel
          style={{
            top: `${valueScale(average) - 16 + (lens.showOperationCount ? OPERATION_COUNT_HEIGHT : 0)}px`,
          }}
        >
          Average: {valueUOM.display(average)}
        </AverageLabel>
      )}

      {(lens?.selectedVisualAids ?? []).includes(VisualAidType.Median) && (
        <MedianLabel
          isDetailed={detailed}
          style={{
            top: `${valueScale(median) - 16 + (lens.showOperationCount ? OPERATION_COUNT_HEIGHT : 0)}px`,
          }}
        >
          Median: {valueUOM.display(median)}
        </MedianLabel>
      )}

      {tooltipElement}
    </div>
  );
}
