import { GeneratedColumnConfig, and, eq, getTableColumns } from "drizzle-orm"
import {
  MySqlColumn,
  MySqlDatabase,
  boolean,
  int,
  mysqlTable,
  primaryKey,
  timestamp,
  varchar,
  PreparedQueryHKTBase,
  MySqlTableWithColumns,
  MySqlQueryResultHKT,
} from "drizzle-orm/mysql-core"

import type {
  Adapter,
  AdapterAccount,
  AdapterAccountType,
  AdapterSession,
  AdapterUser,
  VerificationToken,
  AdapterAuthenticator,
} from "@auth/core/adapters"
import { Awaitable } from "@auth/core/types"

export function defineTables(
  schema: Partial<DefaultMySqlSchema> = {}
): Required<DefaultMySqlSchema> {
  const usersTable =
    schema.usersTable ??
    (mysqlTable("user", {
      id: varchar("id", { length: 255 })
        .primaryKey()
        .$defaultFn(() => crypto.randomUUID()),
      name: varchar("name", { length: 255 }),
      email: varchar("email", { length: 255 }).unique(),
      emailVerified: timestamp("emailVerified", { mode: "date", fsp: 3 }),
      image: varchar("image", { length: 255 }),
    }) satisfies DefaultMySqlUsersTable)

  const accountsTable =
    schema.accountsTable ??
    (mysqlTable(
      "account",
      {
        userId: varchar("userId", { length: 255 })
          .notNull()
          .references(() => usersTable.id, { onDelete: "cascade" }),
        type: varchar("type", { length: 255 })
          .$type<AdapterAccountType>()
          .notNull(),
        provider: varchar("provider", { length: 255 }).notNull(),
        providerAccountId: varchar("providerAccountId", {
          length: 255,
        }).notNull(),
        refresh_token: varchar("refresh_token", { length: 255 }),
        access_token: varchar("access_token", { length: 255 }),
        expires_at: int("expires_at"),
        token_type: varchar("token_type", { length: 255 }),
        scope: varchar("scope", { length: 255 }),
        id_token: varchar("id_token", { length: 2048 }),
        session_state: varchar("session_state", { length: 255 }),
      },
      (account) => ({
        compositePk: primaryKey({
          columns: [account.provider, account.providerAccountId],
        }),
      })
    ) satisfies DefaultMySqlAccountsTable)

  const sessionsTable =
    schema.sessionsTable ??
    (mysqlTable("session", {
      sessionToken: varchar("sessionToken", { length: 255 }).primaryKey(),
      userId: varchar("userId", { length: 255 })
        .notNull()
        .references(() => usersTable.id, { onDelete: "cascade" }),
      expires: timestamp("expires", { mode: "date" }).notNull(),
    }) satisfies DefaultMySqlSessionsTable)

  const verificationTokensTable =
    schema.verificationTokensTable ??
    (mysqlTable(
      "verificationToken",
      {
        identifier: varchar("identifier", { length: 255 }).notNull(),
        token: varchar("token", { length: 255 }).notNull(),
        expires: timestamp("expires", { mode: "date" }).notNull(),
      },
      (verficationToken) => ({
        compositePk: primaryKey({
          columns: [verficationToken.identifier, verficationToken.token],
        }),
      })
    ) satisfies DefaultMySqlVerificationTokenTable)

  const authenticatorsTable =
    schema.authenticatorsTable ??
    (mysqlTable(
      "authenticator",
      {
        credentialID: varchar("credentialID", { length: 255 })
          .notNull()
          .unique(),
        userId: varchar("userId", { length: 255 })
          .notNull()
          .references(() => usersTable.id, { onDelete: "cascade" }),
        providerAccountId: varchar("providerAccountId", {
          length: 255,
        }).notNull(),
        credentialPublicKey: varchar("credentialPublicKey", {
          length: 255,
        }).notNull(),
        counter: int("counter").notNull(),
        credentialDeviceType: varchar("credentialDeviceType", {
          length: 255,
        }).notNull(),
        credentialBackedUp: boolean("credentialBackedUp").notNull(),
        transports: varchar("transports", { length: 255 }),
      },
      (authenticator) => ({
        compositePk: primaryKey({
          columns: [authenticator.userId, authenticator.credentialID],
        }),
      })
    ) satisfies DefaultMySqlAuthenticatorTable)

  return {
    usersTable,
    accountsTable,
    sessionsTable,
    verificationTokensTable,
    authenticatorsTable,
  }
}

export function MySqlDrizzleAdapter(
  client: MySqlDatabase<MySqlQueryResultHKT, PreparedQueryHKTBase, any>,
  schema?: DefaultMySqlSchema
): Adapter {
  const {
    usersTable,
    accountsTable,
    sessionsTable,
    verificationTokensTable,
    authenticatorsTable,
  } = defineTables(schema)

  return {
    async createUser(data: AdapterUser) {
      const { id, ...insertData } = data
      const hasDefaultId = getTableColumns(usersTable)["id"]["defaultFn"]

      const [insertedUser] = (await client
        .insert(usersTable)
        .values(hasDefaultId ? insertData : { ...insertData, id })
        .$returningId()) as [{ id: string }] | []

      return client
        .select()
        .from(usersTable)
        .where(eq(usersTable.id, insertedUser ? insertedUser.id : id))
        .then((res) => res[0]) as Awaitable<AdapterUser>
    },
    async getUser(userId: string) {
      return client
        .select()
        .from(usersTable)
        .where(eq(usersTable.id, userId))
        .then((res) =>
          res.length > 0 ? res[0] : null
        ) as Awaitable<AdapterUser | null>
    },
    async getUserByEmail(email: string) {
      return client
        .select()
        .from(usersTable)
        .where(eq(usersTable.email, email))
        .then((res) =>
          res.length > 0 ? res[0] : null
        ) as Awaitable<AdapterUser | null>
    },
    async createSession(data: {
      sessionToken: string
      userId: string
      expires: Date
    }) {
      await client.insert(sessionsTable).values(data)

      return client
        .select()
        .from(sessionsTable)
        .where(eq(sessionsTable.sessionToken, data.sessionToken))
        .then((res) => res[0])
    },
    async getSessionAndUser(sessionToken: string) {
      return client
        .select({
          session: sessionsTable,
          user: usersTable,
        })
        .from(sessionsTable)
        .where(eq(sessionsTable.sessionToken, sessionToken))
        .innerJoin(usersTable, eq(usersTable.id, sessionsTable.userId))
        .then((res) => (res.length > 0 ? res[0] : null)) as Awaitable<{
        session: AdapterSession
        user: AdapterUser
      } | null>
    },
    async updateUser(data: Partial<AdapterUser> & Pick<AdapterUser, "id">) {
      if (!data.id) {
        throw new Error("No user id.")
      }

      await client
        .update(usersTable)
        .set(data)
        .where(eq(usersTable.id, data.id))

      const [result] = await client
        .select()
        .from(usersTable)
        .where(eq(usersTable.id, data.id))

      if (!result) {
        throw new Error("No user found.")
      }

      return result as Awaitable<AdapterUser>
    },
    async updateSession(
      data: Partial<AdapterSession> & Pick<AdapterSession, "sessionToken">
    ) {
      await client
        .update(sessionsTable)
        .set(data)
        .where(eq(sessionsTable.sessionToken, data.sessionToken))

      return client
        .select()
        .from(sessionsTable)
        .where(eq(sessionsTable.sessionToken, data.sessionToken))
        .then((res) => res[0])
    },
    async linkAccount(data: AdapterAccount) {
      await client.insert(accountsTable).values(data)
    },
    async getUserByAccount(
      account: Pick<AdapterAccount, "provider" | "providerAccountId">
    ) {
      const result = await client
        .select({
          account: accountsTable,
          user: usersTable,
        })
        .from(accountsTable)
        .innerJoin(usersTable, eq(accountsTable.userId, usersTable.id))
        .where(
          and(
            eq(accountsTable.provider, account.provider),
            eq(accountsTable.providerAccountId, account.providerAccountId)
          )
        )
        .then((res) => res[0])

      const user = result?.user ?? null

      return user as Awaitable<AdapterUser | null>
    },
    async deleteSession(sessionToken: string) {
      await client
        .delete(sessionsTable)
        .where(eq(sessionsTable.sessionToken, sessionToken))
    },
    async createVerificationToken(data: VerificationToken) {
      await client.insert(verificationTokensTable).values(data)

      return client
        .select()
        .from(verificationTokensTable)
        .where(eq(verificationTokensTable.identifier, data.identifier))
        .then((res) => res[0])
    },
    async useVerificationToken(params: { identifier: string; token: string }) {
      const deletedToken = await client
        .select()
        .from(verificationTokensTable)
        .where(
          and(
            eq(verificationTokensTable.identifier, params.identifier),
            eq(verificationTokensTable.token, params.token)
          )
        )
        .then((res) => (res.length > 0 ? res[0] : null))

      if (deletedToken) {
        await client
          .delete(verificationTokensTable)
          .where(
            and(
              eq(verificationTokensTable.identifier, params.identifier),
              eq(verificationTokensTable.token, params.token)
            )
          )
      }

      return deletedToken
    },
    async deleteUser(id: string) {
      await client.delete(usersTable).where(eq(usersTable.id, id))
    },
    async unlinkAccount(
      params: Pick<AdapterAccount, "provider" | "providerAccountId">
    ) {
      await client
        .delete(accountsTable)
        .where(
          and(
            eq(accountsTable.provider, params.provider),
            eq(accountsTable.providerAccountId, params.providerAccountId)
          )
        )
    },
    async getAccount(providerAccountId: string, provider: string) {
      return client
        .select()
        .from(accountsTable)
        .where(
          and(
            eq(accountsTable.provider, provider),
            eq(accountsTable.providerAccountId, providerAccountId)
          )
        )
        .then((res) => res[0] ?? null) as Promise<AdapterAccount | null>
    },
    async createAuthenticator(data: AdapterAuthenticator) {
      await client.insert(authenticatorsTable).values(data)

      return (await client
        .select()
        .from(authenticatorsTable)
        .where(eq(authenticatorsTable.credentialID, data.credentialID))
        .then((res) => res[0] ?? null)) as Awaitable<AdapterAuthenticator>
    },
    async getAuthenticator(credentialID: string) {
      return (await client
        .select()
        .from(authenticatorsTable)
        .where(eq(authenticatorsTable.credentialID, credentialID))
        .then(
          (res) => res[0] ?? null
        )) as Awaitable<AdapterAuthenticator | null>
    },
    async listAuthenticatorsByUserId(userId: string) {
      return (await client
        .select()
        .from(authenticatorsTable)
        .where(eq(authenticatorsTable.userId, userId))
        .then((res) => res)) as Awaitable<AdapterAuthenticator[]>
    },
    async updateAuthenticatorCounter(credentialID: string, newCounter: number) {
      await client
        .update(authenticatorsTable)
        .set({ counter: newCounter })
        .where(eq(authenticatorsTable.credentialID, credentialID))

      const authenticator = await client
        .select()
        .from(authenticatorsTable)
        .where(eq(authenticatorsTable.credentialID, credentialID))
        .then((res) => res[0])

      if (!authenticator) throw new Error("Authenticator not found.")

      return authenticator as Awaitable<AdapterAuthenticator>
    },
  }
}

type DefaultMyqlColumn<
  T extends {
    data: string | number | boolean | Date
    dataType: "string" | "number" | "boolean" | "date"
    notNull: boolean
    isPrimaryKey?: boolean
    columnType:
      | "MySqlVarChar"
      | "MySqlText"
      | "MySqlBoolean"
      | "MySqlTimestamp"
      | "MySqlInt"
  },
> = MySqlColumn<{
  isAutoincrement: boolean
  isPrimaryKey: T["isPrimaryKey"] extends true ? true : false
  hasRuntimeDefault: boolean
  generated: GeneratedColumnConfig<T["data"]> | undefined
  name: string
  columnType: T["columnType"]
  data: T["data"]
  driverParam: string | number | boolean
  notNull: T["notNull"]
  hasDefault: boolean
  enumValues: any
  dataType: T["dataType"]
  tableName: string
}>

export type DefaultMySqlUsersTable = MySqlTableWithColumns<{
  name: string
  columns: {
    id: DefaultMyqlColumn<{
      isPrimaryKey: true
      data: string
      dataType: "string"
      notNull: true
      columnType: "MySqlVarChar" | "MySqlText"
    }>
    name: DefaultMyqlColumn<{
      data: string
      dataType: "string"
      notNull: boolean
      columnType: "MySqlVarChar" | "MySqlText"
    }>
    email: DefaultMyqlColumn<{
      data: string
      dataType: "string"
      notNull: boolean
      columnType: "MySqlVarChar" | "MySqlText"
    }>
    emailVerified: DefaultMyqlColumn<{
      data: Date
      dataType: "date"
      notNull: boolean
      columnType: "MySqlTimestamp"
    }>
    image: DefaultMyqlColumn<{
      data: string
      dataType: "string"
      notNull: boolean
      columnType: "MySqlVarChar" | "MySqlText"
    }>
  }
  dialect: "mysql"
  schema: string | undefined
}>

export type DefaultMySqlAccountsTable = MySqlTableWithColumns<{
  name: string
  columns: {
    userId: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    type: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    provider: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    providerAccountId: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
    }>
    refresh_token: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: boolean
    }>
    access_token: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      driverParam: string | number
      notNull: boolean
    }>
    expires_at: DefaultMyqlColumn<{
      dataType: "number"
      columnType: "MySqlInt"
      data: number
      notNull: boolean
    }>
    token_type: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: boolean
    }>
    scope: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: boolean
    }>
    id_token: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: boolean
    }>
    session_state: DefaultMyqlColumn<{
      dataType: "string"
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: boolean
    }>
  }
  dialect: "mysql"
  schema: string | undefined
}>

export type DefaultMySqlSessionsTable = MySqlTableWithColumns<{
  name: string
  columns: {
    sessionToken: DefaultMyqlColumn<{
      isPrimaryKey: true
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    userId: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    expires: DefaultMyqlColumn<{
      dataType: "date"
      columnType: "MySqlTimestamp"
      data: Date
      notNull: true
    }>
  }
  dialect: "mysql"
  schema: string | undefined
}>

export type DefaultMySqlVerificationTokenTable = MySqlTableWithColumns<{
  name: string
  columns: {
    identifier: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    token: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    expires: DefaultMyqlColumn<{
      dataType: "date"
      columnType: "MySqlTimestamp"
      data: Date
      notNull: true
    }>
  }
  dialect: "mysql"
  schema: string | undefined
}>

export type DefaultMySqlAuthenticatorTable = MySqlTableWithColumns<{
  name: string
  columns: {
    credentialID: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    userId: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    providerAccountId: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    credentialPublicKey: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    counter: DefaultMyqlColumn<{
      columnType: "MySqlInt"
      data: number
      notNull: true
      dataType: "number"
    }>
    credentialDeviceType: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: true
      dataType: "string"
    }>
    credentialBackedUp: DefaultMyqlColumn<{
      columnType: "MySqlBoolean"
      data: boolean
      notNull: true
      dataType: "boolean"
    }>
    transports: DefaultMyqlColumn<{
      columnType: "MySqlVarChar" | "MySqlText"
      data: string
      notNull: false
      dataType: "string"
    }>
  }
  dialect: "mysql"
  schema: string | undefined
}>

export type DefaultMySqlSchema = {
  usersTable: DefaultMySqlUsersTable
  accountsTable: DefaultMySqlAccountsTable
  sessionsTable?: DefaultMySqlSessionsTable
  verificationTokensTable?: DefaultMySqlVerificationTokenTable
  authenticatorsTable?: DefaultMySqlAuthenticatorTable
}
