/*

 * Copyright (c) 2025 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */

#include "c_interface_utils.h"

#include "atb/utils/config.h"

#include "atb/utils/singleton.h"

#include "atb/utils/log.h"



using namespace atb;

using namespace atb::cinterfaceTest;



const int64_t MLAINOUTMLAPP = 28;

const int64_t dims = 7168;

const int64_t dimB = 2112;

const int64_t dimC = 1536;

const int64_t dimD = 1;

const int64_t quantScale0 = 0;

const int64_t quantOffset0 = 0;

const int64_t wdqkv = 1 * 224 * 2112 * 32;

const int64_t deScale0 = 2112;

const int64_t bias0 = 2112;

const int64_t beta1 = 1536;

const int64_t gamma1 = 1536;

const int64_t quantScale1 = 1;

const int64_t quantOffset1 = 1;



const int64_t blockSize = 128;

const int64_t numTokens = 32;

const int64_t numHeads = 32;

const int64_t kvHeads = 1;

const int64_t kSeqlen = 256;

const int64_t batch = numTokens * kSeqlen;

const int64_t numBlocks = 64;



const int64_t sizeofFP16 = 2;

const int64_t wuq = 1 * 48 * numHeads * 192 * 32;

const int64_t deScale1 = numHeads * 192;

const int64_t bias1 = numHeads * 192;

const int64_t gamma2 = 512;

const int64_t cosNum = numTokens * 64;

const int64_t sinNum = numTokens * 64;

const int64_t wuk = numHeads * 128 * 512;



const int64_t kvCache = numBlocks * blockSize * 1 * 512;

const int64_t kvCacheC0 = numBlocks * blockSize * 1 * 576;

const int64_t kvCacheC2 = numBlocks * numHeads * 512 / 32 * blockSize * 32;

const int64_t kvCacheC3 = numBlocks * numHeads * 512 / 16 * blockSize * 16;



const int64_t kvCacheRope = numBlocks * blockSize * 1 * 64;

const int64_t kvCacheRopeC2 = numBlocks * numHeads * 64 / 16 * blockSize * 16;

const int64_t kvCacheRopeC3 = numBlocks * numHeads * 64 / 16 * blockSize * 16;

const int64_t slotmapping = numTokens;

const int64_t ctkvScale = 1;

const int64_t qNopeScale = numHeads;



const int64_t outTensor0C0 = numTokens * numHeads * 576;

const int64_t outTensor0C1 = numTokens * numHeads * 512;

const int64_t outTensor0C2 = numTokens * numHeads * 512;



const int64_t outTensor1C0 = numBlocks * blockSize * 576;

const int64_t outTensor1C1 = numBlocks * blockSize * 512;

const int64_t outTensor1C2 = numBlocks * numHeads * 512 / 32 * blockSize * 32;

const int64_t outTensor1C3 = numBlocks * numHeads * 512 / 16 * blockSize * 16;



const int64_t outTensor2 = numTokens * numHeads * 64;



const int64_t outTensor3C1 = numBlocks * blockSize * 1 * 64;

const int64_t outTensor3C2 = numBlocks * numHeads * 64 / 16 * blockSize * 16;





TEST(TestATBACL, TestMLAPreProcesscomb0Q0C0)

{

    atb::Context *context = nullptr;

    aclrtStream stream = nullptr;

    int64_t deviceId = 0;

    Init(&context, &stream, &deviceId);

    if (!atb::GetSingleton<atb::Config>().Is910B()) {

        ATB_LOG(ERROR) << "MLA PreProcess only supports A2/A3";

        Destroy(&context, &stream);

        GTEST_SKIP();

    }

    uint8_t *inoutHost[MLAINOUTMLAPP];

    uint8_t *inoutDevice[MLAINOUTMLAPP];

    aclTensor *tensorList[MLAINOUTMLAPP];

    size_t inoutSize[MLAINOUTMLAPP] = {

        numTokens * dims * sizeofFP16,

        dims * sizeofFP16,

        dims * sizeofFP16,

        quantScale0 * sizeofFP16,

        quantOffset0,

        wdqkv,

        deScale0 * sizeof(int64_t),

        bias0 * sizeof(int32_t),

        gamma1 * sizeofFP16,

        beta1 * sizeofFP16,

        quantScale1 * sizeofFP16,

        quantOffset1,

        wuq,

        deScale1 * sizeof(float),

        bias1 * sizeof(int32_t),

        gamma2 * sizeofFP16,

        cosNum * sizeofFP16,

        sinNum * sizeofFP16,

        wuk * sizeofFP16,

        kvCacheC0 * sizeofFP16,

        kvCacheRope * sizeofFP16,

        slotmapping * sizeof(int32_t),

        ctkvScale * sizeofFP16,

        qNopeScale * sizeofFP16,

        outTensor0C0 * sizeofFP16,

        outTensor1C0 * sizeofFP16,

        outTensor2 * sizeofFP16,

        outTensor1C3 * sizeofFP16,

    };

    CreateInOutData(MLAINOUTMLAPP, inoutHost, inoutDevice, inoutSize);

    size_t i = 0;

    aclDataType inputFormat = ACL_FLOAT16;



    // 0

    std::vector<int64_t> viewDim = {numTokens, dims};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 1

    viewDim = {dims};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 2

    viewDim = {dims};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 3

    viewDim = {dimD};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 4

    viewDim = {dimD};

    CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 5

    viewDim = {1, 224, dimB, 32};

    CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_FRACTAL_NZ, tensorList, i, inoutDevice[i]);



    // 6

    viewDim = {dimB};

    CreateACLTensorInOut(viewDim, ACL_INT64, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 7

    viewDim = {dimB};

    CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 8

    viewDim = {dimC};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 9

    viewDim = {dimC};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 10

    viewDim = {dimD};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 11

    viewDim = {dimD};

    CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 12

    viewDim = {1, 48, numHeads * 192, 32};

    CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_FRACTAL_NZ, tensorList, i, inoutDevice[i]);



    // 13

    viewDim = {numHeads * 192};

    CreateACLTensorInOut(viewDim, ACL_INT64, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 14

    viewDim = {numHeads * 192};

    CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 15

    viewDim = {gamma2};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 16

    viewDim = {numTokens, 64};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 17

    viewDim = {numTokens, 64};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 18

    viewDim = {numHeads, 128, 512};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 19

    viewDim = {numBlocks, blockSize, 1, 576};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 20

    viewDim = {numBlocks, blockSize, 1, 64};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 21

    viewDim = {numTokens};

    CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 22

    viewDim = {ctkvScale};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 23

    viewDim = {qNopeScale};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // out

    // 0

    viewDim = {numTokens, numHeads, 576};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 1

    viewDim = {numBlocks, blockSize, 1, 576};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 2

    viewDim = {numTokens, numHeads, 64};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);



    // 3

    viewDim = {numBlocks, blockSize, 1, 64};

    CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);

    uint64_t workspaceSize = 0;

    atb::Operation *op = nullptr;



    atb::Status ret = AtbMLAPreprocessGetWorkspaceSize(

        tensorList[0], tensorList[1], tensorList[2], tensorList[3], tensorList[4], tensorList[5], tensorList[6],

        tensorList[7], tensorList[8], tensorList[9], tensorList[10], tensorList[11], tensorList[12], tensorList[13],

        tensorList[14], tensorList[15], tensorList[16], tensorList[17], tensorList[18], tensorList[19], tensorList[20],

        tensorList[21], tensorList[22], tensorList[23], 0, 0, 0, 1e-5, 2, 3, true, true, true, 0, 0, tensorList[24],

        tensorList[25], tensorList[26], tensorList[27], &workspaceSize, &op, context);



    EXPECT_EQ(ret, atb::NO_ERROR);

    void *workspaceAddr = nullptr;

    if (workspaceSize > 0) {

        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);

        EXPECT_EQ(ret, ACL_SUCCESS);

    }

    ret = AtbMLAPreprocess(workspaceAddr, workspaceSize, op, context);

    EXPECT_EQ(ret, atb::NO_ERROR);

    ret = aclrtSynchronizeStream(stream);



    if (workspaceSize > 0) {

        EXPECT_EQ(aclrtFree(workspaceAddr), ACL_SUCCESS);

    }

    EXPECT_EQ(atb::DestroyOperation(op), atb::NO_ERROR);

    Destroy(&context, &stream);

    for (i = 0; i < MLAINOUTMLAPP; i++) {

        aclrtFreeHost(inoutHost[i]);

        aclrtFree(inoutDevice[i]);

    }

}