import { ResourceSource, LabelEnum, PixelData } from '../../types/common';
import {
  DeeplabLabel,
  ModelNameOf,
  SemanticSegmentationModelSources,
  SemanticSegmentationConfig,
  SemanticSegmentationModelName,
  SelfieSegmentationLabel,
} from '../../types/semanticSegmentation';
import { IMAGENET1K_MEAN, IMAGENET1K_STD } from '../../constants/commonVision';
import {
  fetchModelPath,
  ResolveLabels as ResolveLabelsFor,
  VisionLabeledModule,
} from './VisionLabeledModule';

const PascalVocSegmentationConfig = {
  labelMap: DeeplabLabel,
  preprocessorConfig: {
    normMean: IMAGENET1K_MEAN,
    normStd: IMAGENET1K_STD,
  },
};
const ModelConfigs = {
  'deeplab-v3-resnet50': PascalVocSegmentationConfig,
  'deeplab-v3-resnet101': PascalVocSegmentationConfig,
  'deeplab-v3-mobilenet-v3-large': PascalVocSegmentationConfig,
  'lraspp-mobilenet-v3-large': PascalVocSegmentationConfig,
  'fcn-resnet50': PascalVocSegmentationConfig,
  'fcn-resnet101': PascalVocSegmentationConfig,
  'deeplab-v3-resnet50-quantized': PascalVocSegmentationConfig,
  'deeplab-v3-resnet101-quantized': PascalVocSegmentationConfig,
  'deeplab-v3-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
  'lraspp-mobilenet-v3-large-quantized': PascalVocSegmentationConfig,
  'fcn-resnet50-quantized': PascalVocSegmentationConfig,
  'fcn-resnet101-quantized': PascalVocSegmentationConfig,
  'selfie-segmentation': {
    labelMap: SelfieSegmentationLabel,
    preprocessorConfig: undefined,
  },
} as const satisfies Record<
  SemanticSegmentationModelName,
  SemanticSegmentationConfig<LabelEnum>
>;

/** @internal */
type ModelConfigsType = typeof ModelConfigs;

/**
 * Resolves the {@link LabelEnum} for a given built-in model name.
 * @typeParam M - A built-in model name from {@link SemanticSegmentationModelName}.
 * @category Types
 */
export type SegmentationLabels<M extends SemanticSegmentationModelName> =
  ModelConfigsType[M]['labelMap'];

/** @internal */
type ResolveLabels<T extends SemanticSegmentationModelName | LabelEnum> =
  ResolveLabelsFor<T, ModelConfigsType>;

/**
 * Generic semantic segmentation module with type-safe label maps.
 * Use a model name (e.g. `'deeplab-v3-resnet50'`) as the generic parameter for built-in models,
 * or a custom label enum for custom configs.
 * @typeParam T - Either a built-in model name (`'deeplab-v3-resnet50'`,
 *   `'deeplab-v3-resnet50-quantized'`, `'deeplab-v3-resnet101'`,
 *   `'deeplab-v3-resnet101-quantized'`, `'deeplab-v3-mobilenet-v3-large'`,
 *   `'deeplab-v3-mobilenet-v3-large-quantized'`, `'lraspp-mobilenet-v3-large'`,
 *   `'lraspp-mobilenet-v3-large-quantized'`, `'fcn-resnet50'`,
 *   `'fcn-resnet50-quantized'`, `'fcn-resnet101'`, `'fcn-resnet101-quantized'`,
 *   `'selfie-segmentation'`) or a custom {@link LabelEnum} label map.
 * @category Typescript API
 */
export class SemanticSegmentationModule<
  T extends SemanticSegmentationModelName | LabelEnum,
> extends VisionLabeledModule<
  Record<'ARGMAX', Int32Array> & Record<keyof ResolveLabels<T>, Float32Array>,
  ResolveLabels<T>
> {
  private constructor(labelMap: ResolveLabels<T>, nativeModule: unknown) {
    super(labelMap, nativeModule);
  }

  /**
   * Creates a segmentation instance for a built-in model.
   * The config object is discriminated by `modelName` — each model can require different fields.
   * @param namedSources - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from.
   * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
   * @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the chosen model's label map.
   * @example
   * ```ts
   * const segmentation = await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50);
   * ```
   */

  static async fromModelName<C extends SemanticSegmentationModelSources>(
    namedSources: C,
    onDownloadProgress: (progress: number) => void = () => {}
  ): Promise<SemanticSegmentationModule<ModelNameOf<C>>> {
    const { modelName, modelSource } = namedSources;
    const { labelMap } = ModelConfigs[modelName];
    const { preprocessorConfig } = ModelConfigs[
      modelName
    ] as SemanticSegmentationConfig<LabelEnum>;
    const normMean = preprocessorConfig?.normMean ?? [];
    const normStd = preprocessorConfig?.normStd ?? [];
    const allClassNames = Object.keys(labelMap).filter((k) => isNaN(Number(k)));
    const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
    const nativeModule = await global.loadSemanticSegmentation(
      modelPath,
      normMean,
      normStd,
      allClassNames
    );
    return new SemanticSegmentationModule<ModelNameOf<C>>(
      labelMap as ResolveLabels<ModelNameOf<C>>,
      nativeModule
    );
  }

  /**
   * Creates a segmentation instance with a user-provided model binary and label map.
   * Use this when working with a custom-exported segmentation model that is not one of the built-in models.
   * Internally uses `'custom'` as the model name for telemetry unless overridden.
   *
   * ## Required model contract
   *
   * The `.pte` model binary must expose a single `forward` method with the following interface:
   *
   * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in
   * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`.
   * H and W are read from the model's declared input shape at load time.
   *
   * **Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw
   * logits — one channel per class, in the same order as the entries in your `labelMap`.
   * For binary segmentation a single-channel output is also supported: channel 0 is treated
   * as the foreground probability and a synthetic background channel is added automatically.
   *
   * Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to
   * original dimensions) are handled by the native runtime.
   * @param modelSource - A fetchable resource pointing to the model binary.
   * @param config - A {@link SemanticSegmentationConfig} object with the label map and optional preprocessing parameters.
   * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
   * @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the provided label map.
   * @example
   * ```ts
   * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const;
   * const segmentation = await SemanticSegmentationModule.fromCustomModel(
   *   'https://example.com/custom_model.pte',
   *   { labelMap: MyLabels },
   * );
   * ```
   */
  static async fromCustomModel<L extends LabelEnum>(
    modelSource: ResourceSource,
    config: SemanticSegmentationConfig<L>,
    onDownloadProgress: (progress: number) => void = () => {}
  ): Promise<SemanticSegmentationModule<L>> {
    const normMean = config.preprocessorConfig?.normMean ?? [];
    const normStd = config.preprocessorConfig?.normStd ?? [];
    const allClassNames = Object.keys(config.labelMap).filter((k) =>
      isNaN(Number(k))
    );
    const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
    const nativeModule = await global.loadSemanticSegmentation(
      modelPath,
      normMean,
      normStd,
      allClassNames
    );
    return new SemanticSegmentationModule<L>(
      config.labelMap as ResolveLabels<L>,
      nativeModule
    );
  }

  /**
   * Executes the model's forward pass to perform semantic segmentation on the provided image.
   *
   * Supports two input types:
   * 1. **String path/URI**: File path, URL, or Base64-encoded string
   * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage)
   *
   * **Note**: For VisionCamera frame processing, use `runOnFrame` instead.
   * @param input - Image source (string or PixelData object)
   * @param classesOfInterest - An optional list of label keys indicating which per-class probability masks to include in the output. `ARGMAX` is always returned regardless.
   * @param resizeToInput - Whether to resize the output masks to the original input image dimensions. If `false`, returns the raw model output dimensions. Defaults to `true`.
   * @returns A Promise resolving to an object with an `'ARGMAX'` key mapped to an `Int32Array` of per-pixel class indices, and each requested class label mapped to a `Float32Array` of per-pixel probabilities.
   * @throws {RnExecutorchError} If the model is not loaded.
   */
  override async forward<K extends keyof ResolveLabels<T>>(
    input: string | PixelData,
    classesOfInterest: K[] = [],
    resizeToInput: boolean = true
  ): Promise<Record<'ARGMAX', Int32Array> & Record<K, Float32Array>> {
    const classesOfInterestNames = classesOfInterest.map(String);
    return super.forward(input, classesOfInterestNames, resizeToInput);
  }
}
