* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file indexadd.h
* \brief
*/
#ifndef TILEOP_TILE_OPERATOR_INDEXADD_H
#define TILEOP_TILE_OPERATOR_INDEXADD_H
#include "utils/layout.h"
#include "utils/tile_tensor.h"
template <
typename T0, typename T2, typename T3, typename dstTileDefine, typename tempTileDefine, typename src1TileDefine,
typename Scalar>
TILEOP void IndexAddUBNotLastAxisCompute(
dstTileDefine dstTile, tempTileDefine tempTile, src1TileDefine src1Tile, Scalar alpha,
__ubuf__ typename T0::Type* dstAddr, __ubuf__ bfloat16_t* tempAddr, __ubuf__ typename T2::Type* src1Addr,
size_t dstOffset, size_t src1Offset)
{
pto::TASSIGN(dstTile, (uint64_t)(dstAddr + dstOffset));
pto::TASSIGN(src1Tile, (uint64_t)(src1Addr + src1Offset));
if constexpr (Std::is_same_v<Scalar, bfloat16_t>) {
pto::TASSIGN(tempTile, (uint64_t)(tempAddr));
set_flag(PIPE_S, PIPE_V, EVENT_ID7);
wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
if (abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
pto::TMULS(src1Tile, src1Tile, alpha);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TCVT(tempTile, src1Tile, pto::RoundMode::CAST_RINT);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TCVT(src1Tile, tempTile, pto::RoundMode::CAST_NONE);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
}
pto::TADD(dstTile, dstTile, src1Tile);
if (Std::is_same_v<typename T3::Type, int32_t> || abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TCVT(tempTile, dstTile, pto::RoundMode::CAST_RINT);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TCVT(dstTile, tempTile, pto::RoundMode::CAST_NONE);
}
} else {
set_flag(PIPE_S, PIPE_V, EVENT_ID7);
wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
if (abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
pto::TMULS(src1Tile, src1Tile, alpha);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
}
pto::TADD(dstTile, dstTile, src1Tile);
}
}
template <typename T0, typename T2, typename T3, typename Scalar>
TILEOP void IndexAddUBLastAxisCompute(
T0 dst, T2 src1, T3 src2, Scalar alpha, size_t src1Shape0, size_t src1Shape1, size_t src1Shape2, size_t src1Shape3,
size_t src1Shape4, size_t dstStride0, size_t dstStride1, size_t dstStride2, size_t dstStride3, size_t src1Stride0,
size_t src1Stride1, size_t src1Stride2, size_t src1Stride3)
{
auto dstAddr = (__ubuf__ typename T0::Type*)((uint64_t)(dst.GetAddr()));
auto src1Addr = (__ubuf__ typename T2::Type*)((uint64_t)(src1.GetAddr()));
auto idxAddr = (__ubuf__ typename T3::Type*)((uint64_t)(src2.GetAddr()));
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
uint64_t dstOffset = 0;
uint64_t src1Offset = 0;
if (abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
for (LoopVar i = 0; i < src1Shape0; ++i) {
for (LoopVar j = 0; j < src1Shape1; ++j) {
for (LoopVar k = 0; k < src1Shape2; ++k) {
for (LoopVar l = 0; l < src1Shape3; ++l) {
for (LoopVar idx = 0; idx < src1Shape4; ++idx) {
auto index = *(idxAddr + idx);
dstOffset = i * dstStride0 + j * dstStride1 + k * dstStride2 + l * dstStride3 + index;
src1Offset = i * src1Stride0 + j * src1Stride1 + k * src1Stride2 + l * src1Stride3 + idx;
if constexpr (Std::is_same_v<Scalar, half>) {
float mulsResult = static_cast<float>(src1Addr[src1Offset]) * static_cast<float>(alpha);
src1Addr[src1Offset] = static_cast<half>(mulsResult);
} else if constexpr (Std::is_same_v<Scalar, bfloat16_t>) {
float mulsResult = src1Addr[src1Offset] * TileOp::Bf16ToFp32(alpha);
bfloat16_t mulsResBf16 = TileOp::Fp32ToBf16R(mulsResult);
src1Addr[src1Offset] = TileOp::Bf16ToFp32(mulsResBf16);
} else {
Scalar mulsResult = static_cast<Scalar>(src1Addr[src1Offset]) * alpha;
src1Addr[src1Offset] = static_cast<typename T2::Type>(mulsResult);
}
}
}
}
}
}
}
for (LoopVar i = 0; i < src1Shape0; ++i) {
for (LoopVar j = 0; j < src1Shape1; ++j) {
for (LoopVar k = 0; k < src1Shape2; ++k) {
for (LoopVar l = 0; l < src1Shape3; ++l) {
for (LoopVar idx = 0; idx < src1Shape4; ++idx) {
auto index = *(idxAddr + idx);
dstOffset = i * dstStride0 + j * dstStride1 + k * dstStride2 + l * dstStride3 + index;
src1Offset = i * src1Stride0 + j * src1Stride1 + k * src1Stride2 + l * src1Stride3 + idx;
if constexpr (Std::is_same_v<Scalar, half>) {
float addResult =
static_cast<float>(dstAddr[dstOffset]) + static_cast<float>(src1Addr[src1Offset]);
if (abs(static_cast<float>(alpha) - 1) < TileOp::EPSILON &&
Std::is_same_v<typename T3::Type, int64_t>) {
dstAddr[dstOffset] = addResult;
} else {
dstAddr[dstOffset] = static_cast<half>(addResult);
}
} else if constexpr (Std::is_same_v<Scalar, bfloat16_t>) {
float addResult = dstAddr[dstOffset] + src1Addr[src1Offset];
if (abs(static_cast<float>(alpha) - 1) < TileOp::EPSILON &&
Std::is_same_v<typename T3::Type, int64_t>) {
dstAddr[dstOffset] = addResult;
} else {
bfloat16_t addResBf16 = TileOp::Fp32ToBf16R(addResult);
dstAddr[dstOffset] = TileOp::Bf16ToFp32(addResBf16);
}
} else {
Scalar addResult =
static_cast<Scalar>(dstAddr[dstOffset]) + static_cast<Scalar>(src1Addr[src1Offset]);
dstAddr[dstOffset] = static_cast<typename T0::Type>(addResult);
}
}
}
}
}
}
set_flag(PIPE_S, PIPE_V, EVENT_ID7);
wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
}
src0:self
src1:source
src2:index
axis是泛化成5维后的值,实际值为 axis + shapeSize - 5
*/
template <int axis, typename T0, typename T1, typename T2, typename T3, typename T4, typename Scalar>
TILEOP void TIndexAddUB(T0 dst, T1 src0, T2 src1, T3 src2, T4 tempTensor, Scalar alpha)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
const auto dstLayout = dst.GetLayout();
auto dstShape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto dstShape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto dstShape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto dstShape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto dstShape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
auto dstStride0 = dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
auto dstStride1 = dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
auto dstStride2 = dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
auto dstStride3 = dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
const auto src1Layout = src1.GetLayout();
auto src1Shape0 = src1Layout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto src1Shape1 = src1Layout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto src1Shape2 = src1Layout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto src1Shape3 = src1Layout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto src1Shape4 = src1Layout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
auto src1Stride0 = src1Layout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
auto src1Stride1 = src1Layout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
auto src1Stride2 = src1Layout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
auto src1Stride3 = src1Layout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
auto dstAddr = (__ubuf__ typename T0::Type*)((uint64_t)(dst.GetAddr()));
auto tempAddr = (__ubuf__ typename T4::Type*)((uint64_t)(tempTensor.GetAddr()));
auto src1Addr = (__ubuf__ typename T2::Type*)((uint64_t)(src1.GetAddr()));
auto idxAddr = (__ubuf__ typename T3::Type*)((uint64_t)(src2.GetAddr()));
if (!dstShape0 || !dstShape1 || !dstShape2 || !dstShape3 || !dstShape4) {
return;
}
if constexpr (axis == 4) {
IndexAddUBLastAxisCompute(
dst, src1, src2, alpha, src1Shape0, src1Shape1, src1Shape2, src1Shape3, src1Shape4, dstStride0, dstStride1,
dstStride2, dstStride3, src1Stride0, src1Stride1, src1Stride2, src1Stride3);
} else {
constexpr auto dstTileW =
TileOp::GetAnyAxisMergeResult<axis + shapeSize - 3, shapeSize, typename T0::TileShape>();
constexpr auto tempTileW =
TileOp::GetAnyAxisMergeResult<axis + shapeSize - 3, shapeSize, typename T4::TileShape>();
constexpr auto src1TileW =
TileOp::GetAnyAxisMergeResult<axis + shapeSize - 3, shapeSize, typename T2::TileShape>();
using dstTileDefine = pto::Tile<pto::TileType::Vec, typename T0::Type, 1, dstTileW, pto::BLayout::RowMajor>;
using tempTileDefine = pto::Tile<pto::TileType::Vec, bfloat16_t, 1, tempTileW, pto::BLayout::RowMajor>;
using src1TileDefine = pto::Tile<pto::TileType::Vec, typename T2::Type, 1, src1TileW, pto::BLayout::RowMajor>;
dstTileDefine dstTile;
tempTileDefine tempTile;
src1TileDefine src1Tile;
if constexpr (axis == 0) {
for (LoopVar i = 0; i < src1Shape0; ++i) {
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
auto index = *(idxAddr + i);
auto dstOffset = index * dstStride0;
auto src1Offset = i * src1Stride0;
IndexAddUBNotLastAxisCompute<T0, T2, T3>(
dstTile, tempTile, src1Tile, alpha, dstAddr, tempAddr, src1Addr, dstOffset, src1Offset);
}
} else if constexpr (axis == 1) {
for (LoopVar i = 0; i < src1Shape0; ++i) {
for (LoopVar j = 0; j < src1Shape1; ++j) {
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
auto index = *(idxAddr + j);
auto dstOffset = i * dstStride0 + index * dstStride1;
auto src1Offset = i * src1Stride0 + j * src1Stride1;
IndexAddUBNotLastAxisCompute<T0, T2, T3>(
dstTile, tempTile, src1Tile, alpha, dstAddr, tempAddr, src1Addr, dstOffset, src1Offset);
}
}
} else if constexpr (axis == 2) {
for (LoopVar i = 0; i < src1Shape0; ++i) {
for (LoopVar j = 0; j < src1Shape1; ++j) {
for (LoopVar k = 0; k < src1Shape2; ++k) {
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
auto index = *(idxAddr + k);
auto dstOffset = i * dstStride0 + j * dstStride1 + index * dstStride2;
auto src1Offset = i * src1Stride0 + j * src1Stride1 + k * src1Stride2;
IndexAddUBNotLastAxisCompute<T0, T2, T3>(
dstTile, tempTile, src1Tile, alpha, dstAddr, tempAddr, src1Addr, dstOffset, src1Offset);
}
}
}
} else {
for (LoopVar i = 0; i < src1Shape0; ++i) {
for (LoopVar j = 0; j < src1Shape1; ++j) {
for (LoopVar k = 0; k < src1Shape2; ++k) {
for (LoopVar l = 0; l < src1Shape3; ++l) {
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
auto index = *(idxAddr + l);
auto dstOffset = i * dstStride0 + j * dstStride1 + k * dstStride2 + index * dstStride3;
auto src1Offset = i * src1Stride0 + j * src1Stride1 + k * src1Stride2 + l * src1Stride3;
IndexAddUBNotLastAxisCompute<T0, T2, T3>(
dstTile, tempTile, src1Tile, alpha, dstAddr, tempAddr, src1Addr, dstOffset, src1Offset);
}
}
}
}
}
}
}
template <
typename T0, typename T2, typename dstGlobalData, typename tmpTileDefine, typename src1TileDefine, typename Scalar>
TILEOP void IndexAddNotLastAxisCompute(
dstGlobalData dstGlobal, tmpTileDefine tmpTile, src1TileDefine src1Tile, Scalar alpha,
__gm__ typename T0::Type* dstAddr, __ubuf__ typename T2::Type* tmpAddr, __ubuf__ typename T2::Type* src1Addr,
size_t dstOffset, size_t src1Offset)
{
pto::TASSIGN(dstGlobal, dstAddr + dstOffset);
pto::TASSIGN(src1Tile, (uint64_t)(src1Addr + src1Offset));
pto::TASSIGN(tmpTile, (uint64_t)(tmpAddr));
if (abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
if constexpr (Std::is_same_v<Scalar, int8_t>) {
for (LoopVar idx = 0; idx < src1Tile.GetValidCol(); ++idx) {
auto newSrc1Offset = src1Offset + idx;
Scalar mulsResult = static_cast<Scalar>(src1Addr[newSrc1Offset]) * alpha;
tmpAddr[idx] = mulsResult;
}
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
} else {
set_flag(PIPE_S, PIPE_V, EVENT_ID7);
wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
pto::TMULS(tmpTile, src1Tile, alpha);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID7);
}
pto::TSTORE<tmpTileDefine, dstGlobalData, pto::AtomicType::AtomicAdd>(dstGlobal, tmpTile);
} else {
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
pto::TSTORE<src1TileDefine, dstGlobalData, pto::AtomicType::AtomicAdd>(dstGlobal, src1Tile);
}
}
template <typename T0, typename T2, typename dstGlobalData, typename tmpTileDefine, typename Scalar>
TILEOP void IndexAddLastAxisCompute(
dstGlobalData dstGlobal, tmpTileDefine tmpTile, Scalar alpha, __gm__ typename T0::Type* dstAddr,
__ubuf__ typename T2::Type* tmpAddr, __ubuf__ typename T2::Type* src1Addr, size_t dstOffset, size_t src1Offset)
{
if (abs(static_cast<float>(alpha) - 1) > TileOp::EPSILON) {
if constexpr (Std::is_same_v<Scalar, half>) {
float mulsResult = static_cast<float>(src1Addr[src1Offset]) * static_cast<float>(alpha);
tmpAddr[0] = static_cast<half>(mulsResult);
} else if constexpr (Std::is_same_v<Scalar, bfloat16_t>) {
float mulsResult = src1Addr[src1Offset] * TileOp::Bf16ToFp32(alpha);
bfloat16_t mulsResBf16 = TileOp::Fp32ToBf16R(mulsResult);
tmpAddr[0] = TileOp::Bf16ToFp32(mulsResBf16);
} else {
Scalar mulsResult = static_cast<Scalar>(src1Addr[src1Offset]) * alpha;
tmpAddr[0] = mulsResult;
}
} else {
tmpAddr[0] = src1Addr[src1Offset];
}
pto::TASSIGN(dstGlobal, dstAddr + dstOffset);
pto::TASSIGN(tmpTile, (uint64_t)tmpAddr);
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
pto::TSTORE<tmpTileDefine, dstGlobalData, pto::AtomicType::AtomicAdd>(dstGlobal, tmpTile);
}
template <int axis, typename T3>
TILEOP size_t GetTileOffset(size_t dstStrides[], size_t idx[], __ubuf__ typename T3::Type* idxAddr)
{
size_t dstOffset = 0;
if constexpr (axis == 0) {
dstOffset = *(idxAddr + idx[0]) * dstStrides[0] + idx[1] * dstStrides[1] + idx[2] * dstStrides[2] +
idx[3] * dstStrides[3];
} else if constexpr (axis == 1) {
dstOffset = idx[0] * dstStrides[0] + *(idxAddr + idx[1]) * dstStrides[1] + idx[2] * dstStrides[2] +
idx[3] * dstStrides[3];
} else if constexpr (axis == 2) {
dstOffset = idx[0] * dstStrides[0] + idx[1] * dstStrides[1] + *(idxAddr + idx[2]) * dstStrides[2] +
idx[3] * dstStrides[3];
} else if constexpr (axis == 3) {
dstOffset = idx[0] * dstStrides[0] + idx[1] * dstStrides[1] + idx[2] * dstStrides[2] +
*(idxAddr + idx[3]) * dstStrides[3];
} else {
dstOffset = idx[0] * dstStrides[0] + idx[1] * dstStrides[1] + idx[2] * dstStrides[2] + idx[3] * dstStrides[3];
}
return dstOffset;
}
dst: dst in GM
src0: self in GM
src1: source in UB
src2: index in UB
axis是泛化成5维后的值,实际值为 axis + shapeSize - 5
*/
template <int axis, typename T0, typename T1, typename T2, typename T3, typename T4, typename C, typename Scalar>
TILEOP void TIndexAdd(T0 dst, T1 src0, T2 src1, T3 src2, T4 tmpTensor, C coord, Scalar alpha)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
const auto dstLayout = dst.GetLayout();
size_t dstShapes[] = {
dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>(), dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>(),
dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>(), dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>(),
dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>()};
if (!dstShapes[0] || !dstShapes[1] || !dstShapes[2] || !dstShapes[3] || !dstShapes[4]) {
return;
}
size_t dstStrides[] = {
dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>(), dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>(),
dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>(), dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>()};
const auto src1Layout = src1.GetLayout();
size_t src1Shapes[] = {
src1Layout.template GetShapeDim<DIM_1ST, MAX_DIMS>(), src1Layout.template GetShapeDim<DIM_2ND, MAX_DIMS>(),
src1Layout.template GetShapeDim<DIM_3RD, MAX_DIMS>(), src1Layout.template GetShapeDim<DIM_4TH, MAX_DIMS>(),
src1Layout.template GetShapeDim<DIM_5TH, MAX_DIMS>()};
size_t src1Strides[] = {
src1Layout.template GetStrideDim<DIM_1ST, MAX_DIMS>(), src1Layout.template GetStrideDim<DIM_2ND, MAX_DIMS>(),
src1Layout.template GetStrideDim<DIM_3RD, MAX_DIMS>(), src1Layout.template GetStrideDim<DIM_4TH, MAX_DIMS>()};
using dstType = typename T0::Type;
using src1Type = typename T2::Type;
using idxType = typename T3::Type;
using tmpType = typename T4::Type;
auto dstAddr = (__gm__ dstType*)((uint64_t)(dst.GetAddr()));
auto tmpAddr = (__ubuf__ tmpType*)((uint64_t)(tmpTensor.GetAddr()));
auto src1Addr = (__ubuf__ src1Type*)((uint64_t)(src1.GetAddr()));
auto idxAddr = (__ubuf__ idxType*)((uint64_t)(src2.GetAddr()));
size_t gmOffset = static_cast<size_t>(dstLayout.template GetGmOffset<C, MAX_DIMS>(coord));
dstAddr += gmOffset;
constexpr auto tmpTileW = Std::tuple_element<1, typename T4::TileShape>::type::value;
constexpr auto src1TileW = Std::tuple_element<shapeSize - 1, typename T2::TileShape>::type::value;
using dstGlobalData = pto::GlobalTensor<dstType, pto::Shape<-1, -1, -1, -1, -1>, pto::Stride<-1, -1, -1, -1, -1>>;
using tmpTileDefine = pto::Tile<pto::TileType::Vec, tmpType, 1, tmpTileW, pto::BLayout::RowMajor, -1, -1>;
using src1TileDefine = pto::Tile<pto::TileType::Vec, src1Type, 1, src1TileW, pto::BLayout::RowMajor, -1, -1>;
if constexpr (axis == 4) {
dstGlobalData dstGlobal(dstAddr, pto::Shape(1, 1, 1, 1, 1), pto::Stride(0, 0, 0, 0, 0));
tmpTileDefine tmpTile(1, 1);
for (LoopVar i = 0; i < src1Shapes[0]; ++i) {
for (LoopVar j = 0; j < src1Shapes[1]; ++j) {
for (LoopVar k = 0; k < src1Shapes[2]; ++k) {
for (LoopVar l = 0; l < src1Shapes[3]; ++l) {
for (LoopVar m = 0; m < src1Shapes[4]; ++m) {
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
size_t idx[] = {i, j, k, l};
auto dstOffset = GetTileOffset<axis, T3>(dstStrides, idx, idxAddr) + idxAddr[m];
auto src1Offset = GetTileOffset<axis, T3>(src1Strides, idx, idxAddr) + m;
IndexAddLastAxisCompute<T0, T2>(
dstGlobal, tmpTile, alpha, dstAddr, tmpAddr, src1Addr, dstOffset, src1Offset);
}
}
}
}
}
} else {
dstGlobalData dstGlobal(dstAddr, pto::Shape(1, 1, 1, 1, dstShapes[4]), pto::Stride(0, 0, 0, 0, 0));
tmpTileDefine tmpTile(1, src1Shapes[4]);
src1TileDefine src1Tile(1, src1Shapes[4]);
for (LoopVar i = 0; i < src1Shapes[0]; ++i) {
for (LoopVar j = 0; j < src1Shapes[1]; ++j) {
for (LoopVar k = 0; k < src1Shapes[2]; ++k) {
for (LoopVar l = 0; l < src1Shapes[3]; ++l) {
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
size_t idx[] = {i, j, k, l};
auto dstOffset = GetTileOffset<axis, T3>(dstStrides, idx, idxAddr);
auto src1Offset = GetTileOffset<4, T3>(src1Strides, idx, idxAddr);
IndexAddNotLastAxisCompute<T0, T2>(
dstGlobal, tmpTile, src1Tile, alpha, dstAddr, tmpAddr, src1Addr, dstOffset, src1Offset);
}
}
}
}
}
}
#endif