import { addConsumeAwareSignal, addToEnd } from './utils'
import type {
  OmitKeyof,
  QueryFunction,
  QueryFunctionContext,
  QueryKey,
} from './types'

type BaseStreamedQueryParams<TQueryFnData, TQueryKey extends QueryKey> = {
  streamFn: (
    context: QueryFunctionContext<TQueryKey>,
  ) => AsyncIterable<TQueryFnData> | Promise<AsyncIterable<TQueryFnData>>
  refetchMode?: 'append' | 'reset' | 'replace'
}

type SimpleStreamedQueryParams<
  TQueryFnData,
  TQueryKey extends QueryKey,
> = BaseStreamedQueryParams<TQueryFnData, TQueryKey> & {
  reducer?: never
  initialValue?: never
}

type ReducibleStreamedQueryParams<
  TQueryFnData,
  TData,
  TQueryKey extends QueryKey,
> = BaseStreamedQueryParams<TQueryFnData, TQueryKey> & {
  reducer: (acc: TData, chunk: TQueryFnData) => TData
  initialValue: TData
}

type StreamedQueryParams<TQueryFnData, TData, TQueryKey extends QueryKey> =
  | SimpleStreamedQueryParams<TQueryFnData, TQueryKey>
  | ReducibleStreamedQueryParams<TQueryFnData, TData, TQueryKey>

/**
 * This is a helper function to create a query function that streams data from an AsyncIterable.
 * Data will be an Array of all the chunks received.
 * The query will be in a 'pending' state until the first chunk of data is received, but will go to 'success' after that.
 * The query will stay in fetchStatus 'fetching' until the stream ends.
 * @param queryFn - The function that returns an AsyncIterable to stream data from.
 * @param refetchMode - Defines how re-fetches are handled.
 * Defaults to `'reset'`, erases all data and puts the query back into `pending` state.
 * Set to `'append'` to append new data to the existing data.
 * Set to `'replace'` to write all data to the cache once the stream ends.
 * @param reducer - A function to reduce the streamed chunks into the final data.
 * Defaults to a function that appends chunks to the end of the array.
 * @param initialValue - Initial value to be used while the first chunk is being fetched, and returned if the stream yields no values.
 */
export function streamedQuery<
  TQueryFnData = unknown,
  TData = Array<TQueryFnData>,
  TQueryKey extends QueryKey = QueryKey,
>({
  streamFn,
  refetchMode = 'reset',
  reducer = (items, chunk) =>
    addToEnd(items as Array<TQueryFnData>, chunk) as TData,
  initialValue = [] as TData,
}: StreamedQueryParams<TQueryFnData, TData, TQueryKey>): QueryFunction<
  TData,
  TQueryKey
> {
  return async (context) => {
    const query = context.client
      .getQueryCache()
      .find({ queryKey: context.queryKey, exact: true })
    const isRefetch = !!query && query.isFetched()
    if (isRefetch && refetchMode === 'reset') {
      query.setState({
        ...query.resetState,
        fetchStatus: 'fetching',
      })
    }

    let result = initialValue

    let cancelled: boolean = false as boolean
    const streamFnContext = addConsumeAwareSignal<
      OmitKeyof<typeof context, 'signal'>
    >(
      {
        client: context.client,
        meta: context.meta,
        queryKey: context.queryKey,
        pageParam: context.pageParam,
        direction: context.direction,
      },
      () => context.signal,
      () => (cancelled = true),
    )

    const stream = await streamFn(streamFnContext)

    const isReplaceRefetch = isRefetch && refetchMode === 'replace'

    for await (const chunk of stream) {
      if (cancelled) {
        break
      }

      if (isReplaceRefetch) {
        // don't append to the cache directly when replace-refetching
        result = reducer(result, chunk)
      } else {
        context.client.setQueryData<TData>(context.queryKey, (prev) =>
          reducer(prev === undefined ? initialValue : prev, chunk),
        )
      }
    }

    // finalize result: replace-refetching needs to write to the cache
    if (isReplaceRefetch && !cancelled) {
      context.client.setQueryData<TData>(context.queryKey, result)
    }

    return context.client.getQueryData(context.queryKey) ?? initialValue
  }
}
