import { defaultAbiCoder } from '@ethersproject/abi'
import { abi as IAstraCLPoolABI } from '@airdao/astra-cl-core/artifacts/contracts/interfaces/IAstraCLPool.sol/IAstraCLPool.json'
import { Fixture } from 'ethereum-waffle'
import { BigNumber, constants, ContractTransaction, Wallet } from 'ethers'
import { solidityPack } from 'ethers/lib/utils'
import { ethers, waffle } from 'hardhat'
import { IAstraCLPool, ISAMB, MockTimeSwapRouter02, TestERC20 } from '../typechain'
import completeFixture from './shared/completeFixture'
import { ADDRESS_THIS, FeeAmount, MSG_SENDER, TICK_SPACINGS } from './shared/constants'
import { encodePriceSqrt } from './shared/encodePriceSqrt'
import { expandTo18Decimals } from './shared/expandTo18Decimals'
import { expect } from './shared/expect'
import { encodePath } from './shared/path'
import snapshotGasCost from './shared/snapshotGasCost'
import { getMaxTick, getMinTick } from './shared/ticks'

describe('SwapRouter gas tests', function () {
  this.timeout(40000)
  let wallet: Wallet
  let trader: Wallet

  const swapRouterFixture: Fixture<{
    samb: ISAMB
    router: MockTimeSwapRouter02
    tokens: [TestERC20, TestERC20, TestERC20]
    pools: [IAstraCLPool, IAstraCLPool, IAstraCLPool]
  }> = async (wallets, provider) => {
    const { samb, factory, router, tokens, nft } = await completeFixture(wallets, provider)

    // approve & fund wallets
    for (const token of tokens) {
      await token.approve(router.address, constants.MaxUint256)
      await token.approve(nft.address, constants.MaxUint256)
      await token.connect(trader).approve(router.address, constants.MaxUint256)
      await token.transfer(trader.address, expandTo18Decimals(1_000_000))
    }

    const liquidity = 1000000
    async function createPool(tokenAddressA: string, tokenAddressB: string) {
      if (tokenAddressA.toLowerCase() > tokenAddressB.toLowerCase())
        [tokenAddressA, tokenAddressB] = [tokenAddressB, tokenAddressA]

      await nft.createAndInitializePoolIfNecessary(
        tokenAddressA,
        tokenAddressB,
        FeeAmount.MEDIUM,
        encodePriceSqrt(100005, 100000) // we don't want to cross any ticks
      )

      const liquidityParams = {
        token0: tokenAddressA,
        token1: tokenAddressB,
        fee: FeeAmount.MEDIUM,
        tickLower: getMinTick(TICK_SPACINGS[FeeAmount.MEDIUM]),
        tickUpper: getMaxTick(TICK_SPACINGS[FeeAmount.MEDIUM]),
        recipient: wallet.address,
        amount0Desired: 1000000,
        amount1Desired: 1000000,
        amount0Min: 0,
        amount1Min: 0,
        deadline: 2 ** 32,
      }

      return nft.mint(liquidityParams)
    }

    async function createPoolSAMB(tokenAddress: string) {
      await samb.deposit({ value: liquidity * 2 })
      await samb.approve(nft.address, constants.MaxUint256)
      return createPool(samb.address, tokenAddress)
    }

    // create pools
    await createPool(tokens[0].address, tokens[1].address)
    await createPool(tokens[1].address, tokens[2].address)
    await createPoolSAMB(tokens[0].address)

    const poolAddresses = await Promise.all([
      factory.getPool(tokens[0].address, tokens[1].address, FeeAmount.MEDIUM),
      factory.getPool(tokens[1].address, tokens[2].address, FeeAmount.MEDIUM),
      factory.getPool(samb.address, tokens[0].address, FeeAmount.MEDIUM),
    ])

    const pools = poolAddresses.map((poolAddress) => new ethers.Contract(poolAddress, IAstraCLPoolABI, wallet)) as [
      IAstraCLPool,
      IAstraCLPool,
      IAstraCLPool
    ]

    return {
      samb,
      router,
      tokens,
      pools,
    }
  }

  let samb: ISAMB
  let router: MockTimeSwapRouter02
  let tokens: [TestERC20, TestERC20, TestERC20]
  let pools: [IAstraCLPool, IAstraCLPool, IAstraCLPool]

  let loadFixture: ReturnType<typeof waffle.createFixtureLoader>

  function encodeUnwrapSAMB(amount: number) {
    return solidityPack(
      ['bytes4', 'bytes'],
      [router.interface.getSighash('unwrapSAMB(uint256)'), defaultAbiCoder.encode(['uint256'], [amount])]
    )
  }

  function encodeSweep(token: string, amount: number) {
    const functionSignature = 'sweepToken(address,uint256)'
    return solidityPack(
      ['bytes4', 'bytes'],
      [
        router.interface.getSighash(functionSignature),
        defaultAbiCoder.encode((router.interface.functions as any)[functionSignature].inputs, [token, amount]),
      ]
    )
  }

  before('create fixture loader', async () => {
    const wallets = await (ethers as any).getSigners()
    ;[wallet, trader] = wallets

    loadFixture = waffle.createFixtureLoader(wallets)
  })

  beforeEach('load fixture', async () => {
    ;({ router, samb, tokens, pools } = await loadFixture(swapRouterFixture))
  })

  async function exactInput(
    tokens: string[],
    amountIn: number = 2,
    amountOutMinimum: number = 1
  ): Promise<ContractTransaction> {
    const inputIsWETH = samb.address === tokens[0]
    const outputIsSAMB = tokens[tokens.length - 1] === samb.address

    const value = inputIsWETH ? amountIn : 0

    const params = {
      path: encodePath(tokens, new Array(tokens.length - 1).fill(FeeAmount.MEDIUM)),
      recipient: outputIsSAMB ? ADDRESS_THIS : MSG_SENDER,
      amountIn,
      amountOutMinimum: outputIsSAMB ? 0 : amountOutMinimum, // save on calldata
    }

    const data = [router.interface.encodeFunctionData('exactInput', [params])]
    if (outputIsSAMB) {
      data.push(encodeUnwrapSAMB(amountOutMinimum))
    }

    return router.connect(trader)['multicall(uint256,bytes[])'](1, data, { value })
  }

  async function exactInputSingle(
    tokenIn: string,
    tokenOut: string,
    amountIn: number = 3,
    amountOutMinimum: number = 1,
    sqrtPriceLimitX96?: BigNumber
  ): Promise<ContractTransaction> {
    const inputIsWETH = samb.address === tokenIn
    const outputIsSAMB = tokenOut === samb.address

    const value = inputIsWETH ? amountIn : 0

    const params = {
      tokenIn,
      tokenOut,
      fee: FeeAmount.MEDIUM,
      recipient: outputIsSAMB ? ADDRESS_THIS : MSG_SENDER,
      amountIn,
      amountOutMinimum: outputIsSAMB ? 0 : amountOutMinimum, // save on calldata
      sqrtPriceLimitX96: sqrtPriceLimitX96 ?? 0,
    }

    const data = [router.interface.encodeFunctionData('exactInputSingle', [params])]
    if (outputIsSAMB) {
      data.push(encodeUnwrapSAMB(amountOutMinimum))
    }

    return router.connect(trader)['multicall(uint256,bytes[])'](1, data, { value })
  }

  async function exactOutput(tokens: string[]): Promise<ContractTransaction> {
    const amountInMaximum = 10 // we don't care
    const amountOut = 1

    const inputIsSAMB = tokens[0] === samb.address
    const outputIsSAMB = tokens[tokens.length - 1] === samb.address

    const value = inputIsSAMB ? amountInMaximum : 0

    const params = {
      path: encodePath(tokens.slice().reverse(), new Array(tokens.length - 1).fill(FeeAmount.MEDIUM)),
      recipient: outputIsSAMB ? ADDRESS_THIS : MSG_SENDER,
      amountOut,
      amountInMaximum,
    }

    const data = [router.interface.encodeFunctionData('exactOutput', [params])]
    if (inputIsSAMB) {
      data.push(router.interface.encodeFunctionData('refundAMB'))
    }

    if (outputIsSAMB) {
      data.push(encodeUnwrapSAMB(amountOut))
    }

    return router.connect(trader)['multicall(uint256,bytes[])'](1, data, { value })
  }

  async function exactOutputSingle(
    tokenIn: string,
    tokenOut: string,
    amountOut: number = 1,
    amountInMaximum: number = 3,
    sqrtPriceLimitX96?: BigNumber
  ): Promise<ContractTransaction> {
    const inputIsSAMB = tokenIn === samb.address
    const outputIsSAMB = tokenOut === samb.address

    const value = inputIsSAMB ? amountInMaximum : 0

    const params = {
      tokenIn,
      tokenOut,
      fee: FeeAmount.MEDIUM,
      recipient: outputIsSAMB ? ADDRESS_THIS : MSG_SENDER,
      amountOut,
      amountInMaximum,
      sqrtPriceLimitX96: sqrtPriceLimitX96 ?? 0,
    }

    const data = [router.interface.encodeFunctionData('exactOutputSingle', [params])]
    if (inputIsSAMB) {
      data.push(router.interface.encodeFunctionData('refundAMB'))
    }

    if (outputIsSAMB) {
      data.push(encodeUnwrapSAMB(amountOut))
    }

    return router.connect(trader)['multicall(uint256,bytes[])'](1, data, { value })
  }

  // TODO should really throw this in the fixture
  beforeEach('intialize feeGrowthGlobals', async () => {
    await exactInput([tokens[0].address, tokens[1].address], 1, 0)
    await exactInput([tokens[1].address, tokens[0].address], 1, 0)
    await exactInput([tokens[1].address, tokens[2].address], 1, 0)
    await exactInput([tokens[2].address, tokens[1].address], 1, 0)
    await exactInput([tokens[0].address, samb.address], 1, 0)
    await exactInput([samb.address, tokens[0].address], 1, 0)
  })

  beforeEach('ensure feeGrowthGlobals are >0', async () => {
    const slots = await Promise.all(
      pools.map((pool) =>
        Promise.all([
          pool.feeGrowthGlobal0X128().then((f) => f.toString()),
          pool.feeGrowthGlobal1X128().then((f) => f.toString()),
        ])
      )
    )

    expect(slots).to.deep.eq([
      ['340290874192793283295456993856614', '340290874192793283295456993856614'],
      ['340290874192793283295456993856614', '340290874192793283295456993856614'],
      ['340290874192793283295456993856614', '340290874192793283295456993856614'],
    ])
  })

  beforeEach('ensure ticks are 0 before', async () => {
    const slots = await Promise.all(pools.map((pool) => pool.slot0().then(({ tick }) => tick)))
    expect(slots).to.deep.eq([0, 0, 0])
  })

  afterEach('ensure ticks are 0 after', async () => {
    const slots = await Promise.all(pools.map((pool) => pool.slot0().then(({ tick }) => tick)))
    expect(slots).to.deep.eq([0, 0, 0])
  })

  describe('#exactInput', () => {
    it('0 -> 1', async () => {
      await snapshotGasCost(exactInput(tokens.slice(0, 2).map((token) => token.address)))
    })

    it('0 -> 1 minimal', async () => {
      const calleeFactory = await ethers.getContractFactory('TestAstraCLCallee')
      const callee = await calleeFactory.deploy()

      await tokens[0].connect(trader).approve(callee.address, constants.MaxUint256)
      await snapshotGasCost(callee.connect(trader).swapExact0For1(pools[0].address, 2, trader.address, '4295128740'))
    })

    it('0 -> 1 -> 2', async () => {
      await snapshotGasCost(
        exactInput(
          tokens.map((token) => token.address),
          3
        )
      )
    })

    it('SAMB -> 0', async () => {
      await snapshotGasCost(
        exactInput(
          [samb.address, tokens[0].address],
          samb.address.toLowerCase() < tokens[0].address.toLowerCase() ? 2 : 3
        )
      )
    })

    it('0 -> SAMB', async () => {
      await snapshotGasCost(
        exactInput(
          [tokens[0].address, samb.address],
          tokens[0].address.toLowerCase() < samb.address.toLowerCase() ? 2 : 3
        )
      )
    })

    it('2 trades (via router)', async () => {
      await samb.connect(trader).deposit({ value: 3 })
      await samb.connect(trader).approve(router.address, constants.MaxUint256)
      const swap0 = {
        path: encodePath([samb.address, tokens[0].address], [FeeAmount.MEDIUM]),
        recipient: ADDRESS_THIS,
        amountIn: 3,
        amountOutMinimum: 0, // save on calldata
      }

      const swap1 = {
        path: encodePath([tokens[1].address, tokens[0].address], [FeeAmount.MEDIUM]),
        recipient: ADDRESS_THIS,
        amountIn: 3,
        amountOutMinimum: 0, // save on calldata
      }

      const data = [
        router.interface.encodeFunctionData('exactInput', [swap0]),
        router.interface.encodeFunctionData('exactInput', [swap1]),
        encodeSweep(tokens[0].address, 2),
      ]

      await snapshotGasCost(router.connect(trader)['multicall(uint256,bytes[])'](1, data))
    })

    it('2 trades (directly to sender)', async () => {
      await samb.connect(trader).deposit({ value: 3 })
      await samb.connect(trader).approve(router.address, constants.MaxUint256)
      const swap0 = {
        path: encodePath([samb.address, tokens[0].address], [FeeAmount.MEDIUM]),
        recipient: MSG_SENDER,
        amountIn: 3,
        amountOutMinimum: 1,
      }

      const swap1 = {
        path: encodePath([tokens[1].address, tokens[0].address], [FeeAmount.MEDIUM]),
        recipient: MSG_SENDER,
        amountIn: 3,
        amountOutMinimum: 1,
      }

      const data = [
        router.interface.encodeFunctionData('exactInput', [swap0]),
        router.interface.encodeFunctionData('exactInput', [swap1]),
      ]

      await snapshotGasCost(router.connect(trader)['multicall(uint256,bytes[])'](1, data))
    })

    it('3 trades (directly to sender)', async () => {
      await samb.connect(trader).deposit({ value: 3 })
      await samb.connect(trader).approve(router.address, constants.MaxUint256)
      const swap0 = {
        path: encodePath([samb.address, tokens[0].address], [FeeAmount.MEDIUM]),
        recipient: MSG_SENDER,
        amountIn: 3,
        amountOutMinimum: 1,
      }

      const swap1 = {
        path: encodePath([tokens[0].address, tokens[1].address], [FeeAmount.MEDIUM]),
        recipient: MSG_SENDER,
        amountIn: 3,
        amountOutMinimum: 1,
      }

      const swap2 = {
        path: encodePath([tokens[1].address, tokens[2].address], [FeeAmount.MEDIUM]),
        recipient: MSG_SENDER,
        amountIn: 3,
        amountOutMinimum: 1,
      }

      const data = [
        router.interface.encodeFunctionData('exactInput', [swap0]),
        router.interface.encodeFunctionData('exactInput', [swap1]),
        router.interface.encodeFunctionData('exactInput', [swap2]),
      ]

      await snapshotGasCost(router.connect(trader)['multicall(uint256,bytes[])'](1, data))
    })
  })

  describe('#exactInputSingle', () => {
    it('0 -> 1', async () => {
      await snapshotGasCost(exactInputSingle(tokens[0].address, tokens[1].address))
    })

    it('SAMB -> 0', async () => {
      await snapshotGasCost(
        exactInputSingle(
          samb.address,
          tokens[0].address,
          samb.address.toLowerCase() < tokens[0].address.toLowerCase() ? 2 : 3
        )
      )
    })

    it('0 -> SAMB', async () => {
      await snapshotGasCost(
        exactInputSingle(
          tokens[0].address,
          samb.address,
          tokens[0].address.toLowerCase() < samb.address.toLowerCase() ? 2 : 3
        )
      )
    })
  })

  describe('#exactOutput', () => {
    it('0 -> 1', async () => {
      await snapshotGasCost(exactOutput(tokens.slice(0, 2).map((token) => token.address)))
    })

    it('0 -> 1 -> 2', async () => {
      await snapshotGasCost(exactOutput(tokens.map((token) => token.address)))
    })

    it('SAMB -> 0', async () => {
      await snapshotGasCost(exactOutput([samb.address, tokens[0].address]))
    })

    it('0 -> SAMB', async () => {
      await snapshotGasCost(exactOutput([tokens[0].address, samb.address]))
    })
  })

  describe('#exactOutputSingle', () => {
    it('0 -> 1', async () => {
      await snapshotGasCost(exactOutputSingle(tokens[0].address, tokens[1].address))
    })

    it('SAMB -> 0', async () => {
      await snapshotGasCost(exactOutputSingle(samb.address, tokens[0].address))
    })

    it('0 -> SAMB', async () => {
      await snapshotGasCost(exactOutputSingle(tokens[0].address, samb.address))
    })
  })
})
