import { Typography, useTheme } from "@mui/material";
import {
  AllSeriesType,
  AxisConfig,
  BarPlot,
  BarPlotProps,
  BarSeriesType,
  ChartsAxisContentProps,
  ChartsGrid,
  ChartsReferenceLine,
  ChartsTooltip,
  ChartsXAxis,
  ChartsXAxisProps,
  ChartsYAxis,
  PieValueType,
  ResponsiveChartContainer,
  ResponsiveChartContainerProps,
  ScatterValueType,
} from "@mui/x-charts";
import { ChartsYReferenceLineProps } from "@mui/x-charts/ChartsReferenceLine/ChartsYReferenceLine";
import {
  AxisScaleConfig,
  ChartSeriesDefaultized,
  ChartSeriesType,
  ChartsYAxisProps,
  DatasetType,
  MakeOptional,
} from "@mui/x-charts/internals";
import ChartTooltip from "./ChartTooltip";

type BarChartProps = {
  dataset: DatasetType;
  series: AllSeriesType[];
  xAxis?: MakeOptional<AxisConfig<"band", any, ChartsXAxisProps>, "id">[];
  yAxis?: MakeOptional<
    AxisConfig<keyof AxisScaleConfig, any, ChartsYAxisProps>,
    "id"
  >[];
  barLabel?: BarPlotProps["barLabel"];
  barLabelColor?: string | undefined;
  referenceLines?: ChartsYReferenceLineProps[];
  onItemClick?: BarPlotProps["onItemClick"];
  // Set this property to use index numbers for the axis labels. The property
  // is used to display the full label in the tooltip along with the index.
  indexAxisLabelProperty?: string;
  stackedTooltipTotalFormatter?: (total: number) => string;
} & ResponsiveChartContainerProps;

/**
  Wrapper component for mui x-charts bar graph which is rendered
  in a responsive container, with grid lines, tooltip, x and y axis as specified.

  Because this is housed in a responsive container, the parent container needs a height/width set.
  You can pass height/width as a prop to this component or set it on a parent.
 */
export const BarChart = ({
  dataset,
  series,
  xAxis,
  yAxis,
  barLabel,
  barLabelColor,
  indexAxisLabelProperty,
  stackedTooltipTotalFormatter,
  referenceLines,
  onItemClick,
  ...responsiveContainerProps
}: BarChartProps) => {
  const theme = useTheme();

  // Calculate the max value in the dataset
  const maxValue = Math.max(
    ...dataset.flatMap((item) =>
      series
        .filter(isBarSeriesType)
        .map((s: BarSeriesType) => Number(item[s.dataKey ?? ""]) || 0)
    )
  );

  // Create a default yAxis config if max value is 0
  const defaultYAxis = [
    {
      min: 0,
      max: 10,
      tickCount: 6,
    },
  ];

  return (
    <ResponsiveChartContainer
      dataset={dataset.map((d, index) => {
        if (!indexAxisLabelProperty) {
          return d;
        }

        return {
          ...d,
          [indexAxisLabelProperty ?? ""]: `${index + 1}.`,
        };
      })}
      series={series}
      xAxis={xAxis}
      yAxis={maxValue === 0 ? defaultYAxis : yAxis}
      {...responsiveContainerProps}
    >
      <ChartsGrid horizontal />
      <BarPlot
        onItemClick={onItemClick}
        barLabel={(item, context) => {
          if (context.bar.height < parseFloat(theme.spacing(2))) {
            return;
          }

          if (typeof barLabel === "function") {
            return barLabel(item, context);
          }

          if (typeof barLabel === "string") {
            return item.value?.toString();
          }
        }}
        slotProps={{
          barLabel: {
            style: { fill: barLabelColor, fontSize: "0.7rem" },
          },
        }}
      />
      <ChartsTooltip
        slots={{
          axisContent: (params: ChartsAxisContentProps) => {
            const { dataIndex, axisValue, series } = params;

            // This should never happen, but just in case
            if (!dataIndex && dataIndex !== 0) {
              return null;
            }

            let title = axisValue?.toString() ?? "";

            if (indexAxisLabelProperty) {
              title = dataset[dataIndex][indexAxisLabelProperty ?? ""]
                ? `${dataIndex + 1}. ${dataset[dataIndex][
                    indexAxisLabelProperty ?? ""
                  ]?.toString()}`
                : "";
            }

            const labels = series
              .filter(isBarSeriesType)
              .map((s) => {
                const val = s.data[dataIndex];
                if (val) {
                  return {
                    color: s.getColor(dataIndex),
                    value: s.valueFormatter(
                      val as number & ScatterValueType & PieValueType,
                      {
                        dataIndex,
                      }
                    ),
                  };
                }
              })
              .filter(
                (label): label is { color: string; value: string } =>
                  label !== undefined
              );

            const total = series
              .filter(isBarSeriesType)
              .map((s) => (s.data[dataIndex] as number) ?? 0)
              .reduce((partialSum, a) => partialSum + a, 0);

            const showTotal = !!stackedTooltipTotalFormatter && !!total;
            let totalLabel = undefined;
            if (showTotal) {
              totalLabel = (
                <Typography
                  variant="body2"
                  fontWeight="bold"
                  sx={{ ml: 3, mt: 1 }}
                >
                  {stackedTooltipTotalFormatter(total)}
                </Typography>
              );
            }

            return (
              <ChartTooltip title={title} labels={labels}>
                {totalLabel}
              </ChartTooltip>
            );
          },
        }}
      />
      <ChartsXAxis
        disableTicks
        slotProps={{
          axisLine: {
            style: {
              stroke: theme.palette.grey[400],
              strokeWidth: 2,
            },
          },
        }}
      />
      <ChartsYAxis
        disableTicks
        slotProps={{
          axisLine: {
            style: {
              stroke: theme.palette.grey[400],
              strokeWidth: 2,
            },
          },
        }}
      />
      {(referenceLines ?? []).map((line, index) => (
        <ChartsReferenceLine key={index} {...line} />
      ))}
    </ResponsiveChartContainer>
  );
};

function isBarSeriesType(
  series: AllSeriesType | ChartSeriesDefaultized<ChartSeriesType>
): series is BarSeriesType {
  return series.type === "bar";
}
