* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef BLOCK_EPICLOGUE_TRANSPOSE_HPP
#define BLOCK_EPICLOGUE_TRANSPOSE_HPP
#include "catlass/catlass.hpp"
#include "catlass/arch/resource.hpp"
#include "catlass/epilogue/dispatch_policy.hpp"
#include "catlass/gemm/helper.hpp"
#include "catlass/gemv/tile/vec_copy_gm_to_ub.hpp"
#include "catlass/gemv/tile/vec_copy_ub_to_gm.hpp"
#include "catlass/gemv/helper.hpp"
#include "catlass/gemv_coord.hpp"
#include "catlass/layout/layout.hpp"
#include "catlass/matrix_coord.hpp"
namespace Catlass::Epilogue::Block {
template <
class ArchTag_,
class DataType_,
uint32_t STAGES_
>
class BlockEpilogueTranspose {
public:
using ArchTag = ArchTag_;
using Element = typename DataType_::Element;
using ScaleLayout = typename DataType_::Layout;
static constexpr uint32_t STAGES = STAGES_;
struct Params {
GM_ADDR ptrSrcReal;
GM_ADDR ptrSrcImag;
GM_ADDR ptrOutReal;
GM_ADDR ptrOutImag;
ScaleLayout layout;
CATLASS_HOST_DEVICE
Params() {}
CATLASS_HOST_DEVICE
Params(
GM_ADDR ptrSrcReal_,
GM_ADDR ptrSrcImag_,
GM_ADDR ptrOutReal_,
GM_ADDR ptrOutImag_,
ScaleLayout layout_)
: ptrSrcReal(ptrSrcReal_), ptrSrcImag(ptrSrcImag_),
ptrOutReal(ptrOutReal_), ptrOutImag(ptrOutImag_), layout(layout_) {}
};
CATLASS_DEVICE
BlockEpilogueTranspose(Arch::Resource<ArchTag>& resource, uint32_t ubByteStart = 0)
{
uint32_t numTensorList = 2;
uint32_t tileSizeByByte = ArchTag::UB_SIZE / numTensorList / STAGES;
tileSizeByByte = tileSizeByByte / 32 * 32;
computeLength = tileSizeByByte / sizeof(Element);
for (uint32_t i = 0; i < STAGES; ++i) {
ubSrcRealTensor[i] = resource.ubBuf.template GetBufferByByte<Element>(ubByteStart);
ubByteStart += tileSizeByByte;
ubSrcImagTensor[i] = resource.ubBuf.template GetBufferByByte<Element>(ubByteStart);
ubByteStart += tileSizeByByte;
UbInEventList[i] = i;
}
}
CATLASS_DEVICE
~BlockEpilogueTranspose()
{}
CATLASS_DEVICE
void operator()(Params const& params)
{
AscendC::GlobalTensor<Element> srcRealGm;
srcRealGm.SetGlobalBuffer((__gm__ Element *)params.ptrSrcReal);
AscendC::GlobalTensor<Element> src1ImagGm;
srcImagGm.SetGlobalBuffer((__gm__ Element *)params.ptrSrcImag);
outRealGm.SetGlobalBuffer((__gm__ Element *)params.ptrOutReal);
AscendC::GlobalTensor<Element> outImagGm;
outImagGm.SetGlobalBuffer((__gm__ Element *)params.ptrOutImag);
uint32_t eleNumPerBatch = params.layout.shape(2);
uint32_t iterCount = (eleNumPerBatch + computeLength - 1) / computeLength;
uint32_t innerBatchDim = params.layout.shape(0);
uint32_t radixDim = params.layout.shape(1);
uint32_t remainDim = params.layout.shape(2);
for (uint32_t i = 0; i < STAGES; i++) {
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(UbInEventList[i]);
}
uint32_t flagId = 0;
for (uint32_t innerIdx = 0; innerIdx < innerBatchDim; ++innerIdx) {
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(UbInEventList[flagId]);
AscendC::DataCopyExtParams copyParams{(uint16_t)radixDim, static_cast<uint32_t>(eleNumPerBatch * sizeof(Element)), 0, 0, 0};
AscendC::DataCopyPadExtParams<Element> padParams{false, 0, 0, 0};
uint32_t srcOffset = innerIdx * radixDim * remainDim;
AscendC::DataCopyPad(ubSrcRealTensor[flagId], srcRealGm[srcOffset], copyParams, padParams);
AscendC::DataCopyPad(ubSrcImagTensor[flagId], srcImagGm[srcOffset], copyParams, padParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE3>(UbInEventList[flagId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_MTE3>(UbInEventList[flagId]);
uint32_t dstOffset = innerIdx * remainDim;
AscendC::DataCopyExtParams copyOutParams;
copyOutParams.blockCount = (uint16_t)radixDim;
copyOutParams.blockLen = static_cast<uint32_t>(eleNumPerBatch * sizeof(Element));
copyOutParams.srcStride = 0;
copyOutParams.dstStride = (innerBatchDim - 1) * remainDim * sizeof(Element);
copyOutParams.rsv = 0;
AscendC::DataCopyPad(outRealGm[dstOffset], ubSrcRealTensor[flagId], copyOutParams);
AscendC::DataCopyPad(outImagGm[dstOffset], ubSrcImagTensor[flagId], copyOutParams);
AscendC::SetFlag<AscendC::HardEvent::MTE3_MTE2>(UbInEventList[flagId]);
flagId = 1 - flagId;
}
for (uint32_t i = 0; i < STAGES; i++) {
AscendC::WaitFlag<AscendC::HardEvent::MTE3_MTE2>(UbInEventList[i]);
}
};
private:
Params params;
AscendC::LocalTensor<Element> ubSrcRealTensor[STAGES];
AscendC::LocalTensor<Element> ubSrcImagTensor[STAGES];
AscendC::GlobalTensor<Element> srcRealGm;
AscendC::GlobalTensor<Element> srcImagGm;
AscendC::GlobalTensor<Element> outRealGm;
AscendC::GlobalTensor<Element> outImagGm;
int32_t UbInEventList[STAGES];
uint32_t computeLength;
};
}
#endif