/*

 * Copyright (c) 2025 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */

#include "c_interface_utils.h"

#include "atb/utils/config.h"

#include "atb/utils/singleton.h"

#include "atb/utils/log.h"





using namespace atb;

using namespace atb::cinterfaceTest;



const int64_t MLAINOUTMLA = 12;

const int64_t headSizeQk = 576;

const int64_t headSizeVo = 512;

const int64_t maxNumBlocksPerQuery = 16;

const int64_t dimD = 1;

const int64_t dimE = 125;

const int64_t dimF = 512;

const int64_t dimG = 2;

const int64_t sizeofFP16 = 2;

const int64_t numHeads = 32;

const int64_t kvHeads = 1;

const int64_t blockSize = 128;

const int64_t kSeqlen = 256;

const int64_t numTokens = 32;

const int64_t batch = numTokens * kSeqlen;

const int64_t numBlocks = 64;



TEST(TestATBACL, TestMLAM0C2C1)

{

    atb::Context *context = nullptr;

    aclrtStream stream = nullptr;

    int64_t deviceId = 0;

    Init(&context, &stream, &deviceId);

    if (!atb::GetSingleton<atb::Config>().Is910B()) {

        ATB_LOG(ERROR) << "MLA only supports A2/A3";

        Destroy(&context, &stream);

        GTEST_SKIP();

    }

    uint8_t *inoutHost[MLAINOUTMLA];

    uint8_t *inoutDevice[MLAINOUTMLA];

    aclTensor *tensorList[MLAINOUTMLA];

    size_t inoutSize[MLAINOUTMLA] = {

        numTokens * numHeads * headSizeVo * sizeofFP16,

        numTokens * numHeads * (headSizeQk - headSizeVo) * sizeofFP16,

        numBlocks * blockSize * kvHeads * headSizeVo * sizeofFP16,

        numBlocks * blockSize * kvHeads * (headSizeQk - headSizeVo) * sizeofFP16,

        batch * maxNumBlocksPerQuery * sizeof(int32_t),

        batch * sizeof(int),

        (dimE + dimG * batch) * blockSize * sizeofFP16,

        batch * sizeof(int),

        numHeads * sizeof(float),

        numHeads * sizeof(float),

        numTokens * numHeads * dimF * sizeofFP16,

        numTokens * numHeads * dimD * sizeofFP16,

    };

    CreateInOutData(MLAINOUTMLA, inoutHost, inoutDevice, inoutSize);

    size_t i = 0;

    aclDataType inputFormat = ACL_FLOAT16;

    // 0

    std::vector<std::vector<int64_t>> viewDim = {

        {numTokens, numHeads, headSizeVo},              // 0

        {numTokens, numHeads, headSizeQk - headSizeVo}, // 1

        {numBlocks, blockSize, kvHeads, 512},           // 2

        {numBlocks, blockSize, kvHeads, 64},            // 3

        {batch, maxNumBlocksPerQuery},                  // 4

        {batch},                                        // 5

        {(dimE + dimG * batch), blockSize},             // 6

        {batch},                                        // 7

        {numHeads},                                     // 8

        {numHeads},                                     // 9

        {numTokens, numHeads, 512},                     // 10

        {numTokens, numHeads, 1},                       // 11

    };

    while (i < viewDim.size()) {

        if (i == 4) {

            CreateACLTensorInOut(viewDim[i], ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);

        } else if (i == 5 || i == 7) {

            CreateACLTensorInOut(viewDim[i], ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutHost[i]);

        } else if (8 == i || 9 == i) {

            CreateACLTensorInOut(viewDim[i], ACL_FLOAT, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);

        } else {

            CreateACLTensorInOut(viewDim[i], inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);

        }

    }

    uint64_t workspaceSize = 0;

    atb::Operation *op = nullptr;



    atb::Status ret =

        AtbMLAGetWorkspaceSize(tensorList[0], tensorList[1], tensorList[2], tensorList[3], tensorList[4], tensorList[5],

                               tensorList[6], tensorList[7], tensorList[8], tensorList[9], 32, 1.0, 1, 0, 2, 1,

                               tensorList[10], tensorList[11], &workspaceSize, &op, context);

    EXPECT_EQ(ret, atb::NO_ERROR);

    void *workspaceAddr = nullptr;

    if (workspaceSize > 0) {

        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);

        EXPECT_EQ(ret, ACL_SUCCESS);

    }

    ret = AtbMLA(workspaceAddr, workspaceSize, op, context);

    EXPECT_EQ(ret, atb::NO_ERROR);



    ret = aclrtSynchronizeStream(stream);

    EXPECT_EQ(ret, ACL_SUCCESS);



    if (workspaceSize > 0) {

        EXPECT_EQ(aclrtFree(workspaceAddr), ACL_SUCCESS);

    }

    EXPECT_EQ(atb::DestroyOperation(op), NO_ERROR);

    Destroy(&context, &stream);

    for (i = 0; i < MLAINOUTMLA; i++) {

        aclrtFreeHost(inoutHost[i]);

        aclrtFree(inoutDevice[i]);

    }

}