* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License")
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file cast_impl.h
* \brief
*/
#ifndef CANN_CUSTOM_OPS_CAST_IMPL_H
#define CANN_CUSTOM_OPS_CAST_IMPL_H
#include "kernel_operator.h"
namespace AscendcCast {
using namespace AscendC;
using AscendC::MicroAPI::LoadDist;
using AscendC::MicroAPI::StoreDist;
using AscendC::MicroAPI::RegLayout;
using AscendC::MicroAPI::SatMode;
using AscendC::MicroAPI::MaskMergeMode;
template<int N>
struct TypeGetTool;
template<>
struct TypeGetTool<CAST_TPL_BOOL> {
using type = bool;
};
template<>
struct TypeGetTool<CAST_TPL_INT8> {
using type = int8_t;
};
template<>
struct TypeGetTool<CAST_TPL_UINT8> {
using type = uint8_t;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT8_E4M3FN> {
using type = fp8_e4m3fn_t;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT8_E5M2> {
using type = fp8_e5m2_t;
};
template<>
struct TypeGetTool<CAST_TPL_HIFLOAT8> {
using type = hifloat8_t;
};
template<>
struct TypeGetTool<CAST_TPL_UINT16> {
using type = uint16_t;
};
template<>
struct TypeGetTool<CAST_TPL_INT16> {
using type = int16_t;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT16> {
using type = half;
};
template<>
struct TypeGetTool<CAST_TPL_BF16> {
using type = bfloat16_t;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT> {
using type = float;
};
template<>
struct TypeGetTool<CAST_TPL_INT32> {
using type = int32_t;
};
template<>
struct TypeGetTool<CAST_TPL_UINT32> {
using type = uint32_t;
};
template<>
struct TypeGetTool<CAST_TPL_INT64> {
using type = int64_t;
};
template<>
struct TypeGetTool<CAST_TPL_MAX> {
using type = int32_t;
};
template<>
struct TypeGetTool<CAST_TPL_COMPLEX32> {
using type = complex32;
};
template<>
struct TypeGetTool<CAST_TPL_COMPLEX64> {
using type = complex64;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT4_E2M1> {
using type = fp4x2_e2m1_t;
};
template<>
struct TypeGetTool<CAST_TPL_FLOAT4_E1M2> {
using type = fp4x2_e1m2_t;
};
template<>
struct TypeGetTool<CAST_TPL_DOUBLE> {
using type = double;
};
template<>
struct TypeGetTool<CAST_TPL_INT4> {
using type = int4x2_t;
};
constexpr int64_t B2_BITS = 2;
constexpr int64_t B7_BITS = 7;
constexpr int64_t B8_BITS = 8;
constexpr int64_t UB_BLOCK_ALIGN = 32;
constexpr int64_t UB_BLOCK_ALIGN_MINUS_ONE = 31;
constexpr int16_t B4_MASK = 0x000F;
constexpr int16_t SHIFT_FOUR_BITS = 4;
template <typename ST, typename DT>
class CastDirect {
public:
__aicore__ inline CastDirect(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, RoundMode roundMode,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
protected:
const CastTilingData *tilingData_{nullptr};
RoundMode rMode_{RoundMode::CAST_NONE};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
GlobalTensor<ST> xGm_;
GlobalTensor<DT> yGm_;
DataCopyExtParams dataCopyInParams_;
DataCopyPadExtParams<ST> padParams_;
DataCopyExtParams dataCopyOutParams_;
int64_t blockIdx_{0};
};
template <typename ST, typename DT>
__aicore__ inline void CastDirect<ST, DT>::Init(GM_ADDR x, GM_ADDR y, RoundMode roundMode,
const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
rMode_ = roundMode;
blockIdx_ = GetBlockIdx();
int64_t gmBlockOffset = blockIdx_ * tilingData_->blockFormer;
xGm_.SetGlobalBuffer((__gm__ ST *)x + gmBlockOffset);
yGm_.SetGlobalBuffer((__gm__ DT *)y + gmBlockOffset);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(ST));
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
dataCopyInParams_.blockCount = 1;
dataCopyInParams_.blockLen = 0;
dataCopyInParams_.srcStride = 0;
dataCopyInParams_.dstStride = 0;
dataCopyOutParams_.blockCount = 1;
dataCopyOutParams_.blockLen = 0;
dataCopyOutParams_.srcStride = 0;
dataCopyOutParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <typename ST, typename DT>
__aicore__ inline void CastDirect<ST, DT>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<ST>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyInParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <typename ST, typename DT>
__aicore__ inline void CastDirect<ST, DT>::Compute(const int64_t &len)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
Cast<DT, ST>(yLocal, xLocal, rMode_, len);
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <typename ST, typename DT>
__aicore__ inline void CastDirect<ST, DT>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<DT>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyOutParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <typename ST, typename DT>
__aicore__ inline void CastDirect<ST, DT>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t gmOffset = 0;
dataCopyInParams_.blockLen = tilingData_->ubFormer * sizeof(ST);
dataCopyOutParams_.blockLen = tilingData_->ubFormer * sizeof(DT);
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(gmOffset);
Compute(tilingData_->ubFormer);
CopyOut(gmOffset);
gmOffset += tilingData_->ubFormer;
}
dataCopyInParams_.blockLen = tailNum * sizeof(ST);
dataCopyOutParams_.blockLen = tailNum * sizeof(DT);
CopyIn(gmOffset);
Compute(tailNum);
CopyOut(gmOffset);
}
template <typename ST>
class CastDstBool {
public:
__aicore__ inline CastDstBool(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
constexpr static ST inZero_ = 0;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
protected:
const CastTilingData *tilingData_{nullptr};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
TBuf<> maskBuf_;
TBuf<> boolZeroBuf_;
GlobalTensor<ST> xGm_;
GlobalTensor<int8_t> yGm_;
DataCopyExtParams dataCopyInParams_;
DataCopyPadExtParams<ST> padParams_;
DataCopyExtParams dataCopyOutParams_;
int64_t blockIdx_{0};
LocalTensor<int8_t> boolZeroTensor_;
};
template <typename ST>
__aicore__ inline void CastDstBool<ST>::Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
blockIdx_ = GetBlockIdx();
int64_t gmBlockOffset = blockIdx_ * tilingData_->blockFormer;
xGm_.SetGlobalBuffer((__gm__ ST *)x + gmBlockOffset);
yGm_.SetGlobalBuffer((__gm__ int8_t *)y + gmBlockOffset);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(ST));
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(int8_t));
pipe_->InitBuffer(maskBuf_, (((tilingData_->ubFormer + B7_BITS) / B8_BITS) + UB_BLOCK_ALIGN_MINUS_ONE) /
UB_BLOCK_ALIGN * UB_BLOCK_ALIGN);
pipe_->InitBuffer(boolZeroBuf_, tilingData_->ubFormer * sizeof(int8_t));
boolZeroTensor_ = boolZeroBuf_.Get<int8_t>();
int8_t boolZero = 0;
Duplicate(boolZeroTensor_, boolZero, tilingData_->ubFormer);
dataCopyInParams_.blockCount = 1;
dataCopyInParams_.blockLen = 0;
dataCopyInParams_.srcStride = 0;
dataCopyInParams_.dstStride = 0;
dataCopyOutParams_.blockCount = 1;
dataCopyOutParams_.blockLen = 0;
dataCopyOutParams_.srcStride = 0;
dataCopyOutParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <typename ST>
__aicore__ inline void CastDstBool<ST>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<ST>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyInParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <typename ST>
__aicore__ inline void CastDstBool<ST>::Compute(const int64_t &len)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto maskTensor = maskBuf_.Get<uint8_t>();
CompareScalar(maskTensor, xLocal, inZero_, CMPMODE::EQ, len);
inQueueX_.FreeTensor(xLocal);
int8_t trueValue = 1;
auto yLocal = outQueue_.template AllocTensor<int8_t>();
Select(yLocal, maskTensor, boolZeroTensor_, trueValue, SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
outQueue_.EnQue(yLocal);
}
template <typename ST>
__aicore__ inline void CastDstBool<ST>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<int8_t>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyOutParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <typename ST>
__aicore__ inline void CastDstBool<ST>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t gmOffset = 0;
dataCopyInParams_.blockLen = tilingData_->ubFormer * sizeof(ST);
dataCopyOutParams_.blockLen = tilingData_->ubFormer * sizeof(int8_t);
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(gmOffset);
Compute(tilingData_->ubFormer);
CopyOut(gmOffset);
gmOffset += tilingData_->ubFormer;
}
dataCopyInParams_.blockLen = tailNum * sizeof(ST);
dataCopyOutParams_.blockLen = tailNum * sizeof(int8_t);
CopyIn(gmOffset);
Compute(tailNum);
CopyOut(gmOffset);
}
template <typename DT>
class CastThrough {
public:
__aicore__ inline CastThrough(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
protected:
const CastTilingData *tilingData_{nullptr};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
GlobalTensor<DT> xGm_;
GlobalTensor<DT> yGm_;
DataCopyExtParams dataCopyParams_;
DataCopyPadExtParams<DT> padParams_;
int64_t blockIdx_{0};
};
template <typename DT>
__aicore__ inline void CastThrough<DT>::Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
blockIdx_ = GetBlockIdx();
int64_t gmBlockOffset = blockIdx_ * tilingData_->blockFormer;
xGm_.SetGlobalBuffer((__gm__ DT *)x + gmBlockOffset);
yGm_.SetGlobalBuffer((__gm__ DT *)y + gmBlockOffset);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
dataCopyParams_.blockCount = 1;
dataCopyParams_.blockLen = 0;
dataCopyParams_.srcStride = 0;
dataCopyParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <typename DT>
__aicore__ inline void CastThrough<DT>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<DT>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <typename DT>
__aicore__ inline void CastThrough<DT>::Compute(const int64_t &len)
{
auto xLocal = inQueueX_.template DeQue<DT>();
auto yLocal = outQueue_.template AllocTensor<DT>();
Copy<DT>(yLocal, xLocal, len);
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <typename DT>
__aicore__ inline void CastThrough<DT>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<DT>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <typename DT>
__aicore__ inline void CastThrough<DT>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t gmOffset = 0;
dataCopyParams_.blockLen = tilingData_->ubFormer * sizeof(DT);
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(gmOffset);
Compute(tilingData_->ubFormer);
CopyOut(gmOffset);
gmOffset += tilingData_->ubFormer;
}
dataCopyParams_.blockLen = tailNum * sizeof(DT);
CopyIn(gmOffset);
Compute(tailNum);
CopyOut(gmOffset);
}
template <typename DT>
class CastUint1 {
public:
__aicore__ inline CastUint1(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
protected:
const CastTilingData *tilingData_{nullptr};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
TBuf<> oneBuf_;
GlobalTensor<uint8_t> xGm_;
GlobalTensor<DT> yGm_;
DataCopyExtParams dataCopyInParams_;
DataCopyPadExtParams<uint8_t> padParams_;
DataCopyExtParams dataCopyOutParams_;
int64_t blockIdx_{0};
LocalTensor<DT> oneTensor_;
};
template <typename DT>
__aicore__ inline void CastUint1<DT>::Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
blockIdx_ = GetBlockIdx();
int64_t blockFormerByte = tilingData_->blockFormer / B8_BITS;
int64_t gmInBlockOffset = blockIdx_ * blockFormerByte;
int64_t gmOutBlockOffset = blockIdx_ * tilingData_->blockFormer;
xGm_.SetGlobalBuffer((__gm__ uint8_t *)x + gmInBlockOffset);
yGm_.SetGlobalBuffer((__gm__ DT *)y + gmOutBlockOffset);
int64_t inFormerByte = tilingData_->ubFormer / B8_BITS;
pipe_->InitBuffer(inQueueX_, bufferNum_, inFormerByte);
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
pipe_->InitBuffer(oneBuf_, tilingData_->ubFormer * sizeof(DT));
oneTensor_ = oneBuf_.Get<DT>();
DT oneValue = 1;
Duplicate(oneTensor_, oneValue, tilingData_->ubFormer);
dataCopyInParams_.blockCount = 1;
dataCopyInParams_.blockLen = 0;
dataCopyInParams_.srcStride = 0;
dataCopyInParams_.dstStride = 0;
dataCopyOutParams_.blockCount = 1;
dataCopyOutParams_.blockLen = 0;
dataCopyOutParams_.srcStride = 0;
dataCopyOutParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <typename DT>
__aicore__ inline void CastUint1<DT>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<uint8_t>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyInParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <typename DT>
__aicore__ inline void CastUint1<DT>::Compute(const int64_t &len)
{
auto xLocal = inQueueX_.template DeQue<uint8_t>();
auto yLocal = outQueue_.template AllocTensor<DT>();
DT zeroValue = 0;
Select(yLocal, xLocal, oneTensor_, zeroValue, SELMODE::VSEL_TENSOR_SCALAR_MODE, len);
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <typename DT>
__aicore__ inline void CastUint1<DT>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<DT>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyOutParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <typename DT>
__aicore__ inline void CastUint1<DT>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t gmOffset = 0;
dataCopyInParams_.blockLen = tilingData_->ubFormer / B8_BITS;
dataCopyOutParams_.blockLen = tilingData_->ubFormer * sizeof(DT);
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(gmOffset);
Compute(tilingData_->ubFormer);
CopyOut(gmOffset);
gmOffset += tilingData_->ubFormer;
}
dataCopyInParams_.blockLen = (tilingData_->ubFormer + B7_BITS) / B8_BITS;
dataCopyOutParams_.blockLen = tailNum * sizeof(DT);
CopyIn(gmOffset);
Compute(tailNum);
CopyOut(gmOffset);
}
template <typename ST, typename MT, typename DT>
class CastTwo {
public:
__aicore__ inline CastTwo(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, RoundMode roundMode1, RoundMode roundMode2,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
protected:
const CastTilingData *tilingData_{nullptr};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
TBuf<> midTypeBuf_;
RoundMode rMode1_{RoundMode::CAST_NONE};
RoundMode rMode2_{RoundMode::CAST_NONE};
GlobalTensor<ST> xGm_;
GlobalTensor<DT> yGm_;
DataCopyExtParams dataCopyInParams_;
DataCopyPadExtParams<ST> padParams_;
DataCopyExtParams dataCopyOutParams_;
int64_t blockIdx_{0};
};
template <typename ST, typename MT, typename DT>
__aicore__ inline void CastTwo<ST, MT, DT>::Init(GM_ADDR x, GM_ADDR y, RoundMode roundMode1,
RoundMode roundMode2, const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
rMode1_ = roundMode1;
rMode2_ = roundMode2;
blockIdx_ = GetBlockIdx();
int64_t gmBlockOffset = blockIdx_ * tilingData_->blockFormer;
xGm_.SetGlobalBuffer((__gm__ ST *)x + gmBlockOffset);
yGm_.SetGlobalBuffer((__gm__ DT *)y + gmBlockOffset);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(ST));
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
pipe_->InitBuffer(midTypeBuf_, tilingData_->ubFormer * sizeof(MT));
dataCopyInParams_.blockCount = 1;
dataCopyInParams_.blockLen = 0;
dataCopyInParams_.srcStride = 0;
dataCopyInParams_.dstStride = 0;
dataCopyOutParams_.blockCount = 1;
dataCopyOutParams_.blockLen = 0;
dataCopyOutParams_.srcStride = 0;
dataCopyOutParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <typename ST, typename MT, typename DT>
__aicore__ inline void CastTwo<ST, MT, DT>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<ST>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyInParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <typename ST, typename MT, typename DT>
__aicore__ inline void CastTwo<ST, MT, DT>::Compute(const int64_t &len)
{
auto xLocal = inQueueX_.template DeQue<ST>();
LocalTensor<MT> midLocal = midTypeBuf_.Get<MT>();
Cast<MT, ST>(midLocal, xLocal, rMode1_, len);
inQueueX_.FreeTensor(xLocal);
auto yLocal = outQueue_.template AllocTensor<DT>();
if constexpr (std::is_same<MT, uint32_t>::value && std::is_same<DT, int64_t>::value) {
LocalTensor<int32_t> tempTensor = midLocal.template ReinterpretCast<int32_t>();
Cast<DT, int32_t>(yLocal, tempTensor, rMode1_, len);
} else if constexpr (std::is_same<MT, uint32_t>::value && std::is_same<DT, half>::value) {
LocalTensor<int32_t> tempTensor = midLocal.template ReinterpretCast<int32_t>();
Cast<DT, int32_t>(yLocal, tempTensor, rMode2_, len);
} else if constexpr (std::is_same<MT, uint32_t>::value && std::is_same<DT, float>::value) {
LocalTensor<int32_t> tempTensor = midLocal.template ReinterpretCast<int32_t>();
Cast<DT, int32_t>(yLocal, tempTensor, rMode2_, len);
} else {
Cast<DT, MT>(yLocal, midLocal, rMode2_, len);
}
outQueue_.EnQue(yLocal);
}
template <typename ST, typename MT, typename DT>
__aicore__ inline void CastTwo<ST, MT, DT>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<DT>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyOutParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <typename ST, typename MT, typename DT>
__aicore__ inline void CastTwo<ST, MT, DT>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t gmOffset = 0;
dataCopyInParams_.blockLen = tilingData_->ubFormer * sizeof(ST);
dataCopyOutParams_.blockLen = tilingData_->ubFormer * sizeof(DT);
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(gmOffset);
Compute(tilingData_->ubFormer);
CopyOut(gmOffset);
gmOffset += tilingData_->ubFormer;
}
dataCopyInParams_.blockLen = tailNum * sizeof(ST);
dataCopyOutParams_.blockLen = tailNum * sizeof(DT);
CopyIn(gmOffset);
Compute(tailNum);
CopyOut(gmOffset);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
class CastMicro {
public:
__aicore__ inline CastMicro(){};
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr);
__aicore__ inline void Process();
constexpr static int32_t bufferNum_ = 2;
protected:
__aicore__ inline void CopyIn(const int64_t &gmOffset);
__aicore__ inline void Compute(const int64_t &len, uint16_t regLoop);
__aicore__ inline void CopyOut(const int64_t &gmOffset);
private:
__aicore__ inline void ComputeInOut(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastInter(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastDeinter(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastCastDeinter(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastInterCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastDeinterCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastCastDeinterCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeCastInterCastCast(const int64_t &len, uint16_t regLoop);
__aicore__ inline void ComputeDeinterShift(const int64_t &len, uint16_t regLoop);
protected:
const CastTilingData *tilingData_{nullptr};
TPipe *pipe_{nullptr};
TQue<QuePosition::VECIN, 1> inQueueX_;
TQue<QuePosition::VECOUT, 1> outQueue_;
GlobalTensor<ST> xGm_;
GlobalTensor<DT> yGm_;
DataCopyExtParams dataCopyInParams_;
DataCopyPadExtParams<ST> padParams_;
DataCopyExtParams dataCopyOutParams_;
int64_t blockIdx_{0};
};
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::Init(GM_ADDR x, GM_ADDR y,
const CastTilingData *tilingData, TPipe *pipePtr)
{
tilingData_ = tilingData;
pipe_ = pipePtr;
blockIdx_ = GetBlockIdx();
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1 || ORIG_DTYPE_X == DT_FLOAT4_E1M2
xGm_.SetGlobalBuffer((__gm__ ST *)x + blockIdx_ * tilingData_->blockFormer / B2_BITS);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(ST) / B2_BITS);
#else
xGm_.SetGlobalBuffer((__gm__ ST *)x + blockIdx_ * tilingData_->blockFormer);
pipe_->InitBuffer(inQueueX_, bufferNum_, tilingData_->ubFormer * sizeof(ST));
#endif
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1 || ORIG_DTYPE_Y == DT_FLOAT4_E1M2 || ORIG_DTYPE_Y == DT_INT4
yGm_.SetGlobalBuffer((__gm__ DT *)y + blockIdx_ * tilingData_->blockFormer / B2_BITS);
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT) / B2_BITS);
#else
yGm_.SetGlobalBuffer((__gm__ DT *)y + blockIdx_ * tilingData_->blockFormer);
pipe_->InitBuffer(outQueue_, bufferNum_, tilingData_->ubFormer * sizeof(DT));
#endif
dataCopyInParams_.blockCount = 1;
dataCopyInParams_.blockLen = 0;
dataCopyInParams_.srcStride = 0;
dataCopyInParams_.dstStride = 0;
dataCopyOutParams_.blockCount = 1;
dataCopyOutParams_.blockLen = 0;
dataCopyOutParams_.srcStride = 0;
dataCopyOutParams_.dstStride = 0;
padParams_.isPad = false;
padParams_.leftPadding = 0;
padParams_.rightPadding = 0;
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::CopyIn(const int64_t &gmOffset)
{
auto xLocalIn = inQueueX_.template AllocTensor<ST>();
DataCopyPad(xLocalIn, xGm_[gmOffset], dataCopyInParams_, padParams_);
inQueueX_.EnQue(xLocalIn);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::Compute(const int64_t &len, uint16_t regLoop)
{
if constexpr (id == CAST_TEMPLATE_MIRCRO_INOUT) {
ComputeInOut(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST) {
ComputeCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_INTER) {
ComputeCastInter(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_DEINTER) {
ComputeCastDeinter(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_CAST_DEINTER) {
ComputeCastCastDeinter(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_CAST) {
ComputeCastCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_INTER_CAST) {
ComputeCastInterCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_DEINTER_CAST) {
ComputeCastDeinterCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_CAST_DEINTER_CAST) {
ComputeCastCastDeinterCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_CAST_INTER_CAST_CAST) {
ComputeCastInterCastCast(len, regLoop);
} else if constexpr (id == CAST_TEMPLATE_MIRCRO_DEINTER_SHIFT) {
ComputeDeinterShift(len, regLoop);
}
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeInOut(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregTemp;
MicroAPI::MaskReg mask;
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregTemp, srcAddr, regCopyInStep);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregTemp, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
if constexpr (std::is_same<MST, uint32_t>::value && std::is_same<MDT, float>::value) {
regCopyOutStep = regCopyOutStep * 2;
}
static constexpr MicroAPI::CastTrait trait = []() {
if constexpr (std::is_same<MST, bfloat16_t>::value && std::is_same<MMT, int32_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::SAT, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, half>::value && std::is_same<MMT, int32_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, float>::value && std::is_same<MMT, int32_t>::value) {
return MicroAPI::CastTrait{RegLayout::UNKNOWN, SatMode::SAT, MaskMergeMode::ZEROING, castMode1};
} else {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
}
}();
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<MMT> vregOut;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<MST>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
#if ORIG_DTYPE_X == DT_UINT32 && ORIG_DTYPE_Y == DT_FLOAT
mask = MicroAPI::UpdateMask<int64_t>(count);
#else
mask = MicroAPI::UpdateMask<MDT>(count);
#endif
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1
MicroAPI::Cast<MMT, fp4x2_e2m1_t, trait>(vregOut, (MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_FLOAT4_E1M2
MicroAPI::Cast<MMT, fp4x2_e1m2_t, trait>(vregOut, (MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_UINT32
MicroAPI::Cast<MMT, int64_t, trait>(vregOut, (MicroAPI::RegTensor<int64_t> &)vregIn, maskAll);
#else
MicroAPI::Cast<MMT, MST, trait>(vregOut, vregIn, maskAll);
#endif
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastInter(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
static constexpr MicroAPI::CastTrait trait = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
MMT zeroValue = 0;
uint32_t count = static_cast<uint32_t>(len);
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<MMT> vregCast;
MicroAPI::RegTensor<MMT> vregOut1;
MicroAPI::RegTensor<MMT> vregOut2;
MicroAPI::RegTensor<MMT> vregZero;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<MST>();
MicroAPI::Duplicate(vregZero, zeroValue);
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
MicroAPI::Cast<MMT, MST, trait>(vregCast, vregIn, maskAll);
MicroAPI::Interleave<MMT>(vregOut1, vregOut2, vregCast, vregZero);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregOut1, regCopyOutStep, mask);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregOut2, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastDeinter(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
static constexpr MicroAPI::CastTrait trait = []() {
if constexpr (std::is_same<MST, float>::value && std::is_same<MMT, int64_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::SAT, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, float>::value && std::is_same<MMT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, castMode1};
}
}();
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn1;
MicroAPI::RegTensor<MST> vregIn2;
MicroAPI::RegTensor<MMT> vregCast1;
MicroAPI::RegTensor<MMT> vregCast2;
MicroAPI::RegTensor<int32_t> vregOut;
MicroAPI::RegTensor<int32_t> vregOutNoUse;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<MST>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn1, srcAddr, regCopyInStep);
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn2, srcAddr, regCopyInStep);
MicroAPI::Cast<MMT, MST, trait>(vregCast1, vregIn1, maskAll);
MicroAPI::Cast<MMT, MST, trait>(vregCast2, vregIn2, maskAll);
MicroAPI::DeInterleave<int32_t>(vregOut, vregOutNoUse,
(MicroAPI::RegTensor<int32_t> &)vregCast1, (MicroAPI::RegTensor<int32_t> &)vregCast2);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastCastDeinter(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
MMT zeroValue = 0;
static constexpr MicroAPI::CastTrait trait1 = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
static constexpr MicroAPI::CastTrait trait2 = {RegLayout::ZERO, SatMode::SAT, MaskMergeMode::ZEROING, castMode2};
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<MMT> vregCast;
MicroAPI::RegTensor<MMT> vregInter1;
MicroAPI::RegTensor<MMT> vregInter2;
MicroAPI::RegTensor<int64_t> vregCast1;
MicroAPI::RegTensor<int64_t> vregCast2;
MicroAPI::RegTensor<int32_t> vregOut;
MicroAPI::RegTensor<int32_t> vregOutNoUse;
MicroAPI::RegTensor<MMT> vregZero;
MicroAPI::Duplicate(vregZero, zeroValue);
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<int8_t>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
MicroAPI::Cast<MMT, MST, trait1>(vregCast, vregIn, maskAll);
MicroAPI::Interleave<MMT>(vregInter1, vregInter2, vregCast, vregZero);
MicroAPI::Cast<int64_t, MMT, trait2>(vregCast1, vregInter1, maskAll);
MicroAPI::Cast<int64_t, MMT, trait2>(vregCast2, vregInter2, maskAll);
MicroAPI::DeInterleave<int32_t>(vregOut, vregOutNoUse,
(MicroAPI::RegTensor<int32_t> &)vregCast1, (MicroAPI::RegTensor<int32_t> &)vregCast2);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
(MicroAPI::RegTensor<MDT> &)vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
if constexpr (std::is_same<MST, uint16_t>::value && std::is_same<MDT, bfloat16_t>::value) {
regCopyOutStep = regCopyOutStep * 2;
}
static constexpr MicroAPI::CastTrait trait1 = []() {
if constexpr (std::is_same<MST, uint8_t>::value && std::is_same<MMT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, half>::value && std::is_same<MMT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::UNKNOWN, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, uint16_t>::value && std::is_same<MMT, float>::value) {
return MicroAPI::CastTrait{RegLayout::UNKNOWN, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
}
}();
static constexpr MicroAPI::CastTrait trait2 = []() {
if constexpr (std::is_same<MMT, bfloat16_t>::value && std::is_same<MDT, uint32_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode2};
} else if constexpr (std::is_same<MMT, bfloat16_t>::value && std::is_same<MDT, half>::value) {
return MicroAPI::CastTrait{RegLayout::UNKNOWN, SatMode::NO_SAT, MaskMergeMode::ZEROING, castMode2};
} else if constexpr (std::is_same<MMT, float>::value && std::is_same<MDT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, castMode2};
}
}();
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<MMT> vregMid;
MicroAPI::RegTensor<MDT> vregOut;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
#if ORIG_DTYPE_X == DT_UINT16 && ORIG_DTYPE_Y == DT_BF16
mask = MicroAPI::UpdateMask<MMT>(count);
#else
mask = MicroAPI::UpdateMask<MDT>(count);
#endif
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1
MicroAPI::Cast<MMT, fp4x2_e2m1_t, trait1>(vregMid, (MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_FLOAT4_E1M2
MicroAPI::Cast<MMT, fp4x2_e1m2_t, trait1>(vregMid, (MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_UINT16
MicroAPI::Cast<MMT, int32_t, trait1>(vregMid, (MicroAPI::RegTensor<int32_t> &)vregIn, maskAll);
#else
MicroAPI::Cast<MMT, MST, trait1>(vregMid, vregIn, maskAll);
#endif
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1
MicroAPI::Cast<fp4x2_e2m1_t, MMT, trait2>((MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregOut, vregMid, maskAll);
#elif ORIG_DTYPE_Y == DT_FLOAT4_E1M2
MicroAPI::Cast<fp4x2_e1m2_t, MMT, trait2>((MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregOut, vregMid, maskAll);
#else
MicroAPI::Cast<MDT, MMT, trait2>(vregOut, vregMid, maskAll);
#endif
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastInterCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
MMT zeroValue = 0;
static constexpr MicroAPI::CastTrait trait1 = []() {
if constexpr (std::is_same<MST, uint8_t>::value && std::is_same<MMT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
}
}();
static constexpr MicroAPI::CastTrait trait2 = []() {
if constexpr (std::is_same<MMT, bfloat16_t>::value && std::is_same<MDT, float>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode2};
}
}();
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<MMT> vregMid;
MicroAPI::RegTensor<MDT> vregOut1;
MicroAPI::RegTensor<MDT> vregOut2;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
MicroAPI::RegTensor<MMT> vregZero;
MicroAPI::Duplicate(vregZero, zeroValue);
MicroAPI::RegTensor<MMT> vregInter1;
MicroAPI::RegTensor<MMT> vregInter2;
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1
MicroAPI::Cast<MMT, fp4x2_e2m1_t, trait1>(vregMid, (MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_FLOAT4_E1M2
MicroAPI::Cast<MMT, fp4x2_e1m2_t, trait1>(vregMid, (MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregIn, maskAll);
#else
MicroAPI::Cast<MMT, MST, trait1>(vregMid, vregIn, maskAll);
#endif
MicroAPI::Interleave<MMT>(vregInter1, vregInter2, vregMid, vregZero);
MicroAPI::Cast<MDT, MMT, trait2>(vregOut1, vregInter1, maskAll);
MicroAPI::Cast<MDT, MMT, trait2>(vregOut2, vregInter2, maskAll);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut1, regCopyOutStep, mask);
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut2, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastDeinterCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
if constexpr ((std::is_same<MST, uint32_t>::value && std::is_same<MDT, half>::value) ||
(std::is_same<MST, uint32_t>::value && std::is_same<MDT, bfloat16_t>::value)) {
regCopyOutStep = regCopyOutStep * 2;
}
static constexpr MicroAPI::CastTrait trait1 = []() {
if constexpr (std::is_same<MST, float>::value && std::is_same<MMT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, castMode1};
} else if constexpr (std::is_same<MST, uint32_t>::value && std::is_same<MMT, float>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode1};
}
}();
static constexpr MicroAPI::CastTrait trait2 = []() {
if constexpr (std::is_same<MMT, bfloat16_t>::value && std::is_same<MDT, uint32_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, castMode2};
} else if constexpr (std::is_same<MMT, float>::value && std::is_same<MDT, half>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::SAT, MaskMergeMode::ZEROING, castMode2};
} else if constexpr (std::is_same<MMT, float>::value && std::is_same<MDT, bfloat16_t>::value) {
return MicroAPI::CastTrait{RegLayout::ZERO, SatMode::SAT, MaskMergeMode::ZEROING, castMode2};
}
}();
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn1;
MicroAPI::RegTensor<MST> vregIn2;
MicroAPI::RegTensor<MMT> vregMid1;
MicroAPI::RegTensor<MMT> vregMid2;
MicroAPI::RegTensor<MMT> vregDeinter;
MicroAPI::RegTensor<MMT> vregDeinterNoUse;
MicroAPI::RegTensor<MDT> vregOut;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn1, srcAddr, regCopyInStep);
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn2, srcAddr, regCopyInStep);
#if ORIG_DTYPE_X == DT_UINT32
mask = MicroAPI::UpdateMask<MMT>(count);
MicroAPI::Cast<MMT, int64_t, trait1>(vregMid1, (MicroAPI::RegTensor<int64_t> &)vregIn1, maskAll);
MicroAPI::Cast<MMT, int64_t, trait1>(vregMid2, (MicroAPI::RegTensor<int64_t> &)vregIn2, maskAll);
#else
mask = MicroAPI::UpdateMask<MDT>(count);
MicroAPI::Cast<MMT, MST, trait1>(vregMid1, vregIn1, maskAll);
MicroAPI::Cast<MMT, MST, trait1>(vregMid2, vregIn2, maskAll);
#endif
MicroAPI::DeInterleave<MMT>(vregDeinter, vregDeinterNoUse, vregMid1, vregMid2);
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1
MicroAPI::Cast<fp4x2_e2m1_t, MMT, trait2>((MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregOut, vregDeinter, maskAll);
#elif ORIG_DTYPE_Y == DT_FLOAT4_E1M2
MicroAPI::Cast<fp4x2_e1m2_t, MMT, trait2>((MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregOut, vregDeinter, maskAll);
#else
MicroAPI::Cast<MDT, MMT, trait2>(vregOut, vregDeinter, maskAll);
#endif
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastCastDeinterCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
static constexpr MicroAPI::CastTrait trait1 = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, RoundMode::UNKNOWN};
static constexpr MicroAPI::CastTrait trait2 = {RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, RoundMode::CAST_RINT};
static constexpr MicroAPI::CastTrait trait3 = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, RoundMode::CAST_RINT};
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn1;
MicroAPI::RegTensor<MST> vregIn2;
MicroAPI::RegTensor<MMT> vregMid1;
MicroAPI::RegTensor<MMT> vregMid2;
MicroAPI::RegTensor<bfloat16_t> vregMidBf1;
MicroAPI::RegTensor<bfloat16_t> vregMidBf2;
MicroAPI::RegTensor<bfloat16_t> vregDeinter;
MicroAPI::RegTensor<bfloat16_t> vregDeinterNoUse;
MicroAPI::RegTensor<MDT> vregOut;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn1, srcAddr, regCopyInStep);
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn2, srcAddr, regCopyInStep);
MicroAPI::Cast<MMT, MST, trait1>(vregMid1, vregIn1, maskAll);
MicroAPI::Cast<MMT, MST, trait1>(vregMid2, vregIn2, maskAll);
MicroAPI::Cast<bfloat16_t, MMT, trait2>(vregMidBf1, vregMid1, maskAll);
MicroAPI::Cast<bfloat16_t, MMT, trait2>(vregMidBf2, vregMid2, maskAll);
MicroAPI::DeInterleave<bfloat16_t>(vregDeinter, vregDeinterNoUse, vregMidBf1, vregMidBf2);
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1
MicroAPI::Cast<fp4x2_e2m1_t, bfloat16_t, trait3>((MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregOut, vregDeinter, maskAll);
#elif ORIG_DTYPE_Y == DT_FLOAT4_E1M2
MicroAPI::Cast<fp4x2_e1m2_t, bfloat16_t, trait3>((MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregOut, vregDeinter, maskAll);
#else
MicroAPI::Cast<MDT, bfloat16_t, trait3>(vregOut, vregDeinter, maskAll);
#endif
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut, regCopyOutStep, maskAll);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeCastInterCastCast(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
bfloat16_t zeroValue = 0;
static constexpr MicroAPI::CastTrait trait1 = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, RoundMode::UNKNOWN};
static constexpr MicroAPI::CastTrait trait2 = {RegLayout::ZERO, SatMode::UNKNOWN, MaskMergeMode::ZEROING, RoundMode::UNKNOWN};
#if ORIG_DTYPE_Y == DT_HIFLOAT8
static constexpr MicroAPI::CastTrait trait3 = {RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, RoundMode::CAST_ROUND};
#else
static constexpr MicroAPI::CastTrait trait3 = {RegLayout::ZERO, SatMode::NO_SAT, MaskMergeMode::ZEROING, RoundMode::CAST_RINT};
#endif
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn;
MicroAPI::RegTensor<bfloat16_t> vregMidBf;
MicroAPI::RegTensor<MMT> vregMid1;
MicroAPI::RegTensor<MMT> vregMid2;
MicroAPI::RegTensor<MDT> vregOut1;
MicroAPI::RegTensor<MDT> vregOut2;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
MicroAPI::RegTensor<bfloat16_t> vregZero;
MicroAPI::Duplicate(vregZero, zeroValue);
MicroAPI::RegTensor<bfloat16_t> vregInter1;
MicroAPI::RegTensor<bfloat16_t> vregInter2;
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn, srcAddr, regCopyInStep);
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1
MicroAPI::Cast<bfloat16_t, fp4x2_e2m1_t, trait1>(vregMidBf, (MicroAPI::RegTensor<fp4x2_e2m1_t> &)vregIn, maskAll);
#elif ORIG_DTYPE_X == DT_FLOAT4_E1M2
MicroAPI::Cast<bfloat16_t, fp4x2_e1m2_t, trait1>(vregMidBf, (MicroAPI::RegTensor<fp4x2_e1m2_t> &)vregIn, maskAll);
#else
MicroAPI::Cast<bfloat16_t, MST, trait1>(vregMidBf, vregIn, maskAll);
#endif
MicroAPI::Interleave<bfloat16_t>(vregInter1, vregInter2, vregMidBf, vregZero);
MicroAPI::Cast<MMT, bfloat16_t, trait2>(vregMid1, vregInter1, maskAll);
MicroAPI::Cast<MMT, bfloat16_t, trait2>(vregMid2, vregInter2, maskAll);
#if ORIG_DTYPE_Y == DT_HIFLOAT8
MicroAPI::Cast<hifloat8_t, MMT, trait3>((MicroAPI::RegTensor<hifloat8_t> &)vregOut1, vregMid1, maskAll);
MicroAPI::Cast<hifloat8_t, MMT, trait3>((MicroAPI::RegTensor<hifloat8_t> &)vregOut2, vregMid2, maskAll);
#elif ORIG_DTYPE_Y == DT_FLOAT8_E5M2
MicroAPI::Cast<fp8_e5m2_t, MMT, trait3>((MicroAPI::RegTensor<fp8_e5m2_t> &)vregOut1, vregMid1, maskAll);
MicroAPI::Cast<fp8_e5m2_t, MMT, trait3>((MicroAPI::RegTensor<fp8_e5m2_t> &)vregOut2, vregMid2, maskAll);
#elif ORIG_DTYPE_Y == DT_FLOAT8_E4M3FN
MicroAPI::Cast<fp8_e4m3fn_t, MMT, trait3>((MicroAPI::RegTensor<fp8_e4m3fn_t> &)vregOut1, vregMid1, maskAll);
MicroAPI::Cast<fp8_e4m3fn_t, MMT, trait3>((MicroAPI::RegTensor<fp8_e4m3fn_t> &)vregOut2, vregMid2, maskAll);
#endif
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut1, regCopyOutStep, maskAll);
MicroAPI::DataCopy<MDT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut2, regCopyOutStep, maskAll);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::ComputeDeinterShift(const int64_t &len, uint16_t regLoop)
{
auto xLocal = inQueueX_.template DeQue<ST>();
auto yLocal = outQueue_.template AllocTensor<DT>();
__ubuf__ MST *srcAddr = (__ubuf__ MST *)xLocal.GetPhyAddr();
__ubuf__ MDT *dstAddr = (__ubuf__ MDT *)yLocal.GetPhyAddr();
int32_t regCopyInStep = static_cast<int32_t>(tilingData_->regCopyInStep);
int32_t regCopyOutStep = static_cast<int32_t>(tilingData_->regCopyOutStep);
uint32_t count = static_cast<uint32_t>(len);
__VEC_SCOPE__ {
MicroAPI::RegTensor<MST> vregIn1;
MicroAPI::RegTensor<MST> vregIn2;
MicroAPI::RegTensor<MMT> vregDeinter1;
MicroAPI::RegTensor<MMT> vregDeinter2;
MicroAPI::RegTensor<MMT> vregMid1;
MicroAPI::RegTensor<MMT> vregMid2;
MicroAPI::RegTensor<MMT> vreg4Bit;
MicroAPI::RegTensor<MMT> vregOut;
MicroAPI::MaskReg mask;
MicroAPI::MaskReg maskAll = MicroAPI::CreateMask<uint8_t>();
MicroAPI::Duplicate(vreg4Bit, B4_MASK);
for (uint16_t loopIdx = 0; loopIdx < regLoop; loopIdx++) {
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn1, srcAddr, regCopyInStep);
MicroAPI::DataCopy<MST, MicroAPI::PostLiteral::POST_MODE_UPDATE, ldDist>(vregIn2, srcAddr, regCopyInStep);
MicroAPI::DeInterleave<MMT>(vregDeinter1, vregDeinter2, vregIn1, vregIn2);
MicroAPI::And<MMT, MicroAPI::MaskMergeMode::ZEROING>(vregMid1, vregDeinter1, vreg4Bit, maskAll);
MicroAPI::ShiftLefts(vregMid2, vregDeinter2, SHIFT_FOUR_BITS, maskAll);
MicroAPI::Or<MMT, MicroAPI::MaskMergeMode::ZEROING>(vregOut, vregMid1, vregMid2, maskAll);
mask = MicroAPI::UpdateMask<MMT>(count);
MicroAPI::DataCopy<MMT, MicroAPI::PostLiteral::POST_MODE_UPDATE, stDist>(dstAddr,
vregOut, regCopyOutStep, mask);
}
}
inQueueX_.FreeTensor(xLocal);
outQueue_.EnQue(yLocal);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::CopyOut(const int64_t &gmOffset)
{
auto yLocalOut = outQueue_.template DeQue<DT>();
DataCopyPad(yGm_[gmOffset], yLocalOut, dataCopyOutParams_);
outQueue_.FreeTensor(yLocalOut);
}
template <int id, typename ST, typename DT, typename MST, typename MMT, typename MDT,
LoadDist ldDist, StoreDist stDist, RoundMode castMode1, RoundMode castMode2>
__aicore__ inline void CastMicro<id, ST, DT, MST, MMT, MDT,
ldDist, stDist, castMode1, castMode2>::Process()
{
bool isLastBlockFlag = (blockIdx_ == tilingData_->blockNum - 1);
int64_t loopNum = isLastBlockFlag ?
tilingData_->ubLoopOfTailBlock : tilingData_->ubLoopOfFormerBlock;
int64_t tailNum = isLastBlockFlag ?
tilingData_->ubTailOfTailBlock : tilingData_->ubTailOfFormerBlock;
int64_t tailRegLoop = isLastBlockFlag ?
tilingData_->ubTailOfTailRegLoop : tilingData_->ubTailOfFormerRegLoop;
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1 || ORIG_DTYPE_X == DT_FLOAT4_E1M2
int64_t xRealFormer = tilingData_->ubFormer / B2_BITS;
#else
int64_t xRealFormer = tilingData_->ubFormer;
#endif
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1 || ORIG_DTYPE_Y == DT_FLOAT4_E1M2 || ORIG_DTYPE_Y == DT_INT4
int64_t yRealFormer = tilingData_->ubFormer / B2_BITS;
#else
int64_t yRealFormer = tilingData_->ubFormer;
#endif
dataCopyInParams_.blockLen = xRealFormer * sizeof(ST);
dataCopyOutParams_.blockLen = yRealFormer * sizeof(DT);
int64_t xGmOffset = 0;
int64_t yGmOffset = 0;
for (int64_t i = 0; i < loopNum - 1; ++i) {
CopyIn(xGmOffset);
Compute(tilingData_->ubFormer, static_cast<uint16_t>(tilingData_->ubFormerRegLoop));
CopyOut(yGmOffset);
xGmOffset += xRealFormer;
yGmOffset += yRealFormer;
}
#if ORIG_DTYPE_X == DT_FLOAT4_E2M1 || ORIG_DTYPE_X == DT_FLOAT4_E1M2
dataCopyInParams_.blockLen = (tailNum + 1) / B2_BITS * sizeof(ST);
#else
dataCopyInParams_.blockLen = tailNum * sizeof(ST);
#endif
#if ORIG_DTYPE_Y == DT_FLOAT4_E2M1 || ORIG_DTYPE_Y == DT_FLOAT4_E1M2 || ORIG_DTYPE_Y == DT_INT4
dataCopyOutParams_.blockLen = (tailNum + 1) / B2_BITS * sizeof(DT);
#else
dataCopyOutParams_.blockLen = tailNum * sizeof(DT);
#endif
CopyIn(xGmOffset);
Compute(tailNum, static_cast<uint16_t>(tailRegLoop));
CopyOut(yGmOffset);
}
}
#endif