import { assetsEndpoints } from "domains/assets/api/endpoints";
import { collectionsEndpoints } from "domains/collections/api/endpoints";
import { inferencesEndpoints } from "domains/inference/api/endpoints";
import { ApiInferenceType } from "domains/inference/interfaces/ApiTypes";
import { jobsEndpoints } from "domains/jobs/api/endpoints";
import { modelsEndpoints } from "domains/models/api/endpoints";
import { searchEndpoints } from "domains/search/api/endpoints";
import { skyboxesEndpoints } from "domains/skyboxes/api/endpoints";
import { subscriptionsEndpoints } from "domains/subscriptions/api/endpoints";
import { teamsEndpoints } from "domains/teams/api/endpoints";
import { usersEndpoints } from "domains/user/api/endpoints";
import {
  GetJobIdApiResponse,
  PostAssetGetBulkApiArg,
  PostAssetGetBulkApiResponse,
  PostPixelateInferencesApiArg,
  PostRemoveBackgroundInferencesApiArg,
  PostRestyleInferencesApiArg,
  PostSkyboxBase360InferencesApiArg,
  PostSkyboxUpscale360InferencesApiArg,
  PostTextureInferencesApiArg,
  PostUpscaleInferencesApiArg,
  PostVectorizeInferencesApiArg,
  PutModelsTrainByModelIdApiArg,
  scenarioApi,
} from "infra/api/generated/api";
import { API_TAGS } from "infra/store/constants";
import _ from "lodash";

type Tags = Record<string, string>;

const tagsRecordToArrays = (tags: Tags) => Object.values(tags) as string[];

export const enhancedScenarioApi = scenarioApi.enhanceEndpoints({
  endpoints: {
    ...inferencesEndpoints,
    ...teamsEndpoints,
    ...modelsEndpoints,
    ...jobsEndpoints,
    ...skyboxesEndpoints,
    ...usersEndpoints,
    ...searchEndpoints,
    ...assetsEndpoints,
    ...collectionsEndpoints,
    ...subscriptionsEndpoints,
  },
  addTagTypes: tagsRecordToArrays(API_TAGS),
});

export const apiSlice = enhancedScenarioApi.injectEndpoints({
  endpoints: (build) => ({
    postInference: build.mutation<
      {
        job: NonNullable<GetJobIdApiResponse["job"]>;
      },
      ApiInferenceType
    >({
      query: (queryArg) => ({
        url: `/generate/${queryArg.inferenceType}`,
        method: "POST",
        body: queryArg.body,
        params: {
          teamId: queryArg.teamId,
          originalAssets: queryArg.originalAssets,
          dryRun: queryArg.dryRun ?? "false",
        },
      }),
      invalidatesTags: [
        { type: API_TAGS.job },
        { type: API_TAGS.asset },
        { type: API_TAGS.limits },
      ],
    }),
    getInferenceCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<ApiInferenceType, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/${queryArg.inferenceType}`,
        method: "POST",
        body: queryArg.body,
        params: {
          teamId: queryArg.teamId,
          originalAssets: queryArg.originalAssets,
          dryRun: "true",
        },
      }),
      serializeQueryArgs: ({ queryArgs, endpointName }) => {
        // remove all the assetId and images from the cacheKeys
        const params = _.omit(queryArgs?.body || {}, [
          "ipAdapterImageId",
          "maskId",
          "controlImageId",
          "mask",
          "seed",
          "imageId",
          "prompt",
          "negativePrompt",
        ] as (keyof typeof queryArgs.body)[]) as typeof queryArgs.body;

        return {
          endpointName,
          params,
        };
      },
    }),

    getPixelateCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostPixelateInferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/pixelate`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getImageUpscaleCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostUpscaleInferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/upscale`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),
    getRemoveBackgroundCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostRemoveBackgroundInferencesApiArg, "dryRun"> & {
        cacheKeys: {
          width?: number;
          height?: number;
          assetId?: string;
        };
      }
    >({
      serializeQueryArgs: ({ queryArgs, endpointName }) => {
        return {
          endpointName,
          // Will only invalidate the cache if the assetId, width or height changes
          cacheKeys: queryArgs.cacheKeys,
        };
      },
      query: (queryArg) => ({
        url: `/generate/remove-background`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getVectorizationCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostVectorizeInferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/vectorize`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getSkyboxInferenceCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostSkyboxBase360InferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/skybox-base-360`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getSkyboxUpscaleCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostSkyboxUpscale360InferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/skybox-upscale-360`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getTextureMapsCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostTextureInferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/texture`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getRestyleInferenceCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PostRestyleInferencesApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/generate/restyle`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
          dryRun: "true",
        },
      }),
    }),

    getModelTrainCost: build.query<
      {
        creativeUnitsCost: number;
      },
      Omit<PutModelsTrainByModelIdApiArg, "dryRun">
    >({
      query: (queryArg) => ({
        url: `/models/${queryArg.modelId}/train`,
        method: "PUT",
        body: queryArg.body,
        params: {
          teamId: queryArg.teamId,
          dryRun: "true",
          trainingImagesCount: queryArg.trainingImagesCount,
        },
      }),
    }),

    getAssetsBulk: build.query<
      PostAssetGetBulkApiResponse,
      PostAssetGetBulkApiArg
    >({
      query: (queryArg) => ({
        url: `/assets/get-bulk`,
        method: "POST",
        body: queryArg.body,
        params: {
          originalAssets: queryArg.originalAssets,
          teamId: queryArg.teamId,
        },
      }),
      providesTags: (_result, _error, arg) => {
        return (
          arg.body.assetIds?.map((assetId) => ({
            type: API_TAGS.asset,
            id: `assetId:${assetId}`,
          })) ?? []
        );
      },
    }),
  }),

  overrideExisting: false,
});

export const {
  useGetPixelateCostQuery,
  useGetRemoveBackgroundCostQuery,
  useGetVectorizationCostQuery,
  useGetImageUpscaleCostQuery,
  usePostInferenceMutation,
  useGetInferenceCostQuery,
  useGetSkyboxInferenceCostQuery,
  useGetSkyboxUpscaleCostQuery,
  useGetTextureMapsCostQuery,
  useGetRestyleInferenceCostQuery,
  useGetModelTrainCostQuery,
  useGetAssetsBulkQuery,
  useLazyGetAssetsBulkQuery,
} = apiSlice;

type EndpointName = keyof typeof apiSlice.endpoints;

export type SelectedInvalidatedByTag = {
  endpointName: EndpointName;
  originalArgs: any;
};
