// Core dependencies
import { Program } from "@coral-xyz/anchor";
import { Connection, Keypair, PublicKey, TransactionInstruction } from "@solana/web3.js";

// Local imports
import { BasketsProgram } from "./idl/types";
import { prepareV0Transactions, VersionedTxs } from "./utils/txUtils";
import { updateTokenPricesIxs } from "./instructions/update/updateTokenPrices";
import { withdrawBeforeRebalanceIx } from "./instructions/rebalance/withdrawBeforeRebalance";
import { BasketState, computeRebalanceInfos, fetchBasketState, getBasketTokenPrices, RebalanceInfo } from "./state/basket";
import { depositAfterRebalanceIx } from "./instructions/rebalance/depositAfterRebalance";
import { generateSwapInstruction, getQuoteResponseHandler } from "./instructions/jup";
import { MAX_NUMBER_OF_SWAPS, MAX_SELL_VALUE_PER_TOKEN, MIN_SWAP_VALUE } from "./utils/constants";
import { createAtasIxs } from "./utils/createAtas";


export function getMaxRebalanceAmount(params: {
    fromInfo: RebalanceInfo,
    toInfo: RebalanceInfo,
}): {
    amount: number,
    value: number,
} {
    const {fromInfo, toInfo} = params;
    const valueLimit = Math.min(-fromInfo.valueDiff, toInfo.valueDiff);
    const amount = Math.min(
        Math.max(fromInfo.maxSpendAmount, 0),
        Math.floor(valueLimit * 10 ** fromInfo.tokenDecimals / fromInfo.tokenPrice)
    );
    const value = amount * fromInfo.tokenPrice / 10 ** fromInfo.tokenDecimals;
    return { amount, value };
}

export async function generateRebalanceInstructionsForTokenPair(
    sdkParams: {
        payer: PublicKey,
        connection: Connection,
        program: Program<BasketsProgram>,
        priorityFee: number,
        jupiterApiKey: string,
        maxAllowedAccounts: number,
    },
    params: {
        basketState: BasketState,
        fromInfo: RebalanceInfo,
        toInfo: RebalanceInfo,
        slippageBps: number,
        minSwapValue: number,
    }
): Promise<{
    amount: number;
    value: number;
    ixs: TransactionInstruction[];
    luts: PublicKey[];
}> {
    const { basketState, fromInfo, toInfo, slippageBps, minSwapValue } = params;

    const {amount, value } = getMaxRebalanceAmount({
        fromInfo,
        toInfo,
    });

    if (amount === 0 || value < minSwapValue)
        return {
            amount: 0,
            value: 0,
            ixs: [],
            luts: [],
        };

    const quoteResponse = await getQuoteResponseHandler({
        jupiterApiKey: sdkParams.jupiterApiKey,
        maxAllowedAccounts: sdkParams.maxAllowedAccounts,
        fromToken: fromInfo.token,
        toToken: toInfo.token,
        amount: amount,
        slippageBps,
    });

    const withdrawIx = await withdrawBeforeRebalanceIx({
        program: sdkParams.program,
        basketState: basketState,
        payer: sdkParams.payer,
        fromTokenMint: fromInfo.token,
        toTokenMint: toInfo.token,
        amountToWithdraw: amount,
        checkWeights: true,
        fromTokenWeight: fromInfo.targetWeight,
        toTokenWeight: toInfo.targetWeight,
    });

    const jupIxAndLuts = await generateSwapInstruction({
        payer: sdkParams.payer,
        jupiterApiKey: sdkParams.jupiterApiKey,
        quoteResponse: quoteResponse,
    }).catch(() => null);
    if (!jupIxAndLuts) {
        return {
            amount: 0,
            value: 0,
            ixs: [],
            luts: [],
        };
    }

    const depositIx = await depositAfterRebalanceIx({
        program: sdkParams.program,
        basketState: basketState,
        payer: sdkParams.payer,
        fromTokenMint: fromInfo.token,
        toTokenMint: toInfo.token,
        checkWeights: true,
    });

    return {
        amount: amount,
        value: value,
        ixs: [withdrawIx, jupIxAndLuts.ix, depositIx],
        luts: jupIxAndLuts.luts,
    };
}

export async function swapTokensHandler(
    sdkParams: {
        payer: PublicKey,
        connection: Connection,
        program: Program<BasketsProgram>,
        priorityFee: number,
        jupiterApiKey: string,
    },
    params: {
        basket: PublicKey;
        fromToken: PublicKey;
        toToken: PublicKey;
        fromAmount: number;
        quoteResponse: any;
        fromTokenWeight?: number;
        toTokenWeight?: number;
    }
): Promise<VersionedTxs> {
    const basketState: BasketState = await fetchBasketState(sdkParams.program, params.basket);

    const fromTokenIndex = basketState.compositionMints.findIndex(mint => mint.toBase58() === params.fromToken.toBase58());
    const toTokenIndex = basketState.compositionMints.findIndex(mint => mint.toBase58() === params.toToken.toBase58());

    const withdrawIx = await withdrawBeforeRebalanceIx({
        program: sdkParams.program,
        basketState: basketState,
        payer: sdkParams.payer,
        fromTokenMint: params.fromToken,
        toTokenMint: params.toToken,
        amountToWithdraw: params.fromAmount,
        checkWeights: false,
        fromTokenWeight: params.fromTokenWeight ?? basketState.compositionTargetWeights[fromTokenIndex],
        toTokenWeight: params.toTokenWeight ?? basketState.compositionTargetWeights[toTokenIndex],
    });

    const jupIxAndLuts = await generateSwapInstruction({
        payer: sdkParams.payer,
        jupiterApiKey: sdkParams.jupiterApiKey,
        quoteResponse: params.quoteResponse,
    });

    const depositIx = await depositAfterRebalanceIx({
        program: sdkParams.program,
        basketState: basketState,
        payer: sdkParams.payer,
        fromTokenMint: params.fromToken,
        toTokenMint: params.toToken,
        checkWeights: false,
    });

    const preIxs = await createAtasIxs(sdkParams.connection, {
        payer: sdkParams.payer,
        mints: [params.fromToken, params.toToken],
    });

    return await prepareV0Transactions({
        connection: sdkParams.connection,
        payer: sdkParams.payer,
        priorityFee: sdkParams.priorityFee,
        multipleIxs: [...preIxs, [withdrawIx, jupIxAndLuts.ix, depositIx]],
        multipleLookupTableAddresses: [...new Array(preIxs.length).fill([]), jupIxAndLuts.luts],
        signers: [...new Array(preIxs.length).fill([]), []],
        batches: [preIxs.length, 1],
    });
}

export async function rebalanceBasketTokensHandler(
    sdkParams: {
        payer: PublicKey,
        connection: Connection,
        program: Program<BasketsProgram>,
        priorityFee: number,
        jupiterApiKey: string,
        maxAllowedAccounts: number,
    },
    params: {
        basket: PublicKey;
        fromToken?: PublicKey;
        toToken?: PublicKey;
        minSwapValue?: number;
        maxSellValuePerToken?: number;
        maxNumberOfSwaps?: number;
    }
): Promise<VersionedTxs> {
    // Extract and set default params
    const basket = params.basket;
    const fromToken = params.fromToken ?? PublicKey.default;
    const toToken = params.toToken ?? PublicKey.default;
    const minSwapValue = params.minSwapValue ?? MIN_SWAP_VALUE;
    const maxSellValue = params.maxSellValuePerToken ?? MAX_SELL_VALUE_PER_TOKEN;
    const maxSwaps = params.maxNumberOfSwaps ?? MAX_NUMBER_OF_SWAPS;

    // Get price update instructions
    const { ixs: updatePricesIxs, luts: updatePricesLuts } = await updateTokenPricesIxs({
        program: sdkParams.program,
        basket: basket,
    });

    // Initialize batch arrays
    const tokenMints: PublicKey[] = [];
    const firstBatchIxs = updatePricesIxs.map(ix => [ix]);
    const firstBatchLuts = new Array(updatePricesIxs.length).fill(updatePricesLuts);
    const firstBatchSigners = new Array(updatePricesIxs.length).fill([]);
    const secondBatchIxs: TransactionInstruction[][] = [];
    const secondBatchLuts: PublicKey[][] = [];
    const secondBatchSigners: Keypair[][] = [];

    // Get basket state and calculate values
    const basketState = await fetchBasketState(sdkParams.program, basket);
    const slippageBps = basketState.rebalanceSlippageBps;
    const oraclePrices = await getBasketTokenPrices(sdkParams.program, basketState);
    const { rebalanceInfos } = computeRebalanceInfos({
        basketState,
        oraclePrices,
    });

    // Sort rebalance infos by value difference and find relevant tokens
    const sortedRebalanceInfos = rebalanceInfos.sort((a, b) => a.valueDiff - b.valueDiff);
    let posIndex = 0;
    while (posIndex < sortedRebalanceInfos.length && sortedRebalanceInfos[posIndex].valueDiff < 0)
        posIndex++;
    let negIndex = sortedRebalanceInfos.length - 1;
    while (negIndex >= 0 && sortedRebalanceInfos[negIndex].valueDiff > 0)
        negIndex--;
    for (let i = 0; i < sortedRebalanceInfos.length; i++)
        if (sortedRebalanceInfos[i].valueDiff <= 0)
            sortedRebalanceInfos[i].valueDiff = Math.max(sortedRebalanceInfos[i].valueDiff, -maxSellValue);

    const fromInfo = sortedRebalanceInfos.find(info => info.token.equals(fromToken));
    const toInfo = sortedRebalanceInfos.find(info => info.token.equals(toToken));

    // Handle direct swap between fromToken and toToken
    if (fromInfo && toInfo) {
        const { ixs, luts, amount } = await generateRebalanceInstructionsForTokenPair(sdkParams, {
            basketState,
            fromInfo,
            toInfo,
            slippageBps,
            minSwapValue,
        });
        if (amount > 0) {
            secondBatchIxs.push(ixs);
            secondBatchLuts.push(luts);
            secondBatchSigners.push([]);
            tokenMints.push(fromInfo.token, toInfo.token);
        }
    }
    // Handle selling fromToken to multiple tokens
    else if (fromInfo) {
        // Iterate through tokens that need value added
        for (let i = sortedRebalanceInfos.length - 1; i > negIndex; i--) {
            const toInfo = sortedRebalanceInfos[i];
            if (fromInfo.valueDiff >= 0 || toInfo.valueDiff <= 0 || toInfo.token.equals(fromToken)) continue;

            const { amount, value, ixs, luts } = await generateRebalanceInstructionsForTokenPair(sdkParams, {
                basketState,
                fromInfo,
                toInfo,
                slippageBps,
                minSwapValue,
            });

            if (amount === 0) continue;

            secondBatchIxs.push(ixs);
            secondBatchLuts.push(luts);
            secondBatchSigners.push([]);
            tokenMints.push(fromInfo.token, toInfo.token);
            fromInfo.maxSpendAmount -= amount;
            fromInfo.valueDiff += value;
            toInfo.valueDiff -= value;

            if (secondBatchIxs.length === maxSwaps) break;
        }
    }
    // Handle buying toToken from multiple tokens
    else if (toInfo) {
        // Iterate through tokens that need value removed
        for (let i = 0; i < posIndex; i++) {
            const fromInfo = sortedRebalanceInfos[i];
            if (fromInfo.valueDiff >= 0 || toInfo.valueDiff <= 0 || toInfo.token.equals(fromToken)) continue;
    
            const { amount, value, ixs, luts } = await generateRebalanceInstructionsForTokenPair(sdkParams, {
                basketState,
                fromInfo,
                toInfo,
                slippageBps,
                minSwapValue,
            });

            if (amount === 0) continue;

            secondBatchIxs.push(ixs);
            secondBatchLuts.push(luts);
            secondBatchSigners.push([]);
            tokenMints.push(fromInfo.token, toInfo.token);
            fromInfo.maxSpendAmount -= amount;
            fromInfo.valueDiff += value;
            toInfo.valueDiff -= value;

            if (secondBatchIxs.length === maxSwaps) break;
        }
    } else {
        for (let i = 0; i < posIndex; i++)
            for (let j = sortedRebalanceInfos.length - 1; j > negIndex; j--) {
                const fromInfo = sortedRebalanceInfos[i];
                const toInfo = sortedRebalanceInfos[j];
                if (fromInfo.valueDiff >= 0 || toInfo.valueDiff <= 0 || toInfo.token.equals(fromToken)) continue;
        
                if (fromInfo.token.toBase58() === "8888xhvqnTWZ6tNbfAyP8mLLfaKsqfiW33tgA8RZ8888") continue;
                const { amount, value, ixs, luts } = await generateRebalanceInstructionsForTokenPair(sdkParams, {
                    basketState,
                    fromInfo,
                    toInfo,
                    slippageBps,
                    minSwapValue,
                });

                if (amount === 0) continue;

                secondBatchIxs.push(ixs);
                secondBatchLuts.push(luts);
                secondBatchSigners.push([]);
                tokenMints.push(fromInfo.token, toInfo.token);
                fromInfo.maxSpendAmount -= amount;
                fromInfo.valueDiff += value;
                toInfo.valueDiff -= value;

                if (secondBatchIxs.length === maxSwaps) break;
            }
    }

    if (secondBatchIxs.length === 0)
        throw new Error("No swaps to perform");

    // Create token accounts if needed
    const preIxs = await createAtasIxs(sdkParams.connection, {
        payer: sdkParams.payer,
        mints: tokenMints,
    });
    firstBatchIxs.push(...preIxs);
    firstBatchLuts.push(...new Array(preIxs.length).fill([]));
    firstBatchSigners.push(...new Array(preIxs.length).fill([]));

    // Prepare and return transactions
    return await prepareV0Transactions({
        connection: sdkParams.connection,
        payer: sdkParams.payer,
        priorityFee: sdkParams.priorityFee,
        multipleIxs: [...firstBatchIxs, ...secondBatchIxs],
        multipleLookupTableAddresses: [...firstBatchLuts, ...secondBatchLuts],
        signers: [...firstBatchSigners, ...secondBatchSigners],
        batches: [firstBatchIxs.length, secondBatchIxs.length],
    });
}
