import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useRouter } from "next/router";
import { HEADER_HEIGHT } from "domains/assets/components/AssetZoom/Header";
import { extractFirstQueryParam } from "domains/commons/misc";
import { getImageDimensions } from "domains/commons/utils/getImageDimensions";
import { FileImageType } from "domains/file-manager/interfaces";
import { useAssetCrop } from "domains/image/hooks/useAssetCrop";
import { useAssetUpload } from "domains/image/hooks/useAssetUpload";
import {
  findLearningRateId,
  FORM_TYPE,
  INITIAL_FORM,
  SDXL_DEFAULT_DATA,
} from "domains/models/constants/train";
import { TrainForm, TrainImage } from "domains/models/interfaces/train";
import {
  getDefaultTrainingSteps,
  getTrainingParameters,
} from "domains/models/utils/train";
import { useScenarioToast } from "domains/notification/hooks/useScenarioToast";
import { useTeamContext } from "domains/teams/contexts/TeamProvider";
import { usePlanContext } from "domains/teams/hooks/usePlan";
import { AnalyticsEvents } from "infra/analytics/constants/Events";
import Track from "infra/analytics/Track";
import { useHandleApiError } from "infra/api/error";
import {
  GetModelsByModelIdApiResponse,
  useGetModelsByModelIdQuery,
  usePostModelsMutation,
  usePostModelsTrainingImagesByModelIdMutation,
  usePutModelsByModelIdMutation,
  usePutModelsTagsByModelIdMutation,
  usePutModelsTrainByModelIdMutation,
} from "infra/api/generated/api";
import { apiSlice } from "infra/store/apiSlice";
import { API_TAGS } from "infra/store/constants";
import store from "infra/store/store";
import _ from "lodash";

import { HStack, VStack } from "@chakra-ui/react";
import { skipToken } from "@reduxjs/toolkit/dist/query";

import ModelTrainBottom from "./Bottom";
import ModelTrainImages from "./Images";
import ModalCrop, { ModalCropProps } from "./ModalCrop";
import ModelTrainSidebar from "./Sidebar";

interface ModelTrainProps {
  id: string | undefined;
  isInsideModal?: boolean;
  initialAssets?: FileImageType[];
  onTrain?: () => void;
  onCreateModel?: (id: string) => void;
  onAlreadyTrained: () => void;
}

export default function ModelTrain({
  id,
  isInsideModal,
  initialAssets,
  onTrain,
  onCreateModel,
  onAlreadyTrained,
}: ModelTrainProps) {
  const router = useRouter();
  const { showLimitModal } = usePlanContext();
  const handleApiError = useHandleApiError();
  const { selectedTeam } = useTeamContext();
  const { successToast } = useScenarioToast();
  const [form, setForm] = useState<TrainForm | undefined>();
  const [remoteForm, setRemoteForm] = useState<TrainForm | undefined>();
  const [isSaving, setIsSaving] = useState<boolean>(false);
  const [isTraining, setIsTraining] = useState<boolean>(false);
  const [imagesToCrop, setImagesToCrop] = useState<TrainImage[]>([]);
  const [imagesWidths, setImagesWidths] = useState<{ [key: string]: number }>(
    {}
  );
  const imagesWidthsRef = useRef<{ [key: string]: number }>(imagesWidths);
  const [uploadedImagesCount, setUploadedImagesCount] = useState<
    number | undefined
  >();
  const [createModelTrigger] = usePostModelsMutation();
  const [addModelImageTrigger] = usePostModelsTrainingImagesByModelIdMutation();
  const [updateModelTrigger] = usePutModelsByModelIdMutation();
  const [updateModelTagsTrigger] = usePutModelsTagsByModelIdMutation();
  const [trainModelTrigger] = usePutModelsTrainByModelIdMutation();
  const { uploadImage } = useAssetUpload();
  const { cropImage } = useAssetCrop();

  const {
    isLoading: isModelLoading,
    data: modelData,
    refetch: refetchModel,
  } = useGetModelsByModelIdQuery(
    id
      ? {
          modelId: id,
          teamId: selectedTeam.id,
        }
      : skipToken
  );

  const model = modelData?.model;
  const lowResMax = form?.type === "sd-1_5" ? 400 : 800;

  const maxTrainingImages = (() => {
    if (selectedTeam.isFreePlan) {
      return 15;
    }
    if (form?.type === "flux.1-lora") {
      return 50;
    }
    if (form?.type === "sd-xl-lora") {
      if (form?.preset === "style") {
        return selectedTeam.plan === "cu-creator" ? 30 : 50;
      }
      if (form?.preset === "subject") {
        return 25;
      }
    }
    if (selectedTeam.plan === "cu-creator") {
      return 30;
    }
    if (selectedTeam.plan === "cu-pro") {
      return 50;
    }
    return 100;
  })();
  const isFormChanged = useMemo(
    () => !_.isEqual(form, remoteForm),
    [form, remoteForm]
  );
  const datasetSize = model?.trainingImages?.length ?? 0;
  const isSavable = !!form?.name && isFormChanged && !isModelLoading;
  const isNotEnoughTrainingImages = (model?.trainingImages?.length ?? 0) < 5;
  const isTooMuchTrainingImages =
    (model?.trainingImages?.length ?? 0) > maxTrainingImages;
  const isMissingClassSlug =
    form?.type === "sd-1_5" && form?.flow === "guided" && !form?.classSlug;
  const isGuessingDescription = useMemo(
    () =>
      (model?.trainingImages ?? []).findIndex(
        (image) => image.description === undefined
      ) >= 0,
    [model?.trainingImages]
  );
  const isLowResTrainingImages = useMemo(
    () =>
      (model?.trainingImages ?? []).findIndex(
        (image) =>
          imagesWidths[image.downloadUrl] === undefined ||
          imagesWidths[image.downloadUrl] < lowResMax
      ) >= 0,
    [model?.trainingImages, imagesWidths, lowResMax]
  );
  const isImageSizesLoading = useMemo(
    () =>
      (model?.trainingImages ?? []).findIndex(
        (image) => imagesWidths[image.downloadUrl] === undefined
      ) >= 0,
    [model?.trainingImages, imagesWidths]
  );

  const isTrainable =
    !!form?.name &&
    !isModelLoading &&
    !isNotEnoughTrainingImages &&
    !isTooMuchTrainingImages &&
    !isMissingClassSlug &&
    !isLowResTrainingImages &&
    !isImageSizesLoading &&
    !isGuessingDescription;
  const trainingError =
    (isNotEnoughTrainingImages && "notEnoughTrainingImages") ||
    (isTooMuchTrainingImages && "tooMuchTrainingImages") ||
    (isMissingClassSlug && "missingClassSlug") ||
    (isImageSizesLoading && "imageSizesLoading") ||
    (isLowResTrainingImages && "lowResTrainingImages") ||
    (isGuessingDescription && "guessingDescription") ||
    undefined;

  const minTrainingSteps = 10;
  const maxTrainingSteps =
    _.min([datasetSize * 500, form?.type === "flux.1-lora" ? 5_000 : 25_000]) ??
    0;
  const defaultTrainingSteps = form
    ? getDefaultTrainingSteps(form, datasetSize)
    : 0;

  // ----------------------------------

  const updateModelTags = useCallback(
    async (model: GetModelsByModelIdApiResponse["model"]) => {
      if (!form) throw new Error("missing form");

      const tagsToAdd = _.difference(form.tags, model.tags ?? []);
      const tagsToRemove = _.difference(model.tags ?? [], form.tags);

      const body = {
        add: tagsToAdd,
        delete: tagsToRemove,
      };

      if (body.add.length > 0 || body.delete.length > 0) {
        await updateModelTagsTrigger({
          teamId: selectedTeam.id,
          modelId: model.id,
          body,
        }).unwrap();

        return async () => {
          await updateModelTagsTrigger({
            teamId: selectedTeam.id,
            modelId: model.id,
            body: {
              add: body.delete,
              delete: body.add,
            },
          }).unwrap();
        };
      }
    },
    [updateModelTagsTrigger, form, selectedTeam.id]
  );

  const updateModel = useCallback(
    async (
      model: GetModelsByModelIdApiResponse["model"],
      isBeforeTrain?: boolean
    ) => {
      if (!form) throw new Error("missing form");
      const undoActions: (() => Promise<void>)[] = [];
      try {
        const undoTags = await updateModelTags(model);
        if (undoTags) undoActions.push(undoTags);

        await updateModelTrigger({
          teamId: selectedTeam.id,
          modelId: model.id,
          body: getTrainingParameters(form, datasetSize, { isBeforeTrain })
            .body,
        }).unwrap();

        setRemoteForm(form);
      } catch (err) {
        for (const undoAction of undoActions) {
          await undoAction();
        }
        throw err;
      }
    },
    [
      updateModelTags,
      updateModelTrigger,
      setRemoteForm,
      selectedTeam.id,
      form,
      datasetSize,
    ]
  );

  const createModel = useCallback(async () => {
    if (!form) throw new Error("missing form");
    const newModel = await createModelTrigger({
      teamId: selectedTeam.id,
      body: {
        name: form.name,
        classSlug: (() => {
          if (
            form.type === "sd-1_5" &&
            form.flow === "guided" &&
            form.classSlug &&
            form.classSlug !== "default"
          ) {
            return form.classSlug;
          } else {
            return undefined;
          }
        })(),
      },
    }).unwrap();

    if (onCreateModel) {
      onCreateModel(newModel.model.id);
    }

    await updateModel(newModel.model);
    return newModel.model.id;
  }, [createModelTrigger, updateModel, onCreateModel, selectedTeam.id, form]);

  const save = useCallback(async () => {
    if (!form) return;
    try {
      setIsSaving(true);
      if (!id) {
        await createModel();
      } else if (model) {
        await updateModel(model);
      } else {
        return;
      }
      successToast({ title: "Model saved" });
    } catch (error) {
      handleApiError(error, "Error saving model");
    } finally {
      setIsSaving(false);
    }
  }, [
    createModel,
    setIsSaving,
    successToast,
    updateModel,
    handleApiError,
    id,
    form,
    model,
  ]);

  const train = useCallback(async () => {
    if (!model || !form || isTraining) return;
    try {
      setIsTraining(true);

      await updateModel(model, true);
      await trainModelTrigger({
        modelId: model.id,
        teamId: selectedTeam.id,
        body: {},
      }).unwrap();

      Track(AnalyticsEvents.CreateModel.TrainingQueued, {
        modelId: model.id,
      });
      await refetchModel();
      if (onTrain) onTrain();
      successToast({ title: "Model training has been queued" });
    } catch (error: any) {
      handleApiError(error, "Error queuing training", {
        quota: () => {
          const reason = _.get(error, "data.reason");
          if (_.get(error, "data.details.remainingSeconds")) {
            showLimitModal("planCooldown", {
              timeout: error.data.details.remainingSeconds,
              type: "training",
            });
          } else if (reason.includes("minute")) {
            showLimitModal("planModelMinuteTrainingRateLimit");
          } else if (reason.includes("hourly")) {
            showLimitModal("planModelHourlyTrainingRateLimit");
          } else if (reason.includes("daily")) {
            showLimitModal("planModelDailyTrainingRateLimit");
          } else if (
            _.get(error, "data.details.actionName") === "parallel-training"
          ) {
            showLimitModal("parallelTrainings");
          } else if (
            _.get(error, "data.reason").includes(
              "You have reached your plan's limit."
            )
          ) {
            showLimitModal("notEnoughCreativeUnits");
          }
        },
      });
    } finally {
      setIsTraining(false);
    }
  }, [
    model,
    form,
    isTraining,
    updateModel,
    trainModelTrigger,
    selectedTeam.id,
    refetchModel,
    onTrain,
    successToast,
    handleApiError,
    showLimitModal,
  ]);

  // ----------------------------------

  const handleDeleteAll = useCallback(async () => {
    await refetchModel();
  }, [refetchModel]);

  const handleImagesToCropChange = useCallback<ModalCropProps["onChange"]>(
    (images) => {
      setImagesToCrop(images);
    },
    [setImagesToCrop]
  );

  const handleImagesToCropUpload = useCallback<ModalCropProps["onSubmit"]>(
    async (images) => {
      const modelId = model ? model.id : await createModel();

      setUploadedImagesCount(0);

      for (const image of images) {
        try {
          if (!image.cropped) {
            const assetId =
              image.original.assetId ??
              (
                await uploadImage({
                  imageData: image.original.src,
                  name: "training-image",
                  preventCompression: true,
                })
              )?.id;
            if (!assetId) {
              return;
            }

            await addModelImageTrigger({
              modelId: modelId,
              teamId: selectedTeam.id,
              body: {
                assetId: assetId,
              },
            }).unwrap();
          } else {
            const croppedAsset = await cropImage({
              asset: image.original,
              cropData: image.cropped.metadata,
              preventCompression: true,
            });
            if (!croppedAsset) {
              return;
            }

            await addModelImageTrigger({
              modelId: modelId,
              teamId: selectedTeam.id,
              body: {
                assetId: croppedAsset.id,
              },
            }).unwrap();
          }
        } catch (error) {
          handleApiError(error, "Error uploading training image");
          return;
        }
        setUploadedImagesCount((count) => (count ?? 0) + 1);
      }

      setImagesToCrop([]);
      setUploadedImagesCount(undefined);

      successToast({
        title: `${images.length} ${
          images.length === 1 ? "image" : "images"
        } uploaded`,
      });
      Track(AnalyticsEvents.CreateModel.AddedTrainingImage, {
        modelId: modelId,
      });

      setTimeout(async () => {
        if (model) {
          await refetchModel();
        } else {
          store.dispatch(
            apiSlice.util.invalidateTags([
              {
                type: API_TAGS.model,
                id: `modelId:${modelId}`,
              },
            ])
          );
        }
      }, 100);
    },
    [
      model,
      createModel,
      successToast,
      addModelImageTrigger,
      selectedTeam.id,
      handleApiError,
      uploadImage,
      cropImage,
      refetchModel,
    ]
  );

  const handleImageChange = useCallback(async () => {
    await refetchModel();
  }, [refetchModel]);

  // ----------------------------------

  imagesWidthsRef.current = imagesWidths;

  useEffect(() => {
    if (form || !router.isReady) return;

    const isRetrain = !!extractFirstQueryParam(router.query.retrain);
    const newForm: TrainForm = { ...INITIAL_FORM };

    if (id && model) {
      newForm.type = _.includes(FORM_TYPE, model.type)
        ? (model.type as TrainForm["type"])
        : "flux.1-lora";

      if (model.name) {
        newForm.name = model.name;
      }

      if (isRetrain && newForm.type !== "flux.1-lora") {
        newForm.type = "flux.1-lora";
      } else {
        if (
          model.parameters?.maxTrainSteps &&
          model.parameters?.maxTrainSteps !== defaultTrainingSteps
        ) {
          newForm.totalTrainingSteps = model.parameters.maxTrainSteps;
        }

        if (model.parameters?.learningRate) {
          newForm.learningRate =
            findLearningRateId(model.parameters.learningRate) ??
            newForm.learningRate;
        }

        if (model.parameters?.learningRateUnet) {
          newForm.learningRateUnet =
            findLearningRateId(model.parameters.learningRateUnet) ??
            newForm.learningRateUnet;
        }

        if (model.parameters?.textEncoderTrainingRatio) {
          newForm.textEncoderTrainingRatio =
            model.parameters.textEncoderTrainingRatio;
        }

        if (model.parameters?.learningRateTextEncoder) {
          newForm.textEncoderLearningRate =
            findLearningRateId(model.parameters.learningRateTextEncoder) ??
            newForm.textEncoderLearningRate;
        }

        if (newForm.type === "sd-1_5" && model.parameters?.conceptPrompt) {
          newForm.flow = "unguided";
        } else if (newForm.type === "sd-1_5") {
          newForm.flow = "guided";
          newForm.classSlug = model.class?.slug ?? "default";
        }

        if (newForm.type === "sd-xl-lora") {
          newForm.preset = (() => {
            const possiblePresets = (() => {
              if (model.parameters?.conceptPrompt === "daiton style") {
                return ["style" as const];
              } else {
                return ["subject" as const, "custom" as const];
              }
            })();
            const possiblePresetData = possiblePresets.map(
              (preset) => SDXL_DEFAULT_DATA[preset]
            );
            return (
              possiblePresets.find((preset, index) => {
                const data = possiblePresetData[index];
                return _.isEqual(data, _.pick(newForm, _.keys(data)));
              }) ?? "custom"
            );
          })();
        }

        if (newForm.type === "flux.1-lora") {
          newForm.isOptimizedForLikeness =
            model.parameters?.optimizeFor === "likeness";
        }
      }

      setForm(newForm);
      setRemoteForm(newForm);

      if (isRetrain) {
        void router.replace({
          pathname: router.pathname,
          query: _.omit(router.query, "retrain"),
        });
      }
    } else if (!id) {
      setForm(newForm);
    }
  }, [router, id, form, model, defaultTrainingSteps, setForm, setRemoteForm]);

  useEffect(() => {
    if (!form || !model) return;
    if (form.totalTrainingSteps && form.totalTrainingSteps < minTrainingSteps) {
      setForm({ ...form, totalTrainingSteps: minTrainingSteps });
    } else if (
      form.totalTrainingSteps &&
      form.totalTrainingSteps > maxTrainingSteps
    ) {
      setForm({ ...form, totalTrainingSteps: maxTrainingSteps });
    }
  }, [setForm, id, form, model, minTrainingSteps, maxTrainingSteps]);

  useEffect(() => {
    const interval = setInterval(async () => {
      const hasMissingDescription = model?.trainingImages?.some(
        (image) =>
          image.description === undefined ||
          image.automaticCaptioning === undefined
      );
      if (hasMissingDescription) {
        await refetchModel();
      } else {
        clearInterval(interval);
      }
    }, 2_000);
    return () => {
      if (interval) clearInterval(interval);
    };
  }, [refetchModel, model]);

  useEffect(() => {
    if (model?.status !== "trained") return;
    onAlreadyTrained();
  }, [model, onAlreadyTrained]);

  useEffect(() => {
    void (async () => {
      const urls = (model?.trainingImages ?? []).map(
        (asset) => asset.downloadUrl
      );
      const widths: { [key: string]: number } = {};
      for (const url of urls) {
        widths[url] =
          imagesWidthsRef.current[url] ?? (await getImageDimensions(url)).width;
      }
      setImagesWidths(widths);
    })();
  }, [model?.trainingImages]);

  // ----------------------------------

  const haveCaptionsBeenUpdated = useMemo(() => {
    return model?.trainingImages !== undefined
      ? model.trainingImages.some(
          (image) =>
            image.automaticCaptioning !== undefined &&
            image.description !== image.automaticCaptioning
        )
      : false;
  }, [model]);

  return (
    <>
      <HStack
        align="stretch"
        justify="stretch"
        w="100%"
        h={isInsideModal ? "100%" : `calc(100vh - ${HEADER_HEIGHT}px)`}
        spacing={0}
      >
        <ModelTrainSidebar
          model={model}
          isTitleHidden={isInsideModal}
          form={form}
          setForm={setForm}
          isSavable={isSavable}
          isSaving={isSaving}
          isTrainable={isTrainable}
          isTraining={isTraining}
          trainingError={trainingError}
          defaultTrainingSteps={defaultTrainingSteps}
          minTrainingSteps={minTrainingSteps}
          maxTrainingSteps={maxTrainingSteps}
          onSaveClick={save}
          onTrainClick={train}
          haveCaptionsBeenUpdated={haveCaptionsBeenUpdated}
        />

        <VStack align="stretch" flex={1} spacing={0}>
          <ModelTrainImages
            model={model}
            lowResMax={lowResMax}
            initialAssets={initialAssets}
            maxTrainingImages={maxTrainingImages}
            onImageUpload={handleImagesToCropChange}
            onDeleteAll={handleDeleteAll}
            onImageChange={handleImageChange}
          />
          <ModelTrainBottom
            form={form}
            trainingImagesCount={datasetSize}
            maxTrainingImages={maxTrainingImages}
          />
        </VStack>
      </HStack>

      <ModalCrop
        images={imagesToCrop}
        lowResMax={lowResMax}
        uploadedImagesCount={uploadedImagesCount}
        onSubmit={handleImagesToCropUpload}
        onChange={handleImagesToCropChange}
      />
    </>
  );
}
