import { CallApiContextParams, ProviderResponse } from 'promptfoo';
import { beforeEach, expect, test, vi } from 'vitest';

import { AgentProvider, AgentProviderOptions } from './agent';

const mocked = vi.hoisted(() => {
  const callApi = vi.fn();
  return {
    callApi,
    loadApiProvider: vi.fn().mockResolvedValue({ callApi }),
  };
});

vi.mock('promptfoo', () => {
  return {
    providers: {
      loadApiProvider: mocked.loadApiProvider,
    },
  };
});

beforeEach(() => {
  vi.clearAllMocks();
  vi.mocked(mocked.callApi).mockClear();
  vi.mocked(mocked.loadApiProvider).mockClear();
});

test('AgentProvider should initialize with default values', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {},
  };

  const agentProvider = new AgentProvider(options);

  expect(agentProvider.id()).toBe('test-agent');
});

test('AgentProvider should load providers', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {
      userProvider: { id: 'user-provider' },
      agentProvider: { id: 'agent-provider' },
    },
  };

  mocked.callApi.mockResolvedValue({ output: '###STOP###' });
  const agentProvider = new AgentProvider(options);
  await agentProvider.callApi('test prompt');

  expect(mocked.loadApiProvider).toHaveBeenCalledWith('user-provider', {
    options: { id: 'user-provider' },
  });
  expect(mocked.loadApiProvider).toHaveBeenCalledWith('agent-provider', {
    options: { id: 'agent-provider' },
  });
});

test('AgentProvider should handle conversation flow', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {
      userProvider: { id: 'user-provider' },
      agentProvider: { id: 'agent-provider' },
      maxTurns: 2,
    },
  };

  const agentProvider = new AgentProvider(options);

  mocked.callApi
    .mockResolvedValueOnce({ output: 'User response' })
    .mockResolvedValueOnce({ output: 'Assistant response' })
    .mockResolvedValueOnce({ output: 'User response 2' })
    .mockResolvedValueOnce({ output: 'Assistant response 2 ###STOP###' });

  const response: ProviderResponse = await agentProvider.callApi('test prompt');

  expect(response.output).toContain('User: User response');
  expect(response.output).toContain('Assistant: Assistant response');
  expect(response.output).toContain('User: User response 2');
  expect(response.output).toContain(
    'Assistant: Assistant response 2 ###STOP###',
  );
});

test('AgentProvider should stop conversation when ###STOP### is received', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {
      userProvider: { id: 'user-provider' },
      agentProvider: { id: 'agent-provider' },
      maxTurns: 10,
    },
  };

  const agentProvider = new AgentProvider(options);

  mocked.callApi
    .mockResolvedValueOnce({ output: 'User response' })
    .mockResolvedValueOnce({ output: 'Assistant response ###STOP###' });

  const response: ProviderResponse = await agentProvider.callApi('test prompt');

  expect(response.output).toContain('User: User response');
  expect(response.output).toContain('Assistant: Assistant response ###STOP###');
  expect(mocked.callApi).toHaveBeenCalledTimes(2);
});

test('AgentProvider should handle maxTurns from context', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {
      userProvider: { id: 'user-provider' },
      agentProvider: { id: 'agent-provider' },
      maxTurns: 5,
    },
  };

  const agentProvider = new AgentProvider(options);

  mocked.callApi
    .mockResolvedValueOnce({ output: 'User response' })
    .mockResolvedValueOnce({ output: 'Assistant response' })
    .mockResolvedValueOnce({ output: 'User response 2' })
    .mockResolvedValueOnce({ output: 'Assistant response 2' })
    .mockResolvedValueOnce({ output: 'User response 3' })
    .mockResolvedValueOnce({ output: 'Assistant response 3 ###STOP###' });

  const response: ProviderResponse = await agentProvider.callApi(
    'test prompt',
    { vars: { maxTurns: '3' } } as unknown as CallApiContextParams,
  );

  expect(response.output).toContain('User: User response');
  expect(response.output).toContain('Assistant: Assistant response');
  expect(response.output).toContain('User: User response 2');
  expect(response.output).toContain('Assistant: Assistant response 2');
  expect(response.output).toContain('User: User response 3');
  expect(response.output).toContain(
    'Assistant: Assistant response 3 ###STOP###',
  );
  expect(mocked.callApi).toHaveBeenCalledTimes(6);
});

test('AgentProvider should handle invalid maxTurns from context', async () => {
  const options: AgentProviderOptions = {
    id: 'test-agent',
    config: {
      userProvider: { id: 'user-provider' },
      agentProvider: { id: 'agent-provider' },
      maxTurns: 5,
    },
  };

  const agentProvider = new AgentProvider(options);

  mocked.callApi
    .mockResolvedValueOnce({ output: 'User response' })
    .mockResolvedValueOnce({ output: 'Assistant response' })
    .mockResolvedValueOnce({ output: 'User response 2' })
    .mockResolvedValueOnce({ output: 'Assistant response 2' })
    .mockResolvedValueOnce({ output: 'User response 3' })
    .mockResolvedValueOnce({ output: 'Assistant response 3 ###STOP###' });

  const response: ProviderResponse = await agentProvider.callApi(
    'test prompt',
    {
      vars: { maxTurns: 'invalid' },
    } as unknown as CallApiContextParams,
  );

  expect(response.output).toContain('User: User response');
  expect(response.output).toContain('Assistant: Assistant response');
  expect(response.output).toContain('User: User response 2');
  expect(response.output).toContain('Assistant: Assistant response 2');
  expect(response.output).toContain('User: User response 3');
  expect(response.output).toContain(
    'Assistant: Assistant response 3 ###STOP###',
  );
  expect(mocked.callApi).toHaveBeenCalledTimes(6);
});
