/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file one_axis_concat_pure_copy.h
 * \brief one axis concat pure copy
 */

#ifndef ONE_AXIS_CONCAT_PURE_COPY_
#define ONE_AXIS_CONCAT_PURE_COPY_

#include "concat_base.h"
#include "kernel_operator.h"
#include "kernel_operator_list_tensor_intf.h"

namespace Concat {
using namespace AscendC;

template <typename TILINGDATA = ConcatTilingData>
class OneAxisConcatPureCopy {
public:
    __aicore__ inline OneAxisConcatPureCopy(const TILINGDATA& tilingData, TPipe& pipe)
        : tilingData_(tilingData), pipe_(pipe){};
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR dst);
    __aicore__ inline void Process();

private:
    __aicore__ inline __gm__ int8_t* GetTensorAddr(uint32_t index, int64_t offset);
    __aicore__ inline void CopyInSingleTensor(
        const GlobalTensor<int8_t> srcGm, int64_t copyRows, int64_t copyCols, int64_t srcOffset, int64_t tensorDim1);
    __aicore__ inline void CopyOutSingleTensor(int64_t copyRows, int64_t copyCols, int64_t dstOffset);
    __aicore__ inline void ProcessSingleTensor(
        int64_t tensorIdx, int64_t tensorDim1, int64_t totalRows, int64_t totalCols, int64_t globalSrcOffset,
        int64_t globalDstOffset);
    __aicore__ inline void ProcessNoSplitDim1();
    __aicore__ inline void SetCopyInparam(int64_t rows, int64_t cols, int64_t srcStride, int64_t dstStride);
    __aicore__ inline void ProcessSplitDim1();
    __aicore__ inline int64_t GetTensorDim1(int64_t idx);
    __aicore__ inline void UpdateSplitInfo();

private:
    TPipe& pipe_;
    TQueBind<QuePosition::VECIN, QuePosition::VECOUT, BUFFER_NUM> inQueue_;
    GlobalTensor<int8_t> dstGlobal_;
    GlobalTensor<int8_t> srcGlobal_;
    const TILINGDATA& tilingData_;
    uint32_t blockIdx_{0};
    int64_t blockOffset_{0};

    uint32_t colsUsedCoreNum_{1};
    uint32_t rowsUsedCoreNum_{1};
    uint32_t blockIdxInCol_{0};
    uint32_t blockIdxInRow_{0};

    DataCopyExtParams copyInParam_{0, 0, 0, 0, 0};
    DataCopyPadExtParams<int8_t> padParam_{false, 0, 0, 0};
    DataCopyExtParams copyOutParam_{0, 0, 0, 0, 0};

    int64_t endTensorIdx_{0};
    int64_t endTensorOffset_{0};
    int64_t startTensorIdx_{0};
    int64_t startTensorOffset_{0};
    ListTensorDesc inputList_;
    uint64_t buf_[ARRAY_SIZE];
    TensorDesc<int8_t> desc_;
};

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::Init(GM_ADDR x, GM_ADDR dst)
{
    blockIdx_ = GetBlockIdx();
    if constexpr (IsSame<TILINGDATA, ConcatTilingData>::value) {
        rowsUsedCoreNum_ = tilingData_.uoDim0;
        colsUsedCoreNum_ = tilingData_.uoDim1;
        blockIdxInCol_ = blockIdx_ % colsUsedCoreNum_;
        blockIdxInRow_ = blockIdx_ / colsUsedCoreNum_;
        if (blockIdxInCol_ != 0) {
            startTensorIdx_ = tilingData_.endTensorIdx[blockIdx_ - 1];
            startTensorOffset_ = tilingData_.endTensorOffset[blockIdx_ - 1];
        }
        endTensorOffset_ = tilingData_.endTensorOffset[blockIdx_];
        endTensorIdx_ = tilingData_.endTensorIdx[blockIdx_];
    }
    if constexpr (IsSame<TILINGDATA, ConcatTilingDataNoArray>::value) {
        blockOffset_ = static_cast<int64_t>(blockIdx_) * tilingData_.ubFactorDim0;
        dstGlobal_.SetGlobalBuffer((__gm__ int8_t*)dst + tilingData_.catDim1 * tilingData_.dtypeSize * static_cast<int64_t>(blockOffset_));
    } else {
        blockOffset_ = blockIdxInRow_ * tilingData_.ubFactorDim0;
        int64_t colOffset = static_cast<int64_t>(blockIdxInCol_) * tilingData_.ubFactorDim1 * tilingData_.dtypeSize;
        dstGlobal_.SetGlobalBuffer(
            (__gm__ int8_t*)dst + blockOffset_ * tilingData_.catDim1 * tilingData_.dtypeSize + colOffset);
    }

    pipe_.InitBuffer(inQueue_, BUFFER_NUM, tilingData_.bufferSize * tilingData_.dtypeSize);

    inputList_ = ListTensorDesc(reinterpret_cast<__gm__ void*>(x));
    desc_.SetShapeAddr(&buf_[0]);
    if constexpr (IsSame<TILINGDATA, ConcatTilingData>::value) {
        UpdateSplitInfo();
    }
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::UpdateSplitInfo()
{
    if (startTensorOffset_ == GetTensorDim1(startTensorIdx_)) {
        startTensorIdx_ += 1;
        startTensorOffset_ = 0;
    }
}

template <typename TILINGDATA>
__aicore__ inline __gm__ int8_t* OneAxisConcatPureCopy<TILINGDATA>::GetTensorAddr(uint32_t index, int64_t offset)
{
    return inputList_.GetDataPtr<int8_t>(index) + offset;
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::Process()
{
    if (blockIdx_ >= GetBlockNum()) {
        return;
    }
    if constexpr (IsSame<TILINGDATA, ConcatTilingData>::value) {
        ProcessSplitDim1();
    } else {
        ProcessNoSplitDim1();
    }
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::ProcessNoSplitDim1()
{
    int64_t totalRows = tilingData_.ubFactorDim0;
    if (blockIdx_ == GetBlockNum() - 1) {
        totalRows = tilingData_.tailUbFactorDim0;
    }
    int64_t totalColOffset = 0;
    for (int64_t i = 0; i < tilingData_.tensorNum; i++) {
        int64_t dim1 = GetTensorDim1(i);
        int64_t dim0stride = GetTensorDim0Stride<TILINGDATA>(tilingData_, i, dim1);
        int64_t globalSrcOffset = blockOffset_ * dim0stride;
        ProcessSingleTensor(i, dim0stride, totalRows, dim1, globalSrcOffset, totalColOffset);
        totalColOffset += dim1;
    }
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::ProcessSplitDim1()
{
    int64_t totalRows = tilingData_.ubFactorDim0;
    if (blockIdxInRow_ == rowsUsedCoreNum_ - 1) {
        totalRows = tilingData_.tailUbFactorDim0;
    }
    int64_t totalColOffset = 0;
    int64_t globalSrcOffset = 0;
    int64_t tensorDim1 = GetTensorDim1(startTensorIdx_);
    int64_t dim0stride = GetTensorDim0Stride<TILINGDATA>(tilingData_, startTensorIdx_, tensorDim1);
    globalSrcOffset = blockOffset_ * dim0stride + startTensorOffset_;
    if (startTensorIdx_ == endTensorIdx_) {
        ProcessSingleTensor(
            startTensorIdx_, dim0stride, totalRows, endTensorOffset_ - startTensorOffset_, globalSrcOffset,
            totalColOffset);
        return;
    }
    ProcessSingleTensor(
        startTensorIdx_, dim0stride, totalRows, tensorDim1 - startTensorOffset_, globalSrcOffset, totalColOffset);
    totalColOffset += tensorDim1 - startTensorOffset_;
    for (int64_t i = startTensorIdx_ + 1; i < endTensorIdx_; i++) {
        tensorDim1 = GetTensorDim1(i);
        dim0stride = GetTensorDim0Stride<TILINGDATA>(tilingData_, i, tensorDim1);
        globalSrcOffset = blockOffset_ * dim0stride;
        ProcessSingleTensor(i, dim0stride, totalRows, tensorDim1, globalSrcOffset, totalColOffset);
        totalColOffset += tensorDim1;
    }
    tensorDim1 = GetTensorDim1(endTensorIdx_);
    dim0stride = GetTensorDim0Stride<TILINGDATA>(tilingData_, endTensorIdx_, tensorDim1);
    globalSrcOffset = blockOffset_ * dim0stride;
    ProcessSingleTensor(endTensorIdx_, dim0stride, totalRows, endTensorOffset_, globalSrcOffset, totalColOffset);
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::ProcessSingleTensor(
    int64_t tensorIdx, int64_t tensorDim1, int64_t totalRows, int64_t totalCols, int64_t globalSrcOffset,
    int64_t globalDstOffset)
{
    if (tensorDim1 == 0 || totalRows == 0 || totalCols == 0) {
        return;
    }
    srcGlobal_.SetGlobalBuffer(GetTensorAddr(tensorIdx, globalSrcOffset * tilingData_.dtypeSize));
    int64_t colFactor = min(totalCols, static_cast<int64_t>(tilingData_.bufferSize));
    int64_t rowFactor = tilingData_.bufferSize / colFactor;
    int64_t colLoopSize = (totalCols + colFactor - 1) / colFactor;
    int64_t rowLoopSize = (totalRows + rowFactor - 1) / rowFactor;
    int64_t tailColFactor = totalCols - (colLoopSize - 1) * colFactor;
    int64_t tailRowFactor = totalRows - (rowLoopSize - 1) * rowFactor;
    int64_t srcOffset = 0;
    int64_t dstOffset = 0;
    for (int64_t m = 0; m < rowLoopSize; m++) {
        int64_t copyRows = m == rowLoopSize - 1 ? tailRowFactor : rowFactor;
        dstOffset = globalDstOffset + m * rowFactor * tilingData_.catDim1;
        for (int64_t n = 0; n < colLoopSize - 1; n++) {
            srcOffset = m * rowFactor * tensorDim1 + n * colFactor;
            CopyInSingleTensor(srcGlobal_, copyRows, colFactor, srcOffset, tensorDim1);
            CopyOutSingleTensor(copyRows, colFactor, dstOffset);
            dstOffset += colFactor;
        }
        srcOffset = m * rowFactor * tensorDim1 + (colLoopSize - 1) * colFactor;
        CopyInSingleTensor(srcGlobal_, copyRows, tailColFactor, srcOffset, tensorDim1);
        CopyOutSingleTensor(copyRows, tailColFactor, dstOffset);
    }
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::CopyInSingleTensor(
    const GlobalTensor<int8_t> srcGm, int64_t copyRows, int64_t copyCols, int64_t srcOffset, int64_t tensorDim1)
{
    LocalTensor<int8_t> srcLocal = inQueue_.AllocTensor<int8_t>();
    SetCopyInparam(copyRows, copyCols, tensorDim1 - copyCols, 0);
    DataCopyPad<int8_t, PaddingMode::Compact>(
        srcLocal, srcGm[srcOffset * tilingData_.dtypeSize], copyInParam_, padParam_);
    inQueue_.EnQue(srcLocal);
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::CopyOutSingleTensor(
    int64_t copyRows, int64_t copyCols, int64_t dstOffset)
{
    LocalTensor<int8_t> dstLocal = inQueue_.DeQue<int8_t>();
    copyOutParam_.blockCount = copyRows;
    copyOutParam_.blockLen = copyCols * tilingData_.dtypeSize;
    copyOutParam_.dstStride = (tilingData_.catDim1 - copyCols) * tilingData_.dtypeSize;
    copyOutParam_.srcStride = 0;
    DataCopyPad<int8_t, PaddingMode::Compact>(dstGlobal_[dstOffset * tilingData_.dtypeSize], dstLocal, copyOutParam_);
    inQueue_.FreeTensor(dstLocal);
}

template <typename TILINGDATA>
__aicore__ inline void OneAxisConcatPureCopy<TILINGDATA>::SetCopyInparam(
    int64_t rows, int64_t cols, int64_t srcStride, int64_t dstStride)
{
    copyInParam_.blockCount = rows;
    copyInParam_.blockLen = cols * tilingData_.dtypeSize;
    copyInParam_.srcStride = srcStride * tilingData_.dtypeSize;
    copyInParam_.dstStride = dstStride;
}

template <typename TILINGDATA>
__aicore__ inline int64_t OneAxisConcatPureCopy<TILINGDATA>::GetTensorDim1(int64_t idx)
{
    if (idx < PRELOAD_DIM1_SIZE) {
        return tilingData_.preLoadDim1[idx];
    }
    return GetNonConDimSize<TILINGDATA, int8_t>(tilingData_, idx, inputList_, desc_) * tilingData_.sameShapeTensorDim1;
}
} // namespace Concat

#endif // ONE_AXIS_CONCAT_PURE_COPY_