/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/


/* !
 * \file transpose.asc
 * \brief
 */

#include "acl/acl.h"
#include "data_utils.h"
#include "kernel_operator.h"
#include "tiling/tiling_api.h"

namespace optiling {
BEGIN_TILING_DATA_DEF(TransposeCustomTilingData)
TILING_DATA_FIELD_DEF(uint32_t, b);
TILING_DATA_FIELD_DEF(uint32_t, n);
TILING_DATA_FIELD_DEF(uint32_t, s);
TILING_DATA_FIELD_DEF(uint32_t, hnDiv);
TILING_DATA_FIELD_DEF_STRUCT(ConfusionTransposeTiling, ConfusionTransposeTilingData);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(TransposeCustom, TransposeCustomTilingData)
} // namespace optiling

void ComputeTiling(uint32_t b, uint32_t n, uint32_t s, uint32_t hnDiv, optiling::TransposeCustomTilingData& tiling)
{
    tiling.set_b(b);
    tiling.set_n(n);
    tiling.set_s(s);
    tiling.set_hnDiv(hnDiv);

    std::vector<int64_t> shapeVec = {b, n, s, hnDiv};
    ge::Shape srcShape(shapeVec);

    uint32_t maxValue = 0;
    uint32_t minValue = 0;
    AscendC::GetTransposeMaxMinTmpSize(srcShape, sizeof(uint16_t), 1, maxValue, minValue);
    const uint32_t stackBufferSize = minValue;
    AscendC::GetTransposeTilingInfo(srcShape, stackBufferSize, sizeof(uint16_t), 1,
                                    tiling.ConfusionTransposeTilingData);
}

uint8_t* GetTilingBuf(optiling::TransposeCustomTilingData* tilingData)
{
    uint32_t tilingSize = sizeof(optiling::TransposeCustomTilingData);
    uint8_t* buf = (uint8_t*)malloc(tilingSize);
    tilingData->SaveToBuffer(buf, tilingSize);
    return buf;
}

uint8_t* GenerateTiling(uint32_t b, uint32_t n, uint32_t s, uint32_t hnDiv)
{
    optiling::TransposeCustomTilingData tiling;
    ComputeTiling(b, n, s, hnDiv, tiling);
    return GetTilingBuf(&tiling);
}

namespace MyCustomKernel {
struct VecTiling {
    uint32_t b;
    uint32_t n;
    uint32_t s;
    uint32_t hnDiv;
    ConfusionTransposeTiling TransposeTilingData;
};

template <typename T>
class KernelTranspose {
public:
    __aicore__ inline KernelTranspose() {}
    __aicore__ inline void Init(__gm__ uint8_t* srcGm, __gm__ uint8_t* dstGm, const VecTiling& tilingData,
                                AscendC::TPipe* pipeIn)
    {
        ASCENDC_ASSERT(AscendC::GetBlockNum() != 0, { KERNEL_LOG(KERNEL_ERROR, "block dim can not be zero!"); });
        this->b = tilingData.b;
        this->n = tilingData.n;
        this->s = tilingData.s;
        this->hnDiv = tilingData.hnDiv;

        srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm, b * n * s * hnDiv);
        dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm, b * n * s * hnDiv);

        pipe = pipeIn;
        pipe->InitBuffer(inQueueSrcVecIn, 1, b * n * s * hnDiv * sizeof(T));
        pipe->InitBuffer(inQueueSrcVecOut, 1, b * n * s * hnDiv * sizeof(T));
        this->tiling = tilingData.TransposeTilingData;
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrcVecIn.AllocTensor<T>();
        AscendC::DataCopy(srcLocal, srcGlobal, b * n * s * hnDiv);
        inQueueSrcVecIn.EnQue(srcLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<T> srcLocal = inQueueSrcVecIn.DeQue<T>();
        AscendC::LocalTensor<T> dstLocal = inQueueSrcVecOut.AllocTensor<T>();
        AscendC::Transpose(dstLocal, srcLocal, AscendC::TransposeType::TRANSPOSE_NZ2ND_0213, this->tiling);
        inQueueSrcVecOut.EnQue<T>(dstLocal);
        inQueueSrcVecIn.FreeTensor(srcLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<T> dstLocal = inQueueSrcVecOut.DeQue<T>();
        AscendC::DataCopy(dstGlobal, dstLocal, b * n * s * hnDiv);
        inQueueSrcVecOut.FreeTensor(dstLocal);
    }

private:
    AscendC::TPipe* pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrcVecIn;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> inQueueSrcVecOut;
    AscendC::GlobalTensor<T> srcGlobal;
    AscendC::GlobalTensor<T> dstGlobal;
    uint32_t b = 0;
    uint32_t n = 0;
    uint32_t s = 0;
    uint32_t hnDiv = 0;
    ConfusionTransposeTiling tiling;
};
} // namespace MyCustomKernel

__aicore__ inline void CopyTiling(MyCustomKernel::VecTiling* tiling, GM_ADDR tilingGM)
{
    uint32_t* ptr = reinterpret_cast<uint32_t*>(tiling);
    auto tiling32 = reinterpret_cast<__gm__ uint32_t*>(tilingGM);

    for (uint32_t i = 0; i < sizeof(MyCustomKernel::VecTiling) / sizeof(uint32_t); i++, ptr++) {
        *ptr = *(tiling32 + i);
    }
    return;
}

extern "C" __global__ __aicore__ void transpose_custom(GM_ADDR x, GM_ADDR y, GM_ADDR tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    AscendC::TPipe pipe;
    MyCustomKernel::KernelTranspose<half> op;
    MyCustomKernel::VecTiling tilingData;
    CopyTiling(&tilingData, tiling);
    op.Init(x, y, tilingData, &pipe);
    op.Process();
}

namespace {
constexpr uint32_t USED_CORE_NUM = 1;
constexpr uint32_t TILINGDATA_SIZE = 22;
constexpr uint32_t B = 1;
constexpr uint32_t N = 2;
constexpr uint32_t S = 64;
constexpr uint32_t HNDIV = 32;
} // namespace

int32_t main(int32_t argc, char* argv[])
{
    size_t inputSize = B * S * N * HNDIV * sizeof(uint16_t);
    size_t outputSize = B * S * N * HNDIV * sizeof(uint16_t);
    size_t tilingSize = TILINGDATA_SIZE * sizeof(uint32_t);

    aclInit(nullptr);
    int32_t deviceId = 0;
    aclrtSetDevice(deviceId);
    aclrtStream stream = nullptr;
    aclrtCreateStream(&stream);

    uint8_t *xHost, *yHost;
    uint8_t *xDevice, *yDevice, *tilingDevice;

    aclrtMallocHost((void**)(&xHost), inputSize);
    aclrtMallocHost((void**)(&yHost), outputSize);

    aclrtMalloc((void**)&xDevice, inputSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&yDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void**)&tilingDevice, tilingSize, ACL_MEM_MALLOC_HUGE_FIRST);

    ReadFile("./input/input_x.bin", inputSize, xHost, inputSize);

    uint8_t* buf = GenerateTiling(B, N, S, HNDIV);

    aclrtMemcpy(xDevice, inputSize, xHost, inputSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(tilingDevice, tilingSize, buf, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);

    // Execute the kernel
    transpose_custom<<<USED_CORE_NUM, nullptr, stream>>>(xDevice, yDevice, tilingDevice);

    // Wait for the stop event to complete
    aclrtSynchronizeStream(stream);

    // Copy result to host memory and write to output file
    aclrtMemcpy(yHost, outputSize, yDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST);
    WriteFile("./output/output.bin", yHost, outputSize);

    // Clean up memory
    free(buf);

    aclrtFree(xDevice);
    aclrtFree(yDevice);
    aclrtFree(tilingDevice);
    aclrtFreeHost(xHost);
    aclrtFreeHost(yHost);

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
    return 0;
}