import { assert, it, describe } from "vitest";
import {
	parseSafetensorsMetadata,
	parseSafetensorsShardFilename,
	globMatch,
	isQuantizedTensor,
} from "./parse-safetensors-metadata";
import { sum } from "../utils/sum";

describe("parseSafetensorsMetadata", () => {
	it("fetch info for single-file (with the default conventional filename)", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "google-bert/bert-base-uncased",
			computeParametersCount: true,
			revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
		});

		assert(!parse.sharded);
		assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });

		// Example of one tensor (the header contains many tensors)

		assert.deepStrictEqual(parse.header["bert.embeddings.LayerNorm.beta"], {
			dtype: "F32",
			shape: [768],
			data_offsets: [0, 3072],
		});

		assert.deepStrictEqual(parse.parameterCount, { F32: 110_106_428 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 110_106_428);
		// total params = 110m

		assert.deepStrictEqual(parse.filepaths, ["model.safetensors"]);
	});

	it("fetch info for sharded (with the default conventional filename)", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "bigscience/bloom",
			computeParametersCount: true,
			revision: "053d9cd9fbe814e091294f67fcfedb3397b954bb",
		});

		assert(parse.sharded);

		assert.strictEqual(Object.keys(parse.headers).length, 72);
		// This model has 72 shards!

		// Example of one tensor inside one file

		assert.deepStrictEqual(parse.headers["model_00012-of-00072.safetensors"]["h.10.input_layernorm.weight"], {
			dtype: "BF16",
			shape: [14336],
			data_offsets: [3288649728, 3288678400],
		});

		assert.deepStrictEqual(parse.parameterCount, { BF16: 176_247_271_424 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 176_247_271_424);
		// total params = 176B

		assert.strictEqual(parse.filepaths[0], "model.safetensors.index.json");
		assert.strictEqual(parse.filepaths.length, 73); // 1 index + 72 shards
		assert.ok(parse.filepaths.includes("model_00012-of-00072.safetensors"));
	});

	it("fetch info for single-file with multiple dtypes", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "roberta-base",
			computeParametersCount: true,
			revision: "e2da8e2f811d1448a5b465c236feacd80ffbac7b",
		});

		assert(!parse.sharded);

		assert.deepStrictEqual(parse.parameterCount, { F32: 124_697_433, I64: 514 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 124_697_947);
		// total params = 124m
	});

	it("fetch info for single-file with file path", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "CompVis/stable-diffusion-v1-4",
			computeParametersCount: true,
			path: "unet/diffusion_pytorch_model.safetensors",
			revision: "133a221b8aa7292a167afc5127cb63fb5005638b",
		});

		assert(!parse.sharded);
		assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });

		// Example of one tensor (the header contains many tensors)

		assert.deepStrictEqual(parse.header["up_blocks.3.resnets.0.norm2.bias"], {
			dtype: "F32",
			shape: [320],
			data_offsets: [3_409_382_416, 3_409_383_696],
		});

		assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);

		assert.deepStrictEqual(parse.filepaths, ["unet/diffusion_pytorch_model.safetensors"]);
	});

	it("fetch info for sharded with file path", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "Alignment-Lab-AI/ALAI-gemma-7b",
			computeParametersCount: true,
			path: "7b/1/model.safetensors.index.json",
			revision: "37e307261fe97bbf8b2463d61dbdd1a10daa264c",
		});

		assert(parse.sharded);

		assert.strictEqual(Object.keys(parse.headers).length, 4);

		assert.deepStrictEqual(parse.headers["model-00004-of-00004.safetensors"]["model.layers.24.mlp.up_proj.weight"], {
			dtype: "BF16",
			shape: [24576, 3072],
			data_offsets: [301996032, 452990976],
		});

		assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);

		assert.strictEqual(parse.filepaths[0], "7b/1/model.safetensors.index.json");
		assert.strictEqual(parse.filepaths.length, 5); // 1 index + 4 shards
		assert.ok(parse.filepaths.includes("7b/1/model-00001-of-00004.safetensors"));
		assert.ok(parse.filepaths.includes("7b/1/model-00004-of-00004.safetensors"));
	});

	it("fetch info for sharded, but get param count directly from metadata", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "hf-internal-testing/sharded-model-metadata-num-parameters",
			computeParametersCount: true,
			revision: "999395eb3db277f3d7a0393402b02486ca91cef8",
		});

		assert(parse.sharded);
		assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
		// total params = 109M
	});

	it("fetch info for single-file, but get param count directly from metadata", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "hf-internal-testing/single-file-model",
			computeParametersCount: true,
			revision: "75fcd3fed0285ac7f1092897ff2aefdf24bf872e",
		});

		assert(!parse.sharded);
		assert.deepStrictEqual(parse.parameterTotal, 109_482_240);
	});

	it("should detect sharded safetensors filename", async () => {
		const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
		const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);

		assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_");
		assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
		assert.strictEqual(safetensorsShardFileInfo?.shard, "00005");
		assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
	});

	it("should detect sharded safetensors filename with 6 digits", async () => {
		const safetensorsFilename = "model-00001-of-000163.safetensors"; // https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/model-00001-of-000163.safetensors
		const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);

		assert.strictEqual(safetensorsShardFileInfo?.prefix, "model-");
		assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
		assert.strictEqual(safetensorsShardFileInfo?.shard, "00001");
		assert.strictEqual(safetensorsShardFileInfo?.total, "000163");
	});

	it("should support sub-byte data types", async () => {
		const newDataTypes: Array<"F4" | "F6_E2M3" | "F6_E3M2" | "E8M0"> = ["F4", "F6_E2M3", "F6_E3M2", "E8M0"];

		for (const dtype of newDataTypes) {
			const tensorInfo = {
				dtype,
				shape: [1, 2],
				data_offsets: [0, 1] as [number, number],
			};

			assert.ok(typeof tensorInfo.dtype === "string");
			assert.ok(["F4", "F6_E2M3", "F6_E3M2", "E8M0"].includes(tensorInfo.dtype));
		}
	});

	it("should handle parameter counting with sub-byte data types", () => {
		const mockHeader = {
			tensor_f4: {
				dtype: "F4" as const,
				shape: [10, 20],
				data_offsets: [0, 100] as [number, number],
			},
			tensor_f6_e2m3: {
				dtype: "F6_E2M3" as const,
				shape: [5, 10],
				data_offsets: [100, 150] as [number, number],
			},
			tensor_f6_e3m2: {
				dtype: "F6_E3M2" as const,
				shape: [8, 12],
				data_offsets: [150, 246] as [number, number],
			},
			tensor_e8m0: {
				dtype: "E8M0" as const,
				shape: [4, 6],
				data_offsets: [246, 270] as [number, number],
			},
			__metadata__: { format: "pt" },
		};

		const computeNumOfParamsByDtypeSingleFile = (header: typeof mockHeader) => {
			const counter: Partial<Record<string, number>> = {};
			const tensors = Object.fromEntries(Object.entries(header).filter(([key]) => key !== "__metadata__"));

			for (const [, v] of Object.entries(tensors) as [
				string,
				{ dtype: string; shape: number[]; data_offsets: [number, number] },
			][]) {
				if (v.shape.length === 0) {
					continue;
				}
				counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a: number, b: number) => a * b);
			}
			return counter;
		};

		const parameterCount = computeNumOfParamsByDtypeSingleFile(mockHeader);

		assert.strictEqual(parameterCount.F4, 200);
		assert.strictEqual(parameterCount.F6_E2M3, 50);
		assert.strictEqual(parameterCount.F6_E3M2, 96);
		assert.strictEqual(parameterCount.E8M0, 24);
	});

	it("fetch info for GPTQ quantized 8B model", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16",
			revision: "3921b6aee65496a708b0af456c964ceca7423193",
			computeParametersCount: true,
		});

		const parameterCount = parse.parameterCount;
		assert.ok(parameterCount);
		assert.ok(parameterCount.I32);
		assert.ok(parameterCount.F16);
		assert.strictEqual(parameterCount.I32, 6_979_321_856);
		assert.strictEqual(parameterCount.F16, 1_052_315_648);

		const parameterCountTotal =
			parse.parameterTotal ??
			sum(
				Object.entries(parameterCount)
					.filter(([, value]) => typeof value === "number")
					.map(([, value]) => value as number),
			);

		assert.strictEqual(parameterCountTotal, 8_031_637_504);
	});

	it("fetch info for openai/gpt-oss-20b (large sharded model)", async () => {
		const parse = await parseSafetensorsMetadata({
			repo: "openai/gpt-oss-20b",
			computeParametersCount: true,
			revision: "bbf09307421df45099c1e7dcbd64e3106ce5b403",
		});

		assert(parse.sharded);

		assert.ok(Object.keys(parse.headers).length > 1);
		assert.ok(parse.parameterCount);

		const totalParams = parse.parameterTotal || sum(Object.values(parse.parameterCount));

		assert.strictEqual(totalParams, 21_511_953_984); // 21.5B

		assert.ok(parse.parameterCount.BF16 && parse.parameterCount.U8);

		assert.strictEqual(Object.keys(parse.headers).length, 3);
	});

	it("should support FP4 and UE8 data types in type system", () => {
		const newDataTypes: Array<"FP4" | "UE8"> = ["FP4", "UE8"];

		for (const dtype of newDataTypes) {
			const tensorInfo = {
				dtype,
				shape: [1, 2],
				data_offsets: [0, 1] as [number, number],
			};

			assert.ok(typeof tensorInfo.dtype === "string");
			assert.ok(["FP4", "UE8"].includes(tensorInfo.dtype));
		}

		const mockHeader = {
			tensor_fp4: {
				dtype: "FP4" as const,
				shape: [100, 200],
				data_offsets: [0, 5000] as [number, number],
			},
			tensor_ue8: {
				dtype: "UE8" as const,
				shape: [50, 100],
				data_offsets: [5000, 10000] as [number, number],
			},
			__metadata__: { format: "pt" },
		};

		const computeNumOfParamsByDtypeSingleFile = (header: typeof mockHeader) => {
			const counter: Partial<Record<string, number>> = {};
			const tensors = Object.fromEntries(Object.entries(header).filter(([key]) => key !== "__metadata__"));

			for (const [, v] of Object.entries(tensors) as [
				string,
				{ dtype: string; shape: number[]; data_offsets: [number, number] },
			][]) {
				if (v.shape.length === 0) {
					continue;
				}
				counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a: number, b: number) => a * b);
			}
			return counter;
		};

		const parameterCount = computeNumOfParamsByDtypeSingleFile(mockHeader);

		assert.strictEqual(parameterCount.FP4, 20000);
		assert.strictEqual(parameterCount.UE8, 5000);
	});

	describe("globMatch", () => {
		it("exact match when no wildcard", () => {
			assert.strictEqual(globMatch("foo", "foo"), true);
			assert.strictEqual(globMatch("foo", "foobar"), false);
			assert.strictEqual(globMatch("foo", "xfoo"), false);
			assert.strictEqual(globMatch("foo", "xfoox"), false);
		});

		it("single leading wildcard (*.ext)", () => {
			assert.strictEqual(globMatch("*.txt", "file.txt"), true);
			assert.strictEqual(globMatch("*.txt", ".txt"), true);
			assert.strictEqual(globMatch("*.txt", "file.txt.bak"), false);
			assert.strictEqual(globMatch("*.txt", "txt"), false);
		});

		it("single trailing wildcard (prefix.*)", () => {
			assert.strictEqual(globMatch("model.*", "model.bin"), true);
			assert.strictEqual(globMatch("model.*", "model."), true);
			assert.strictEqual(globMatch("model.*", "my_model.bin"), false);
		});

		it("wildcard on both sides (*mid*)", () => {
			assert.strictEqual(globMatch("*layer*", "model.layer.weight"), true);
			assert.strictEqual(globMatch("*layer*", "layer"), true);
			assert.strictEqual(globMatch("*layer*", "no_match"), false);
		});

		it("multiple wildcards", () => {
			assert.strictEqual(globMatch("a*b*c", "abc"), true);
			assert.strictEqual(globMatch("a*b*c", "aXXbYYc"), true);
			assert.strictEqual(globMatch("a*b*c", "aXXbYY"), false);
			assert.strictEqual(globMatch("a*b*c", "XXbYYc"), false);
		});

		it("wildcard-only pattern matches anything", () => {
			assert.strictEqual(globMatch("*", "anything"), true);
			assert.strictEqual(globMatch("*", ""), true);
		});

		it("typical quantization config patterns", () => {
			assert.strictEqual(globMatch("lm_head", "lm_head"), true);
			assert.strictEqual(globMatch("lm_head", "model.lm_head"), false);
			assert.strictEqual(globMatch("*lm_head*", "model.lm_head.weight"), true);
		});

		it("bare module names match via substring in isQuantizedTensor context", () => {
			// globMatch itself is a strict glob matcher — no wildcard means exact match
			assert.strictEqual(globMatch("lm_head", "model.lm_head.weight"), false);
			// But isQuantizedTensor uses substring matching for bare names (no *)
			// to match Python transformers behavior. See isQuantizedTensor tests below.
		});
	});

	describe("isQuantizedTensor", () => {
		const makeConfig = (modules: string[]) => ({
			quant_method: "bitsandbytes" as const,
			modules_to_not_convert: modules,
		});

		it("returns false when no quantization config", () => {
			assert.strictEqual(isQuantizedTensor("model.layer.weight", undefined), false);
		});

		it("returns true when modules_to_not_convert is empty", () => {
			assert.strictEqual(isQuantizedTensor("model.layer.weight", makeConfig([])), true);
		});

		it("bare module name excludes tensors containing that substring (Python compat)", () => {
			const config = makeConfig(["lm_head"]);
			assert.strictEqual(isQuantizedTensor("model.lm_head.weight", config), false);
			assert.strictEqual(isQuantizedTensor("lm_head", config), false);
			assert.strictEqual(isQuantizedTensor("lm_head.weight", config), false);
			assert.strictEqual(isQuantizedTensor("model.embed_tokens.weight", config), true);
		});

		it("glob pattern with wildcards uses globMatch", () => {
			const config = makeConfig(["*lm_head*"]);
			assert.strictEqual(isQuantizedTensor("model.lm_head.weight", config), false);
			assert.strictEqual(isQuantizedTensor("model.embed_tokens.weight", config), true);
		});

		it("multiple exclusion patterns", () => {
			const config = makeConfig(["lm_head", "embed_tokens"]);
			assert.strictEqual(isQuantizedTensor("model.lm_head.weight", config), false);
			assert.strictEqual(isQuantizedTensor("model.embed_tokens.weight", config), false);
			assert.strictEqual(isQuantizedTensor("model.layers.0.self_attn.q_proj.weight", config), true);
		});
	});

	it("fetch info for moonshotai/Kimi-K2.5 (large index file >20MB)", async () => {
		// This model has a ~23.5MB index file due to having many experts
		const parse = await parseSafetensorsMetadata({
			repo: "moonshotai/Kimi-K2.5",
			revision: "2426b45b6af0da48d0dcce71bbce6225e5c73adc",
			computeParametersCount: true,
		});

		assert(parse.sharded);
		assert.strictEqual(Object.keys(parse.headers).length, 64);
		assert.deepStrictEqual(parse.parameterCount, { F32: 23_040, I32: 1_014_687_129_600, BF16: 43_902_267_888 });
		assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 1_058_589_420_528);
	});
});
