* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file index_outcast.h
* \brief
*/
#ifndef TILEOP_TILE_OPERATOR_INDEX_OUTCAST__H
#define TILEOP_TILE_OPERATOR_INDEX_OUTCAST__H
#include "utils/layout.h"
#include "utils/tile_tensor.h"
template <
typename T0, typename T1, typename T2, int srcrawShape1, int srcTileH, int srcTileW, int src1SAligned,
typename DstDtype, typename SrcDtype, typename IdxDtype>
TILEOP void TIndexOutcastMode2(T0 dst, T1 src, T2 src1, unsigned b, unsigned s, unsigned srcShape3, unsigned srcShape4)
{
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
__ubuf__ SrcDtype* srcBase = reinterpret_cast<__ubuf__ SrcDtype*>(src.GetAddr());
__ubuf__ IdxDtype* idxBase = reinterpret_cast<__ubuf__ IdxDtype*>(src1.GetAddr());
__gm__ DstDtype* dstBase = reinterpret_cast<__gm__ DstDtype*>(dst.GetAddr());
__ubuf__ SrcDtype* curSrc = srcBase;
__ubuf__ IdxDtype* dstIdx = idxBase;
constexpr auto srcNdAligned = srcTileW;
for (LoopVar i = 0; i < b; ++i) {
for (LoopVar j = 0; j < s; ++j) {
int64_t targetRow = static_cast<int64_t>(*dstIdx);
if (targetRow >= 0) {
__gm__ DstDtype* curDst = dstBase + targetRow * srcShape4;
using SrcTileDefine =
pto::Tile<pto::TileType::Vec, SrcDtype, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
SrcTileDefine srcTile(srcShape3, srcShape4);
pto::TASSIGN(srcTile, reinterpret_cast<uint64_t>(curSrc));
using ShapeDim5 = pto::Shape<-1, -1, -1, -1, -1>;
using StrideDim5 = pto::Stride<-1, -1, -1, -1, -1>;
using DstGlobal = pto::GlobalTensor<DstDtype, ShapeDim5, StrideDim5>;
DstGlobal dstGlobal(curDst, pto::Shape(1, 1, 1, 1, srcShape4), pto::Stride(1, 1, 1, 1, 1));
pto::TSTORE(dstGlobal, srcTile);
}
curSrc += srcNdAligned;
dstIdx++;
}
curSrc += (srcrawShape1 - s) * srcNdAligned;
dstIdx += src1SAligned - s;
}
}
template <unsigned cacheMode, unsigned blockSize, typename T0, typename T1, typename T2, typename C>
TILEOP void TIndexOutcast(T0 dst, T1 src, T2 src1, C coordinate)
{
constexpr auto expectSize = 5;
const auto uLayout = src.GetLayout();
auto srcShape1 = uLayout.template GetShapeDim<1, expectSize>();
auto srcShape2 = uLayout.template GetShapeDim<2, expectSize>();
auto srcShape3 = uLayout.template GetShapeDim<3, expectSize>();
auto srcShape4 = uLayout.template GetShapeDim<4, expectSize>();
const auto iLayout = src1.GetLayout();
auto src1Shape3 = iLayout.template GetShapeDim<3, expectSize>();
auto src1Shape4 = iLayout.template GetShapeDim<4, expectSize>();
const auto dLayout = dst.GetLayout();
auto dstShape1 = dLayout.template GetShapeDim<1, expectSize>();
auto dstShape2 = dLayout.template GetShapeDim<2, expectSize>();
auto dstShape3 = dLayout.template GetShapeDim<3, expectSize>();
auto dstShape4 = dLayout.template GetShapeDim<4, expectSize>();
auto offset = dLayout.template GetGmOffset<C, expectSize>(coordinate);
using DstDtype = typename T0::Type;
using SrcDtype = typename T1::Type;
using IdxDtype = typename T2::Type;
constexpr auto srcrawShape1 = TileOp::GetTensorTileShapeDim<T1, 2, 5>();
constexpr auto srcTileH = TileOp::GetTensorTileShapeDim<T1, 3, 5>();
constexpr auto srcTileW = TileOp::GetTensorTileShapeDim<T1, 4, 5>();
constexpr auto src1SAligned = TileOp::GetTensorTileShapeDim<T2, 4, 5>();
constexpr auto srcNdAligned = srcTileW;
if (srcShape1 == 0 || srcShape2 == 0 || srcShape4 == 0 || src1Shape3 == 0 || src1Shape4 == 0) {
return;
}
if constexpr (cacheMode == 2) {
TIndexOutcastMode2<T0, T1, T2, srcrawShape1, srcTileH, srcTileW, src1SAligned, DstDtype, SrcDtype, IdxDtype>(
dst, src, src1, src1Shape3, src1Shape4, srcShape3, srcShape4);
return;
}
auto alignTS2TS3 = srcTileH * srcTileW;
auto alignSrc1 = src1SAligned;
__ubuf__ SrcDtype* srcBase = reinterpret_cast<__ubuf__ SrcDtype*>(src.GetAddr());
__ubuf__ IdxDtype* src1Base = reinterpret_cast<__ubuf__ IdxDtype*>(src1.GetAddr());
__gm__ DstDtype* dstBase = reinterpret_cast<__gm__ DstDtype*>(dst.GetAddr());
dstBase += offset;
for (LoopVar i = 0; i < srcShape1; ++i) {
for (LoopVar j = 0; j < srcShape2; ++j) {
for (LoopVar k = 0; k < src1Shape4; ++k) {
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
set_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID7);
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
auto curValue = *(reinterpret_cast<__ubuf__ IdxDtype*>(src1Base + k));
int64_t idxVal = static_cast<int64_t>(curValue);
if (idxVal < 0) {
continue;
}
__ubuf__ SrcDtype* srcPtr = srcBase + k * srcrawShape1;
if constexpr (cacheMode == 1) {
auto blockCount = curValue / blockSize;
auto index = curValue % blockSize;
__gm__ DstDtype* newDst =
dstBase + blockCount * blockSize * dstShape4 + index * 32 / sizeof(DstDtype);
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
using SrcTileDefine =
pto::Tile<pto::TileType::Vec, SrcDtype, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
SrcTileDefine srcTile(srcShape3, srcShape4);
pto::TASSIGN(srcTile, reinterpret_cast<uint64_t>(srcPtr));
using ShapeDim = pto::Shape<-1, -1, -1, -1, -1>;
using StrideDim = pto::Stride<-1, -1, -1, -1, -1>;
using DstGlobalType = pto::GlobalTensor<DstDtype, ShapeDim, StrideDim>;
DstGlobalType dstGlobal(newDst, pto::Shape(1, 1, 1, 1, srcShape4), pto::Stride(1, 1, 1, 1, 1));
pto::TSTORE(dstGlobal, srcTile);
} else {
__gm__ DstDtype* newDst = dstBase + static_cast<unsigned>(curValue) * dstShape4;
set_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID7);
using SrcTileDefine =
pto::Tile<pto::TileType::Vec, SrcDtype, srcTileH, srcTileW, pto::BLayout::RowMajor, -1, -1>;
SrcTileDefine srcTile(srcShape3, srcShape4);
pto::TASSIGN(srcTile, reinterpret_cast<uint64_t>(srcPtr));
using ShapeDim5 = pto::Shape<-1, -1, -1, -1, -1>;
using StrideDim5 = pto::Stride<-1, -1, -1, -1, -1>;
using DstGlobalType = pto::GlobalTensor<DstDtype, ShapeDim5, StrideDim5>;
DstGlobalType dstGlobal(newDst, pto::Shape(1, 1, 1, 1, srcShape4), pto::Stride(1, 1, 1, 1, 1));
pto::TSTORE(dstGlobal, srcTile);
}
}
}
srcBase += alignTS2TS3;
src1Base += alignSrc1;
dstBase += dstShape3 * dstShape4;
}
}
#endif