import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { CredentialsParams, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { normalizeInferenceProviderMapping } from "../utils/normalizeInferenceProviderMapping";

export const MODEL_EXPAND_KEYS = [
	"pipeline_tag",
	"private",
	"gated",
	"downloads",
	"likes",
	"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];

export const MODEL_EXPANDABLE_KEYS = [
	"author",
	"cardData",
	"config",
	"createdAt",
	"disabled",
	"downloads",
	"downloadsAllTime",
	"gated",
	"gitalyUid",
	"inferenceProviderMapping",
	"lastModified",
	"library_name",
	"likes",
	"model-index",
	"pipeline_tag",
	"private",
	"safetensors",
	"sha",
	"spaces",
	"tags",
	"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];

export interface ModelDerivedFields {
	filePaths: string[];
}

export const MODEL_DERIVED_FIELD_TO_API_KEY: Record<keyof ModelDerivedFields, keyof ApiModelInfo> = {
	filePaths: "siblings",
};

export type ModelAdditionalField =
	| Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]>
	| keyof ModelDerivedFields;

export type ResolveModelAdditionalFields<T extends ModelAdditionalField> = Pick<ApiModelInfo, T & keyof ApiModelInfo> &
	Pick<ModelDerivedFields, T & keyof ModelDerivedFields>;

export interface ModelEntry {
	id: string;
	name: string;
	private: boolean;
	gated: false | "auto" | "manual";
	task?: PipelineType;
	likes: number;
	downloads: number;
	updatedAt: Date;
}

export async function* listModels<const T extends ModelAdditionalField = never>(
	params?: {
		search?: {
			/**
			 * Will search in the model name for matches
			 */
			query?: string;
			owner?: string;
			task?: PipelineType;
			tags?: string[];
			/**
			 * Will search for models that have one of the inference providers in the list.
			 */
			inferenceProviders?: string[];
			/**
			 * Will search for models that support at least one of those local apps (eg "lmstudio", "mlx-lm", ...)
			 */
			apps?: string[];
		};
		hubUrl?: string;
		additionalFields?: T[];
		/**
		 * Set to limit the number of models returned.
		 */
		limit?: number;
		/**
		 * Sort models by a specific field.
		 */
		sort?:
			| "createdAt"
			| "downloads"
			| "likes"
			| "lastModified"
			| "likes30d"
			| "trendingScore"
			| "num_parameters"
			| "mainSize"
			| "id";
		/**
		 * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
		 */
		fetch?: typeof fetch;
	} & Partial<CredentialsParams>,
): AsyncGenerator<ModelEntry & ResolveModelAdditionalFields<T>> {
	const accessToken = params && checkCredentials(params);
	let totalToFetch = params?.limit ?? Infinity;
	const additionalExpandKeys =
		params?.additionalFields?.map(
			(field) => MODEL_DERIVED_FIELD_TO_API_KEY[field as keyof ModelDerivedFields] ?? field,
		) ?? [];
	const search = new URLSearchParams([
		...Object.entries({
			limit: String(Math.min(totalToFetch, 500)),
			...(params?.search?.owner ? { author: params.search.owner } : undefined),
			...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
			...(params?.search?.query ? { search: params.search.query } : undefined),
			...(params?.search?.inferenceProviders
				? { inference_provider: params.search.inferenceProviders.join(",") }
				: undefined),
			...(params?.search?.apps ? { apps: params.search.apps.join(",") } : undefined),
			...(params?.sort ? { sort: params.sort } : undefined),
		}),
		...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
		...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
		...additionalExpandKeys.map((val) => ["expand", val] satisfies [string, string]),
	]).toString();
	let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;

	while (url) {
		const res: Response = await (params?.fetch ?? fetch)(url, {
			headers: {
				accept: "application/json",
				...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
			},
		});

		if (!res.ok) {
			throw await createApiError(res);
		}

		const items: ApiModelInfo[] = await res.json();

		for (const item of items) {
			const additional: Record<string, unknown> = {};
			if (params?.additionalFields) {
				for (const field of params.additionalFields) {
					if (field === "filePaths") {
						additional.filePaths = (item.siblings ?? []).map((s) => s.rfilename);
					} else if (field === "inferenceProviderMapping" && item.inferenceProviderMapping) {
						additional.inferenceProviderMapping = normalizeInferenceProviderMapping(
							item.id,
							item.inferenceProviderMapping,
						);
					} else {
						additional[field] = item[field as keyof ApiModelInfo];
					}
				}
			}

			yield {
				...additional,
				id: item._id,
				name: item.id,
				private: item.private,
				task: item.pipeline_tag,
				downloads: item.downloads,
				gated: item.gated,
				likes: item.likes,
				updatedAt: new Date(item.lastModified),
			} as ModelEntry & ResolveModelAdditionalFields<T>;
			totalToFetch--;

			if (totalToFetch <= 0) {
				return;
			}
		}

		const linkHeader = res.headers.get("Link");

		url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
		// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
	}
}
