import { updateModelsQueryDataByUpdatingCollections } from "domains/collections/api/updateQueryDataByUpdatingCollections";
import { updateQueryDataByUpdatingModelConcepts } from "domains/models/api/updateQueryDataByUpdatingModelConcepts";
import { updateQueryDataByUpdatingTags } from "domains/tags/api/updateQueryDataByUpdatingTags";
import { GetModelPresetsByModelIdApiResponse } from "infra/api/generated/api";
import { apiSlice, SelectedInvalidatedByTag } from "infra/store/apiSlice";
import { API_TAGS } from "infra/store/constants";
import { Endpoints } from "infra/store/interface";
import _ from "lodash";

export const modelsFixedKeys = {
  deleteModel: "deleteModel",
};

export const modelsEndpoints: Endpoints = {
  postModelsTrainingImagesByModelId: {
    extraOptions: {
      maxRetries: 5,
    },
  },

  getModels: {
    merge(existing, incoming, { arg }) {
      if (arg.loadedOnly) {
        return incoming;
      }

      existing.models = _.unionBy(existing.models, incoming.models, "id");

      // This fix is in case some assets don't update
      // existing.models = _.unionBy(
      //   _.cloneDeep(existing.models),
      //   incoming.models,
      //   "id"
      // );

      // Has not reached the end of the pagination
      if (!existing?.models || existing.nextPaginationToken) {
        existing.nextPaginationToken = incoming.nextPaginationToken;
      }
    },
    serializeQueryArgs: ({ endpointName, queryArgs }) => {
      return (
        endpointName +
        queryArgs.teamId +
        queryArgs.privacy +
        queryArgs.status +
        queryArgs.collectionId +
        queryArgs.loadedOnly
      );
    },
    forceRefetch({ currentArg, previousArg }) {
      return !_.isEqual(currentArg, previousArg);
    },
    providesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
        },
        {
          type: API_TAGS.model,
          id: `collectionId:${arg.collectionId}`,
        },
      ];
    },
  },

  getModelsByModelId: {
    onQueryStarted: async (arg, { dispatch, queryFulfilled, getState }) => {
      try {
        const { data } = await queryFulfilled;
        if (data) {
          for (const {
            endpointName,
            originalArgs,
          } of apiSlice.util.selectInvalidatedBy(getState(), [
            API_TAGS.search,
          ]) as SelectedInvalidatedByTag[]) {
            if (endpointName === "postSearch") {
              dispatch(
                apiSlice.util.updateQueryData(
                  endpointName,
                  originalArgs,
                  (draft) => {
                    const model = draft?.results?.find(
                      (item) => item.type === "model" && item.id === arg.modelId
                    );
                    if (model) {
                      Object.assign(model, {
                        ...model,
                        trainingImages: data.model.trainingImages,
                      });
                    }
                  }
                )
              );
            }
          }
        }
      } catch (err: any) {
        if (err?.error?.status !== 404) throw err;
      }
    },

    providesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `modelId:${arg.modelId}`,
        },
      ];
    },
  },

  getModelsExamplesByModelId: {
    providesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `modelId:${arg.modelId}`,
        },
      ];
    },
  },

  getModelsDescriptionByModelId: {
    providesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `modelId:${arg.modelId}`,
        },
      ];
    },
  },

  getModelPresetsByModelId: {
    transformResponse: (response) => {
      return {
        presets: (response as GetModelPresetsByModelIdApiResponse).presets.map(
          (preset) => ({
            ...preset,
            parameters: {
              ...preset.parameters,
              numSamples: undefined,
            },
          })
        ),
      };
    },
    providesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `modelId:${arg.modelId}`,
        },
      ];
    },
  },

  deleteModelsByModelId: {
    // Update the cache without refetching
    onQueryStarted: async (arg, { dispatch, queryFulfilled, getState }) => {
      const undoList: (() => void)[] = [];
      try {
        await queryFulfilled;
        // Go through all the queries that use the `inference` tag, remove the deleted inference images from the cache
        for (const {
          endpointName,
          originalArgs,
        } of apiSlice.util.selectInvalidatedBy(getState(), [
          API_TAGS.model,
        ]) as SelectedInvalidatedByTag[]) {
          if (endpointName === "getModels") {
            const update = dispatch(
              apiSlice.util.updateQueryData(
                endpointName,
                originalArgs,
                (draft) => {
                  draft.models = draft?.models?.filter(
                    (model) => model.id !== arg.modelId
                  );
                }
              )
            );
            undoList.push(update.undo);
          }
        }
      } catch (err) {
        // Undo all the updates on error
        for (const undo of undoList) {
          undo();
        }
      }
    },
  },

  postModels: {
    invalidatesTags: [API_TAGS.model],
  },

  postModelsCopyByModelId: {
    invalidatesTags: [API_TAGS.model],
  },

  putModelsDescriptionByModelId: {
    onQueryStarted: async (arg, { dispatch, queryFulfilled }) => {
      const undo = dispatch(
        apiSlice.util.updateQueryData(
          "getModelsDescriptionByModelId",
          {
            teamId: arg.teamId,
            modelId: arg.modelId,
          },
          (draft) => {
            if (draft) {
              draft.description = {
                ...draft.description,
                value: arg.body.description,
              };
            }
          }
        )
      );
      queryFulfilled.catch(() => undo.undo());
    },
    invalidatesTags: [API_TAGS.model],
  },

  putModelsByModelId: {
    onQueryStarted: async (_arg, { dispatch, queryFulfilled, getState }) => {
      const { data } = await queryFulfilled;
      if (data) {
        updateQueryDataByUpdatingModelConcepts({
          updatedModel: data.model,
          dispatch,
          getState,
        });
      }
    },
    invalidatesTags: [API_TAGS.model],
  },

  postModelsTransferByModelId: {
    invalidatesTags: [API_TAGS.model],
  },

  putModelsTagsByModelId: {
    onQueryStarted: async (arg, { dispatch, queryFulfilled, getState }) => {
      const undo = updateQueryDataByUpdatingTags(
        arg.teamId ?? "",
        "model",
        arg.modelId,
        arg.body.add ?? [],
        arg.body.delete ?? [],
        {
          dispatch,
          getState,
        }
      ).undo;
      queryFulfilled.catch(() => undo());
    },
  },

  putModelsExamplesByModelId: {
    invalidatesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `modelId:${arg.modelId}`,
        },
      ];
    },
  },

  putModelsByCollectionId: {
    onQueryStarted: async (arg, { dispatch, queryFulfilled, getState }) => {
      const undo = updateModelsQueryDataByUpdatingCollections(
        arg.teamId,
        "add",
        arg.body.modelIds,
        arg.collectionId,
        {
          dispatch,
          getState,
        }
      ).undo;

      queryFulfilled.catch(() => undo());
    },
    invalidatesTags: (_result, _error, arg) => {
      return [
        {
          type: API_TAGS.model,
          id: `collection:${arg.collectionId}`,
        },
      ];
    },
  },

  postModelPresetByModelId: {
    invalidatesTags: [API_TAGS.model],
  },

  deleteModelPresetByModelIdAndPresetId: {
    invalidatesTags: [API_TAGS.model],
  },

  deleteModelsByCollectionId: {
    onQueryStarted: async (arg, { dispatch, queryFulfilled, getState }) => {
      const undo = updateModelsQueryDataByUpdatingCollections(
        arg.teamId,
        "remove",
        arg.body.modelIds,
        arg.collectionId,
        {
          dispatch,
          getState,
        }
      ).undo;

      queryFulfilled.catch(() => undo());
    },
  },

  putModelsTrainByModelId: {
    invalidatesTags: [API_TAGS.job],
  },
};
