import {
  AssertionValueFunctionContext,
  AssertionValueFunctionResult,
  GradingResult,
} from 'promptfoo';
import { z } from 'zod';

import { TwilioAgentProvider, TwilioProvider } from '../providers';
import { TwilioProviderResponse } from '../providers/twilio';

type TwilioProviderContext = AssertionValueFunctionContext & {
  provider: TwilioAgentProvider | TwilioProvider;
  providerResponse: TwilioProviderResponse;
  config?: {
    expectedTools: {
      name?: string;
      input?: string;
      output?: string;
    }[];
  };
};

function isTwilioProviderContext(
  context: AssertionValueFunctionContext,
): context is TwilioProviderContext {
  if (
    !(context.provider instanceof TwilioAgentProvider) &&
    !(context.provider instanceof TwilioProvider)
  ) {
    return false;
  }

  return context.providerResponse?.metadata?.sessionId;
}

const MessageBase = z.object({
  account_sid: z.string(),
  assistant_id: z.string(),
  date_created: z.string(),
  date_updated: z.string(),
  id: z.string(),
  identity: z.string(),
});

export function sanitizeFunctionCallName(name: string): string {
  return name.replace(/[^a-zA-Z0-9_-]+/g, '_');
}

const MessageSchema = z.discriminatedUnion('role', [
  z
    .object({
      role: z.literal('user'),
      content: z.object({
        content: z.string(),
      }),
      meta: z.object({}),
    })
    .merge(MessageBase),
  z
    .object({
      role: z.literal('assistant'),
      content: z.object({
        content: z.string(),
      }),
      meta: z.object({
        tokens: z.object({
          completionTokens: z.number(),
          promptTokens: z.number(),
          totalTokens: z.number(),
        }),
      }),
    })
    .merge(MessageBase),
  z
    .object({
      role: z.literal('tool'),
      content: z.object({
        input: z.string(),
        output: z.string(),
        name: z.string(),
      }),
      meta: z.object({}),
    })
    .merge(MessageBase),
]);

export type Message = z.infer<typeof MessageSchema>;

const History = z.object({
  messages: z.array(MessageSchema),
  meta: z.object({
    first_page_url: z.string().or(z.null()),
    next_page_url: z.string().or(z.null()),
    previous_page_url: z.string().or(z.null()),
    url: z.string(),
    key: z.literal('messages'),
    page: z.number(),
    page_size: z.number(),
  }),
});

async function getHistory(
  sessionId: string,
  authorizationHeader: string,
  domain: string,
) {
  const url = new URL(`/v1/Sessions/${sessionId}/Messages`, domain);
  url.searchParams.append('PageSize', '100');
  const response = await fetch(url, {
    headers: {
      Authorization: authorizationHeader,
      'Content-Type': 'application/json',
    },
  });
  const body = await response.json();
  const result = History.parse(body);
  return result.messages;
}

export function findAllToolCalls(messages: Message[]) {
  return messages.filter((message) => message.role === 'tool');
}

export function findToolCallsForResponse(
  messages: Message[],
  response: string,
) {
  const indexOfAiResponse = messages.findIndex(
    (message) =>
      message.role === 'assistant' && message.content.content === response,
  );

  const indexOfUserMessage = messages.findIndex(
    (message) => message.role === 'user',
    indexOfAiResponse,
  );

  return findAllToolCalls(
    messages.slice(indexOfAiResponse, indexOfUserMessage),
  );
}

export async function usedTool(
  _output: string,
  context: AssertionValueFunctionContext,
): Promise<AssertionValueFunctionResult> {
  if (!isTwilioProviderContext(context)) {
    return {
      pass: false,
      score: 0,
      reason:
        'Assertion can only be used in with TwilioProvider or TwilioAgentProvider',
    };
  }

  const { provider, providerResponse } = context;
  let authorizationHeader = '';
  let url = '';
  if (provider instanceof TwilioAgentProvider) {
    authorizationHeader =
      // @ts-ignore
      provider.agentProviderInstance?.requestOptions?.headers?.Authorization ||
      '';
    url = provider.agentProviderInstance.defaultUrl;
  } else if (provider instanceof TwilioProvider) {
    // @ts-ignore
    authorizationHeader = provider.requestOptions.headers.Authorization || '';
    url = provider.defaultUrl;
  }

  if (!providerResponse.metadata?.sessionId) {
    return {
      pass: false,
      score: 0,
      reason: 'Invalid request',
    };
  }

  const history = await getHistory(
    providerResponse.metadata?.sessionId,
    authorizationHeader,
    url,
  );

  const tools = findAllToolCalls(history);
  const expectedTools = context.config?.expectedTools || [];

  const toolTests = expectedTools.map((expectedTool): GradingResult => {
    const tool = tools.find((t) => {
      if (
        expectedTool.name &&
        !sanitizeFunctionCallName(t.content.name).includes(
          sanitizeFunctionCallName(expectedTool.name),
        )
      ) {
        return false;
      }

      if (expectedTool.input && !t.content.input.includes(expectedTool.input)) {
        return false;
      }

      return !(
        expectedTool.output && !t.content.output.includes(expectedTool.output)
      );
    });

    return {
      pass: !!tool,
      score: tool ? 1 : 0,
      reason: tool
        ? `Tool ${JSON.stringify(expectedTool)} found`
        : `Tool ${JSON.stringify(expectedTool)} not found`,
      assertion: {
        type: 'javascript',
        value: tool?.content,
      },
    };
  });

  const pass = toolTests.every((test) => test.pass);

  return {
    pass,
    score: pass ? 1 : 0,
    reason: pass ? 'Tools used' : 'Tools not used',
    componentResults: toolTests,
  };
}
