import React, {
  MouseEventHandler,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from "react";
import { reverseStrength } from "domains/assets/utils";
import { isAssetInferenceTexture } from "domains/assets/utils/isType";
import { mapAssetTypeToJobType } from "domains/assets/utils/mapAssetsToJobs";
import { useDebounce } from "domains/commons/hooks/useDebounce";
import { SDXL_DEFAULT_BASE_MODEL } from "domains/inference/constants/baseModel";
import {
  getGuidanceForModel,
  getSamplingStepsForModel,
} from "domains/inference/constants/pipelinesConfigs";
import {
  ApiInferenceImg2Img,
  ApiInferenceImg2ImgIpAdapter,
  ApiInferenceImg2ImgTexture,
} from "domains/inference/interfaces/ApiTypes";
import { useJobSessionContext } from "domains/jobs/contexts/JobSessionProviderV2";
import {
  getIsModelFluxDev,
  getIsModelFluxPro,
  getIsModelFluxPro1_1,
  getIsModelFluxPro1_1Ultra,
  getIsModelFluxSchnell,
  getIsModelSdxl,
} from "domains/models/utils";
import { useScenarioToast } from "domains/notification/hooks/useScenarioToast";
import { useTeamContext } from "domains/teams/contexts/TeamProvider";
import { usePlanContext } from "domains/teams/hooks/usePlan";
import ButtonWithCuIndicator from "domains/ui/components/ButtonWithCuIndicator";
import { AnalyticsEvents } from "infra/analytics/constants/Events";
import Track from "infra/analytics/Track";
import { useHandleApiError } from "infra/api/error";
import { removeCuFromType } from "infra/api/ExludeCU";
import {
  GetAssetsByAssetIdApiResponse,
  GetJobIdApiResponse,
  GetModelsByModelIdApiResponse,
  useGetModelsInferencesByModelIdAndInferenceIdQuery,
} from "infra/api/generated/api";
import {
  useGetInferenceCostQuery,
  usePostInferenceMutation,
} from "infra/store/apiSlice";
import _ from "lodash";

import { skipToken } from "@reduxjs/toolkit/query";

interface VaryButtonProps {
  asset: GetAssetsByAssetIdApiResponse["asset"];
  cardWidth: number;
  isVisible: boolean;
}

export default React.memo(function VaryButton({
  asset,
  cardWidth,
  isVisible,
}: VaryButtonProps) {
  const { canVary, cuCostVary, isVaryLoadingCu, isVaryLoading, varyImage } =
    useAssetVary({
      asset,
      isVisible,
    });

  const handleVaryClick = useCallback<MouseEventHandler<HTMLButtonElement>>(
    (e) => {
      e.stopPropagation();
      void varyImage();
    },
    [varyImage]
  );

  if (!canVary || !isVisible || cardWidth < 180) {
    return null;
  }

  return (
    <ButtonWithCuIndicator
      onClick={handleVaryClick}
      size="sm"
      variant="blurred"
      isInline
      isLoading={isVaryLoading}
      cuCost={cuCostVary}
      isCuLoading={isVaryLoadingCu || !cuCostVary}
    >
      Vary
    </ButtonWithCuIndicator>
  );
});

// ------------------------------------

interface UseAssetVaryArgs {
  asset: GetAssetsByAssetIdApiResponse["asset"];
  isVisible: boolean;
}

export function useAssetVary({ asset, isVisible }: UseAssetVaryArgs) {
  const { selectedTeam } = useTeamContext();
  const { addJob } = useJobSessionContext();
  const [generateImageTrigger] = usePostInferenceMutation();
  const handleApiError = useHandleApiError();
  const { showLimitModal } = usePlanContext();
  const { successToast, errorToast } = useScenarioToast();
  const [isVaryLoading, setIsVaryLoading] = useState<boolean>(false);
  const [isWaitingForVary, setIsWaitingForVary] = useState<boolean>(false);
  const isDebouncedVisible = useDebounce(isVisible, 1_000);

  const partialModel = useMemo(
    () =>
      ({
        id: asset.metadata.modelId,
        type: asset.metadata.modelType,
      } as GetModelsByModelIdApiResponse["model"]),
    [asset]
  );
  const canVary = useMemo(() => {
    const isModelAllowed =
      !getIsModelFluxPro1_1(partialModel) &&
      !getIsModelFluxPro1_1Ultra(partialModel) &&
      !getIsModelFluxPro(partialModel);
    if (!isModelAllowed) {
      return false;
    }

    const allowedJobTypes: GetJobIdApiResponse["job"]["jobType"][] = [
      "inference",
    ];
    const jobType = mapAssetTypeToJobType(asset.metadata.type);
    if (Array.isArray(jobType)) {
      return allowedJobTypes.some((type) => jobType.includes(type));
    }
    return allowedJobTypes.includes(jobType);
  }, [asset.metadata.type, partialModel]);

  const { data: inferenceData, isLoading: isLoadingInference } =
    useGetModelsInferencesByModelIdAndInferenceIdQuery(
      asset.metadata.modelId &&
        asset.metadata.inferenceId &&
        canVary &&
        isDebouncedVisible
        ? {
            teamId: selectedTeam.id,
            modelId: asset.metadata.modelId,
            inferenceId: asset.metadata.inferenceId,
          }
        : skipToken
    );
  const inference = inferenceData?.inference;

  const paramsGenerate = useMemo(() => {
    if (
      !asset.metadata.modelId ||
      !inference ||
      !canVary ||
      !isDebouncedVisible
    ) {
      return undefined;
    }

    const guidanceConfig = getGuidanceForModel({
      model: partialModel,
      scheduler: undefined,
    });
    const samplingStepsConfig = getSamplingStepsForModel({
      model: partialModel,
      scheduler: undefined,
    });

    const isTextureInference = isAssetInferenceTexture(asset);
    const globalParameters = {
      modelId: asset.metadata.modelId,
      concepts: inference.parameters.concepts,
      baseModelId:
        !isTextureInference &&
        getIsModelSdxl({
          type: asset.metadata.modelType,
        } as GetModelsByModelIdApiResponse["model"]) &&
        asset.metadata.modelId !== "stable-diffusion-xl-base-1.0" &&
        inference.parameters.baseModelId !== SDXL_DEFAULT_BASE_MODEL.id
          ? inference.parameters.baseModelId
          : undefined,
      prompt: asset.metadata.prompt ?? "",
      negativePrompt: asset.metadata.negativePrompt,
      guidance: guidanceConfig.default,
      numInferenceSteps: samplingStepsConfig.default,
      numSamples: 2,
      width: asset.metadata.width,
      height: asset.metadata.height,
    };

    if (
      [
        "txt2img_ip_adapter",
        "img2img_ip_adapter",
        "controlnet_ip_adapter",
      ].includes(inference.parameters.type)
    ) {
      const isFluxInference =
        getIsModelFluxDev({
          id: asset.metadata.baseModelId,
        } as GetModelsByModelIdApiResponse["model"]) ||
        getIsModelFluxSchnell({
          id: asset.metadata.baseModelId,
        } as GetModelsByModelIdApiResponse["model"]) ||
        getIsModelFluxDev({
          id: asset.metadata.modelId,
        } as GetModelsByModelIdApiResponse["model"]) ||
        getIsModelFluxSchnell({
          id: asset.metadata.modelId,
        } as GetModelsByModelIdApiResponse["model"]);

      const body: ApiInferenceImg2ImgIpAdapter = {
        inferenceType: "img2img-ip-adapter",
        teamId: selectedTeam.id,
        body: {
          ...globalParameters,
          ...(isFluxInference
            ? {
                ipAdapterImageIds:
                  inference.parameters.ipAdapterImageIds ??
                  _.compact([inference.parameters.ipAdapterImageId]),
                ipAdapterScales: _.range(
                  inference.parameters.ipAdapterImageIds?.length ?? 1
                ).map((_v) => 0.25),
              }
            : {
                ipAdapterImageId: inference.parameters.ipAdapterImageId,
                ipAdapterScale: 0.25,
              }),
          ipAdapterType: inference.parameters.ipAdapterType,
          strength: reverseStrength(0.5),
          imageId: asset.id,
        },
      };
      return body;
    }

    if (isTextureInference) {
      const body: ApiInferenceImg2ImgTexture = {
        inferenceType: "img2img-texture",
        teamId: selectedTeam.id,
        body: {
          ...globalParameters,
          strength: reverseStrength(0.25),
          imageId: asset.id,
        },
      };
      return body;
    }

    const body: ApiInferenceImg2Img = {
      inferenceType: "img2img",
      teamId: selectedTeam.id,
      body: {
        ...globalParameters,
        strength: reverseStrength(0.25),
        imageId: asset.id,
      },
    };
    return body;
  }, [
    asset,
    inference,
    selectedTeam.id,
    canVary,
    isDebouncedVisible,
    partialModel,
  ]);

  const { data: costData, isFetching: isCuLoading } = useGetInferenceCostQuery(
    paramsGenerate ?? skipToken
  );
  const cuCostVary = costData?.creativeUnitsCost;

  // ----------------------------------

  const varyImage = useCallback(async () => {
    try {
      setIsVaryLoading(true);

      if (!paramsGenerate) {
        if (canVary && isVisible) {
          setIsWaitingForVary(true);
        }
        return;
      }

      const { job } = await generateImageTrigger(paramsGenerate)
        .unwrap()
        .then(removeCuFromType);
      addJob(job);

      Track(AnalyticsEvents.Inference.CreatedInference, {
        ...paramsGenerate.body,
      });
      successToast({
        title: "Images are being generated",
      });
    } catch (error: any) {
      handleApiError(error, "Error generating images", {
        quota: () => {
          if (_.get(error, "data.details.remainingSeconds")) {
            showLimitModal("planCooldown", {
              timeout: error.data.details.remainingSeconds,
              type: "generation",
            });
          } else if (_.get(error, "data.name") === "RateLimitError") {
            showLimitModal("planImageGenerationsRateLimit");
          } else if (
            _.get(error, "data.details.actionName") === "parallel-inferences"
          ) {
            showLimitModal("parallelInferences");
          } else if (
            _.get(error, "data.reason").includes(
              "You have reached your plan's limit."
            )
          ) {
            showLimitModal("notEnoughCreativeUnits");
          } else {
            errorToast({
              title: "Error generating images",
              description: _.get(error, "data.reason"),
            });
          }
        },
        file_size: () => {
          errorToast({
            title: "Error generating images",
            description: "The reference image file size is too big.",
          });
        },
      });
    } finally {
      setIsVaryLoading(false);
    }
  }, [
    paramsGenerate,
    generateImageTrigger,
    addJob,
    successToast,
    canVary,
    isVisible,
    handleApiError,
    showLimitModal,
    errorToast,
  ]);

  // ----------------------------------

  useEffect(() => {
    if (!isLoadingInference && isWaitingForVary) {
      setIsWaitingForVary(false);
      void varyImage();
    }
  }, [isWaitingForVary, isLoadingInference, varyImage]);

  // ----------------------------------

  return useMemo(
    () => ({
      canVary,
      cuCostVary,
      isVaryLoadingCu: isCuLoading,
      isVaryLoading: isVaryLoading || isWaitingForVary,
      varyImage,
    }),
    [
      canVary,
      cuCostVary,
      isCuLoading,
      isVaryLoading,
      isWaitingForVary,
      varyImage,
    ]
  );
}
