* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mte.h
* \brief
*/
#ifndef __LOGICALTENSOR_TILEOP_MTE__
#define __LOGICALTENSOR_TILEOP_MTE__
#include "tileop_common.h"
namespace TileOp {
template <typename T, unsigned T0, unsigned T1, unsigned UBS, unsigned GMS>
TILEOP void UBCopyIn(__ubuf__ T* dst, __gm__ T* src)
{
constexpr uint16_t nBurst = T0;
constexpr uint32_t lenBurst = T1 * sizeof(T);
constexpr uint32_t gmGap = (GMS - T1) * sizeof(T);
constexpr uint32_t blockSize = 32 / sizeof(T);
constexpr uint32_t ubGap = (UBS - T1) / blockSize;
static_assert(nBurst < ((1ULL << 12) - 1ULL));
static_assert(lenBurst < ((1ULL << 21) - 1ULL));
static_assert(gmGap < ((1ULL << 32) - 1ULL));
if constexpr (T1 == 0) {
return;
}
if constexpr (sizeof(T) == 1) {
copy_gm_to_ubuf_align_b8(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , gmGap, ubGap);
} else if (sizeof(T) == 2) {
copy_gm_to_ubuf_align_b16(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , gmGap, ubGap);
} else {
copy_gm_to_ubuf_align_b32(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , gmGap, ubGap);
}
}
template <
typename T, unsigned T0, unsigned T1, unsigned T2, unsigned T3, unsigned T4, unsigned UBS1, unsigned UBS2,
unsigned UBS3, unsigned UBS4, unsigned GMS1, unsigned GMS2, unsigned GMS3, unsigned GMS4>
TILEOP void UBCopyIn(__ubuf__ T* dst, __gm__ T* src)
{
static_assert((UBS4 * sizeof(T)) % 32 == 0, "UB tile must be 32B aligned!");
for (int i0 = 0; i0 < T0; i0++) {
__gm__ T* src0 = src;
__ubuf__ T* dst0 = dst;
for (int i1 = 0; i1 < T1; i1++) {
__gm__ T* src1 = src0;
__ubuf__ T* dst1 = dst0;
for (int i2 = 0; i2 < T2; i2++) {
TileOp::UBCopyIn<T, T3, T4, UBS4, GMS4>(dst1, src1);
src1 += GMS3 * GMS4;
dst1 += UBS3 * UBS4;
}
src0 += GMS2 * GMS3 * GMS4;
dst0 += UBS2 * UBS3 * UBS4;
}
src += GMS1 * GMS2 * GMS3 * GMS4;
dst += UBS1 * UBS2 * UBS3 * UBS4;
}
}
template <typename T, unsigned T0, unsigned T1, unsigned GMS, unsigned UBS>
TILEOP void UBCopyOut(__gm__ T* dst, __ubuf__ T* src)
{
constexpr uint16_t nBurst = T0;
constexpr uint32_t lenBurst = T1 * sizeof(T);
constexpr uint32_t gmGap = (GMS - T1) * sizeof(T);
constexpr uint32_t blockSize = 32 / sizeof(T);
constexpr uint32_t ubGap = (UBS - T1) / blockSize;
static_assert(nBurst < ((1ULL << 12) - 1ULL));
static_assert(lenBurst < ((1ULL << 21) - 1ULL));
static_assert(gmGap < ((1ULL << 32) - 1ULL));
if constexpr (sizeof(T) == 1) {
copy_ubuf_to_gm_align_b8(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , ubGap, gmGap);
} else if (sizeof(T) == 2) {
copy_ubuf_to_gm_align_b16(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , ubGap, gmGap);
} else {
copy_ubuf_to_gm_align_b32(
dst, src, 0 , nBurst, lenBurst, 0 , 0 , ubGap, gmGap);
}
}
template <
typename T, unsigned T0, unsigned T1, unsigned T2, unsigned T3, unsigned T4, unsigned GMS1, unsigned GMS2,
unsigned GMS3, unsigned GMS4, unsigned UBS1, unsigned UBS2, unsigned UBS3, unsigned UBS4>
TILEOP void UBCopyOut(__gm__ T* dst, __ubuf__ T* src)
{
static_assert((UBS4 * sizeof(T)) % 32 == 0, "UB tile must be 32B aligned!");
for (int i0 = 0; i0 < T0; i0++) {
__gm__ T* dst0 = dst;
__ubuf__ T* src0 = src;
for (int i1 = 0; i1 < T1; i1++) {
__gm__ T* dst1 = dst0;
__ubuf__ T* src1 = src0;
for (int i2 = 0; i2 < T2; i2++) {
TileOp::UBCopyOut<T, T3, T4, GMS4, UBS4>(dst1, src1);
dst1 += GMS3 * GMS4;
src1 += UBS3 * UBS4;
}
dst0 += GMS2 * GMS3 * GMS4;
src0 += UBS2 * UBS3 * UBS4;
}
dst += GMS1 * GMS2 * GMS3 * GMS4;
src += UBS1 * UBS2 * UBS3 * UBS4;
}
}
template <
typename T, unsigned T0, unsigned T1, unsigned T2, unsigned T3, unsigned T4, unsigned UBS1, unsigned UBS2,
unsigned UBS3, unsigned UBS4, unsigned GMS1, unsigned GMS2, unsigned GMS3, unsigned GMS4, unsigned isNop>
TILEOP void UBCopyIn(__ubuf__ T* dst, __gm__ T* src)
{
}
template <
typename T, typename T2, unsigned src0OriShape1, unsigned src1OriShape1, unsigned GmShape1, unsigned src0rawShape1,
unsigned cacheMode, unsigned blockSize>
TILEOP void TIndexoutcast(__gm__ T* dst, __ubuf__ T* src0, __ubuf__ T2* src1)
{
for (auto i = 0; i < src1OriShape1; i++) {
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
T2 curValue = *(reinterpret_cast<__ubuf__ T2*>(src1 + i));
if constexpr (cacheMode == 1) {
T2 blockCount = curValue / blockSize;
T2 index = curValue % blockSize;
__gm__ T* new_dst = dst + blockCount * blockSize * GmShape1 + index * 32 / sizeof(T);
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
copy_ubuf_to_gm(
new_dst, src0 + i * src0OriShape1, 0 , src0OriShape1 / 32 * sizeof(T), 1, 0, blockSize - 1);
} else {
__gm__ T* new_dst = dst + curValue * GmShape1;
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
TileOp::UBCopyOut<T, 1, src0OriShape1, GmShape1, src0rawShape1>(new_dst, src0 + i * src0rawShape1);
}
}
}
template <
typename T, typename T2, unsigned src1OriShape0, unsigned src1OriShape1, unsigned src1rawShape1,
unsigned src0OriShape3, unsigned src0rawShape1, unsigned src0rawShape3, unsigned cacheMode>
TILEOP void TIndexoutcast(__gm__ T* dst, __ubuf__ T* src, __ubuf__ T2* index)
{
constexpr unsigned b = src1OriShape0;
constexpr unsigned s1 = src1OriShape1;
constexpr unsigned s1_32aligned = src1rawShape1;
constexpr unsigned nd = src0OriShape3;
constexpr unsigned nd_32aligned = src0rawShape3;
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
__gm__ T* curDst = dst;
__ubuf__ T2* dstIdx = index;
__ubuf__ T* curSrc = src;
for (int i = 0; i < b; ++i) {
for (int j = 0; j < s1; ++j) {
curDst = dst + *dstIdx * nd;
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
copy_ubuf_to_gm_align_b32(
curDst, curSrc, 0 , 1, nd * sizeof(T), 0, 0, (nd_32aligned - nd) * sizeof(T) / BLOCK_SIZE, 0);
curSrc += nd_32aligned;
dstIdx++;
}
curSrc += (src0rawShape1 - s1) * nd_32aligned;
dstIdx += s1_32aligned - s1;
}
}
template <
typename T, typename T2, unsigned src0OriShape0, unsigned src0OriShape1, unsigned src0OriShape3,
unsigned src0rawShape1, unsigned src0rawShape2, unsigned src0rawShape3, unsigned src1OriShape0,
unsigned src1OriShape1, unsigned src1rawShape3, unsigned GmShape2, unsigned GmShape3, unsigned cacheMode,
unsigned blockSize>
TILEOP void TIndexoutcast(__gm__ T* dst, __ubuf__ T* src0, __ubuf__ T2* src1)
{
if (cacheMode == 2) {
TIndexoutcast<
T, T2, src1OriShape0, src1OriShape1, src1rawShape3, src0OriShape3, src0rawShape1, src0rawShape3, cacheMode>(
dst, src0, src1);
return;
}
static_assert(src0OriShape1 == 1, "src0OriShape1 now only support 1");
static_assert(blockSize != 0, "blockSize can not be zero");
int alignTS2TS3 = src0rawShape2 * src0rawShape3;
int alignSrc1 = src1rawShape3;
for (int i = 0; i < src0OriShape0; ++i) {
for (int j = 0; j < src0OriShape1; ++j) {
TileOp::TIndexoutcast<T, T2, src0OriShape3, src1OriShape1, GmShape3, src0rawShape3, cacheMode, blockSize>(
dst, src0, src1);
}
src0 += alignTS2TS3;
src1 += alignSrc1;
dst += GmShape2 * GmShape3;
}
}
template <typename T1, typename T2, int64_t rawShape1>
TILEOP void Load(__ubuf__ T1* dst, __gm__ T1* src, __ubuf__ T2* offsets, int64_t originShape0, int64_t originShape1)
{
static_assert(std::is_same_v<T2, int32_t> || std::is_same_v<T2, int64_t>);
pipe_barrier(PIPE_ALL);
if (rawShape1 == originShape1) {
int64_t total = originShape0 * originShape1;
for (int64_t i = 0; i < total; i++) {
dst[i] = src[offsets[i]];
}
} else {
int64_t idx = 0;
for (int64_t i = 0; i < originShape0; i++) {
for (int64_t j = 0; j < originShape1; j++) {
dst[idx] = src[offsets[idx]];
idx++;
}
idx += rawShape1 - originShape1;
}
}
pipe_barrier(PIPE_ALL);
}
template <typename T1, typename T2, int64_t rawShape1, int64_t rawShape2>
TILEOP void Load(
__ubuf__ T1* dst, __gm__ T1* src, __ubuf__ T2* offsets, int64_t originShape0, int64_t originShape1,
int64_t originShape2)
{
static_assert(std::is_same_v<T2, int32_t> || std::is_same_v<T2, int64_t>);
pipe_barrier(PIPE_ALL);
if (rawShape1 == originShape1 && rawShape2 == originShape2) {
int64_t total = originShape0 * originShape1 * originShape2;
for (int64_t i = 0; i < total; i++) {
dst[i] = src[offsets[i]];
}
} else {
int64_t idx = 0;
for (int64_t i = 0; i < originShape0; i++) {
for (int64_t j = 0; j < originShape1; j++) {
for (int64_t k = 0; k < originShape2; k++) {
dst[idx] = src[offsets[idx]];
idx++;
}
idx += rawShape2 - originShape2;
}
idx += (rawShape1 - originShape1) * rawShape2;
}
}
pipe_barrier(PIPE_ALL);
}
}
#endif