import { describe, expect, it } from "vitest";
import type { ModelEntry } from "./list-models";
import { listModels } from "./list-models";

describe("listModels", () => {
	it("should list models for depth estimation", async () => {
		const results: ModelEntry[] = [];

		for await (const entry of listModels({
			search: { owner: "Intel", task: "depth-estimation" },
		})) {
			if (typeof entry.downloads === "number") {
				entry.downloads = 0;
			}
			if (typeof entry.likes === "number") {
				entry.likes = 0;
			}
			if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
				entry.updatedAt = new Date(0);
			}

			if (!["Intel/dpt-large", "Intel/dpt-hybrid-midas"].includes(entry.name)) {
				expect(entry.task).to.equal("depth-estimation");
				continue;
			}

			results.push(entry);
		}

		results.sort((a, b) => a.id.localeCompare(b.id));

		expect(results).deep.equal([
			{
				id: "621ffdc136468d709f17e709",
				name: "Intel/dpt-large",
				private: false,
				gated: false,
				downloads: 0,
				likes: 0,
				task: "depth-estimation",
				updatedAt: new Date(0),
			},
			{
				id: "638f07977559bf9a2b2b04ac",
				name: "Intel/dpt-hybrid-midas",
				gated: false,
				private: false,
				downloads: 0,
				likes: 0,
				task: "depth-estimation",
				updatedAt: new Date(0),
			},
		]);
	});

	it("should list indonesian models with gguf format", async () => {
		let count = 0;
		for await (const entry of listModels({
			search: { tags: ["gguf", "id"] },
			additionalFields: ["tags"],
			limit: 2,
		})) {
			count++;
			expect(entry.tags).to.include("gguf");
			expect(entry.tags).to.include("id");
		}

		expect(count).to.equal(2);
	});

	it("should search model by name", async () => {
		let count = 0;
		for await (const entry of listModels({
			search: { query: "t5" },
			limit: 10,
		})) {
			count++;
			expect(entry.name.toLocaleLowerCase()).to.include("t5");
		}

		expect(count).to.equal(10);
	});

	it("should search model by inference provider", async () => {
		let count = 0;
		for await (const entry of listModels({
			search: { inferenceProviders: ["together"] },
			additionalFields: ["inferenceProviderMapping"],
			limit: 10,
		})) {
			count++;
			if (Array.isArray(entry.inferenceProviderMapping)) {
				expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together");
			}
		}

		expect(count).to.equal(10);
	});

	it("should search model by several inference providers", async () => {
		let count = 0;
		const inferenceProviders = ["together", "replicate"];
		for await (const entry of listModels({
			search: { inferenceProviders },
			additionalFields: ["inferenceProviderMapping"],
			limit: 10,
		})) {
			count++;
			if (Array.isArray(entry.inferenceProviderMapping)) {
				expect(
					entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length
				).toBeGreaterThan(0);
			}
		}

		expect(count).to.equal(10);
	});
});
