import {
  GridColDef,
  GridEventListener,
  GridGroupingColDefOverride,
  GridGroupingColDefOverrideParams,
  GridValidRowModel,
} from "@mui/x-data-grid-premium";
import { GridInitialStatePremium } from "@mui/x-data-grid-premium/models/gridStatePremium";
import { isFunction, sortBy } from "lodash";
import React from "react";

// we should maintain columns state ourselves, as described here: https://github.com/mui/mui-x/issues/970#issuecomment-776612509
export function useDataGridColumns<RowModel extends GridValidRowModel>(
  state: GridInitialStatePremium | undefined,
  columns: readonly GridColDef<RowModel>[],
  groupingColDef?:
    | GridGroupingColDefOverride<RowModel>
    | ((
        params: GridGroupingColDefOverrideParams
      ) => GridGroupingColDefOverride<RowModel> | undefined | null)
) {
  const columnsExtended = React.useMemo(() => {
    let _columns = columns;

    // 1. add pre-saved dimensions to columns
    if (state?.columns?.dimensions) {
      _columns = _columns.map(c => {
        const dimension = state?.columns?.dimensions?.[c.field];
        if (dimension && dimension.width) {
          return { ...c, width: dimension.width, flex: undefined };
        }
        return c;
      });
    }

    const orderedFields = state?.columns?.orderedFields;
    if (orderedFields) {
      _columns.forEach((c, index) => {
        // if column doesn't exist in ordered fields (e.g. added recently) - we put in to the order according to the initial columns array
        if (!orderedFields.includes(c.field)) {
          orderedFields.splice(index, 0, c.field);
        }
      });
    }
    // 2. sort columns according to their pre-saved order
    if (state?.columns?.orderedFields) {
      _columns = sortBy(_columns, c =>
        orderedFields
          ? orderedFields.findIndex(column => c.field === column)
          : columns.findIndex(column => c.field === column.field)
      );
    }

    return _columns;
  }, [columns, state]);

  const groupingColDefExtended = React.useMemo(() => {
    const dimensionTreeGroup =
      state?.columns?.dimensions?.["__tree_data_group__"];
    const dimensionRowGroup =
      state?.columns?.dimensions?.["__row_group_by_columns_group__"];

    if (groupingColDef && (dimensionTreeGroup || dimensionRowGroup)) {
      if (isFunction(groupingColDef)) {
        return (params: GridGroupingColDefOverrideParams) => {
          if ((dimensionTreeGroup ?? dimensionRowGroup)?.width) {
            return {
              ...groupingColDef(params),
              width: (dimensionTreeGroup ?? dimensionRowGroup)?.width,
              flex: undefined,
            };
          }
          return groupingColDef(params);
        };
      }

      if ((dimensionTreeGroup ?? dimensionRowGroup)?.width) {
        return {
          ...groupingColDef,
          width: (dimensionTreeGroup ?? dimensionRowGroup)?.width,
          flex: undefined,
        };
      }
      return groupingColDef;
    }

    return groupingColDef;
  }, [groupingColDef, state]);

  return { columnsExtended, groupingColDefExtended };
}
