* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file where.h
* \brief
*/
#ifndef TILEOP_TILE_OPERATOR_WHERE__H
#define TILEOP_TILE_OPERATOR_WHERE__H
#include "utils/layout.h"
#include "utils/tile_tensor.h"
template <unsigned elementsCount>
TILEOP void CaculateMask(
uint64_t condition, __ubuf__ half* castCondition, __ubuf__ half* compareCondition, __ubuf__ uint8_t* vcmpBitResult,
const unsigned curCount)
{
constexpr unsigned bitsOfByte = 8;
using TileCondition = pto::Tile<pto::TileType::Vec, uint8_t, 1, elementsCount, pto::BLayout::RowMajor, -1, -1>;
using TileConditionHalf = pto::Tile<pto::TileType::Vec, half, 1, elementsCount, pto::BLayout::RowMajor, -1, -1>;
using TileVCmpBitResult =
pto::Tile<pto::TileType::Vec, uint8_t, 1, elementsCount / bitsOfByte, pto::BLayout::RowMajor, -1, -1>;
TileCondition conditionTile(1, curCount);
TileConditionHalf castConditionTile(1, curCount);
TileConditionHalf compareConditionTile(1, curCount);
TileVCmpBitResult vcmpBitResultTile(1, curCount / bitsOfByte);
pto::TASSIGN(vcmpBitResultTile, (uint64_t)(vcmpBitResult));
pto::TASSIGN(conditionTile, (uint64_t)condition);
pto::TASSIGN(castConditionTile, (uint64_t)(castCondition));
pto::TASSIGN(compareConditionTile, (uint64_t)(compareCondition));
pto::TCVT(castConditionTile, conditionTile, pto::RoundMode::CAST_NONE);
pto::TEXPANDS(compareConditionTile, (half)1.000000e+00f);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TCMP(vcmpBitResultTile, castConditionTile, compareConditionTile, pto::CmpMode::EQ);
}
template <typename T, unsigned elementsCount>
TILEOP void ProcessWhere(
uint64_t dst, uint64_t vcmpBitResult, uint64_t src0, uint64_t src1, uint64_t startAddrUB, const unsigned curCount)
{
constexpr unsigned bitsOfByte = 8;
constexpr unsigned addressUsed = 4;
constexpr unsigned alignUint8 = 32;
using TileVCmpBitResult =
pto::Tile<pto::TileType::Vec, uint8_t, 1, elementsCount / bitsOfByte, pto::BLayout::RowMajor, -1, -1>;
TileVCmpBitResult vcmpBitResultTile(1, curCount / bitsOfByte);
pto::TASSIGN(vcmpBitResultTile, (uint64_t)(vcmpBitResult));
using TileStartAddrUB = pto::Tile<pto::TileType::Vec, uint8_t, 1, alignUint8, pto::BLayout::RowMajor, -1, -1>;
TileStartAddrUB startAddrUBTile(1, addressUsed);
pto::TASSIGN(startAddrUBTile, (uint64_t)(startAddrUB));
using TileDst = pto::Tile<pto::TileType::Vec, T, 1, elementsCount, pto::BLayout::RowMajor, -1, -1>;
TileDst dstTile(1, curCount);
TileDst src0Tile(1, curCount);
TileDst src1Tile(1, curCount);
pto::TASSIGN(dstTile, (uint64_t)dst);
pto::TASSIGN(src0Tile, (uint64_t)src0);
pto::TASSIGN(src1Tile, (uint64_t)src1);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TSEL(dstTile, vcmpBitResultTile, src0Tile, src1Tile, startAddrUBTile);
}
#define OP_TILE_OP_WHERETT TWhereTT
template <typename TDst, typename TTmp, typename TCond, typename TSrc0, typename TSrc1>
TILEOP void TWhereTT(TDst dst, TTmp tmpbuf, TCond condition, TSrc0 src0, TSrc1 src1)
{
using ShapeValueType = typename Std::tuple_element<0, typename TDst::Shape>::type;
constexpr auto shapeSize = Std::tuple_size<typename TDst::Shape>::value;
constexpr unsigned elementsPerCount = 1024;
constexpr unsigned bitsOfByte = 8;
uint64_t tmpbufAddr = tmpbuf.GetAddr();
__ubuf__ half* castCondition = reinterpret_cast<__ubuf__ half*>(tmpbufAddr);
__ubuf__ half* compareCondition = castCondition + elementsPerCount;
__ubuf__ uint8_t* vcmpBitResult = reinterpret_cast<__ubuf__ uint8_t*>(compareCondition + elementsPerCount);
__ubuf__ uint8_t* startAddrUB = reinterpret_cast<__ubuf__ uint8_t*>(vcmpBitResult + elementsPerCount / bitsOfByte);
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto conditionLayout = condition.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto shape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto shape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto shape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto shape4 = dstLayout.template GetShapeDim<4, expectSize>();
auto conditionShape = condition.GetLayout().template GetShapeDim<4, expectSize>();
auto stride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto stride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto stride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto stride3 = dstLayout.template GetStrideDim<3, expectSize>();
auto conditionStride0 = conditionLayout.template GetStrideDim<0, expectSize>();
auto conditionStride1 = conditionLayout.template GetStrideDim<1, expectSize>();
auto conditionStride2 = conditionLayout.template GetStrideDim<2, expectSize>();
auto conditionStride3 = conditionLayout.template GetStrideDim<3, expectSize>();
constexpr auto tileW = Std::tuple_element<shapeSize - 1, typename TDst::TileShape>::type::value;
constexpr auto conditionTileW = Std::tuple_element<shapeSize - 1, typename TCond::TileShape>::type::value;
constexpr auto dstTypeSize = sizeof(typename TDst::Type);
constexpr auto conditionTypeSize = sizeof(typename TCond::Type);
constexpr auto src0TypeSize = sizeof(typename TSrc0::Type);
constexpr auto src1TypeSize = sizeof(typename TSrc1::Type);
unsigned numCountPerLine = shape4 / elementsPerCount;
unsigned elementsRemainPerLine = shape4 % elementsPerCount;
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
for (LoopVar n3Index = 0; n3Index < shape3; ++n3Index) {
if constexpr (std::is_same_v<typename TCond::Type, bool>) {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsPerCount);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize),
(uint64_t)(src1.GetAddr() + offset * src1TypeSize), (uint64_t)(startAddrUB),
elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsRemainPerLine);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize),
(uint64_t)(src1.GetAddr() + offset * src1TypeSize), (uint64_t)(startAddrUB),
elementsRemainPerLine);
}
} else {
constexpr unsigned addressUsed = 4;
constexpr unsigned alignUint8 = 32;
using TileStartAddrUB =
pto::Tile<pto::TileType::Vec, uint8_t, 1, alignUint8, pto::BLayout::RowMajor, -1, -1>;
TileStartAddrUB startAddrUBTile(1, addressUsed);
pto::TASSIGN(startAddrUBTile, (uint64_t)(startAddrUB));
using TileDst = pto::Tile<
pto::TileType::Vec, typename TDst::Type, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
using TileMask = pto::Tile<
pto::TileType::Vec, typename TCond::Type, 1, conditionTileW, pto::BLayout::RowMajor, -1,
-1>;
TileDst dstTile(1, shape4);
TileMask maskTile(1, conditionShape);
TileDst src0Tile(1, shape4);
TileDst src1Tile(1, shape4);
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 + n3Index * stride3;
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + offset * dstTypeSize));
pto::TASSIGN(maskTile, (uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize));
pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + offset * src0TypeSize));
pto::TASSIGN(src1Tile, (uint64_t)(src1.GetAddr() + offset * src1TypeSize));
pto::TSEL(dstTile, maskTile, src0Tile, src1Tile, startAddrUBTile);
}
}
}
}
}
}
#define OP_TILE_OP_WHERE_TS TWhereTS
template <typename TDst, typename TTmp, typename TCond, typename TSrc0, typename TSrc1>
TILEOP void TWhereTS(TDst dst, TTmp tmpbuf, TCond condition, TSrc0 src0, TSrc1 src1)
{
using ShapeValueType = typename Std::tuple_element<0, typename TDst::Shape>::type;
constexpr auto shapeSize = Std::tuple_size<typename TDst::Shape>::value;
constexpr unsigned elementsPerCount = 1024;
constexpr unsigned bitsOfByte = 8;
uint64_t tmpbufAddr = tmpbuf.GetAddr();
__ubuf__ half* castCondition = reinterpret_cast<__ubuf__ half*>(tmpbufAddr);
__ubuf__ half* compareCondition = castCondition + elementsPerCount;
__ubuf__ uint8_t* vcmpBitResult = reinterpret_cast<__ubuf__ uint8_t*>(compareCondition + elementsPerCount);
__ubuf__ typename TDst::Type* otherTempTensor =
(__ubuf__ typename TDst::Type*)(vcmpBitResult + elementsPerCount / bitsOfByte);
__ubuf__ uint8_t* startAddrUB = reinterpret_cast<__ubuf__ uint8_t*>(otherTempTensor + elementsPerCount);
using TileTSrc1 = pto::Tile<pto::TileType::Vec, TSrc1, 1, elementsPerCount, pto::BLayout::RowMajor, -1, -1>;
TileTSrc1 src1Tile(1, elementsPerCount);
pto::TASSIGN(src1Tile, (uint64_t)(otherTempTensor));
pto::TEXPANDS(src1Tile, src1);
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto conditionLayout = condition.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto shape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto shape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto shape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto shape4 = dstLayout.template GetShapeDim<4, expectSize>();
auto conditionShape = condition.GetLayout().template GetShapeDim<4, expectSize>();
auto stride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto stride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto stride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto stride3 = dstLayout.template GetStrideDim<3, expectSize>();
auto conditionStride0 = conditionLayout.template GetStrideDim<0, expectSize>();
auto conditionStride1 = conditionLayout.template GetStrideDim<1, expectSize>();
auto conditionStride2 = conditionLayout.template GetStrideDim<2, expectSize>();
auto conditionStride3 = conditionLayout.template GetStrideDim<3, expectSize>();
constexpr auto dstTypeSize = sizeof(typename TDst::Type);
constexpr auto conditionTypeSize = sizeof(typename TCond::Type);
constexpr auto src0TypeSize = sizeof(typename TSrc0::Type);
unsigned numCountPerLine = shape4 / elementsPerCount;
unsigned elementsRemainPerLine = shape4 % elementsPerCount;
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
for (LoopVar n3Index = 0; n3Index < shape3; ++n3Index) {
if constexpr (std::is_same_v<typename TCond::Type, bool>) {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsPerCount);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize), (uint64_t)(otherTempTensor),
(uint64_t)(startAddrUB), elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsRemainPerLine);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize), (uint64_t)(otherTempTensor),
(uint64_t)(startAddrUB), elementsRemainPerLine);
}
} else {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize), (uint64_t)(otherTempTensor),
(uint64_t)(startAddrUB), elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(src0.GetAddr() + offset * src0TypeSize), (uint64_t)(otherTempTensor),
(uint64_t)(startAddrUB), elementsRemainPerLine);
}
}
}
}
}
}
}
#define OP_TILE_OP_WHERE_ST TWhereST
template <typename TDst, typename TTmp, typename TCond, typename TSrc0, typename TSrc1>
TILEOP void TWhereST(TDst dst, TTmp tmpbuf, TCond condition, TSrc0 src0, TSrc1 src1)
{
using ShapeValueType = typename Std::tuple_element<0, typename TDst::Shape>::type;
constexpr auto shapeSize = Std::tuple_size<typename TDst::Shape>::value;
constexpr unsigned elementsPerCount = 1024;
constexpr unsigned bitsOfByte = 8;
uint64_t tmpbufAddr = tmpbuf.GetAddr();
__ubuf__ half* castCondition = reinterpret_cast<__ubuf__ half*>(tmpbufAddr);
__ubuf__ half* compareCondition = castCondition + elementsPerCount;
__ubuf__ uint8_t* vcmpBitResult = reinterpret_cast<__ubuf__ uint8_t*>(compareCondition + elementsPerCount);
__ubuf__ typename TDst::Type* inputTempTensor =
(__ubuf__ typename TDst::Type*)(vcmpBitResult + elementsPerCount / bitsOfByte);
__ubuf__ uint8_t* startAddrUB = reinterpret_cast<__ubuf__ uint8_t*>(inputTempTensor + elementsPerCount);
using TileTSrc0 = pto::Tile<pto::TileType::Vec, TSrc0, 1, elementsPerCount, pto::BLayout::RowMajor, -1, -1>;
TileTSrc0 src0Tile(1, elementsPerCount);
pto::TASSIGN(src0Tile, (uint64_t)(inputTempTensor));
pto::TEXPANDS(src0Tile, src0);
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto conditionLayout = condition.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto shape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto shape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto shape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto shape4 = dstLayout.template GetShapeDim<4, expectSize>();
auto conditionShape = condition.GetLayout().template GetShapeDim<4, expectSize>();
auto stride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto stride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto stride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto stride3 = dstLayout.template GetStrideDim<3, expectSize>();
auto conditionStride0 = conditionLayout.template GetStrideDim<0, expectSize>();
auto conditionStride1 = conditionLayout.template GetStrideDim<1, expectSize>();
auto conditionStride2 = conditionLayout.template GetStrideDim<2, expectSize>();
auto conditionStride3 = conditionLayout.template GetStrideDim<3, expectSize>();
constexpr auto dstTypeSize = sizeof(typename TDst::Type);
constexpr auto conditionTypeSize = sizeof(typename TCond::Type);
constexpr auto src1TypeSize = sizeof(typename TSrc1::Type);
unsigned numCountPerLine = shape4 / elementsPerCount;
unsigned elementsRemainPerLine = shape4 % elementsPerCount;
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
for (LoopVar n3Index = 0; n3Index < shape3; ++n3Index) {
if constexpr (std::is_same_v<typename TCond::Type, bool>) {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsPerCount);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(inputTempTensor), (uint64_t)(src1.GetAddr() + offset * src1TypeSize),
(uint64_t)(startAddrUB), elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsRemainPerLine);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(inputTempTensor), (uint64_t)(src1.GetAddr() + offset * src1TypeSize),
(uint64_t)(startAddrUB), elementsRemainPerLine);
}
} else {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(inputTempTensor), (uint64_t)(src1.GetAddr() + offset * src1TypeSize),
(uint64_t)(startAddrUB), elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(inputTempTensor), (uint64_t)(src1.GetAddr() + offset * src1TypeSize),
(uint64_t)(startAddrUB), elementsRemainPerLine);
}
}
}
}
}
}
}
#define OP_TILE_OP_WHERE_SS TWhereSS
template <typename TDst, typename TTmp, typename TCond, typename TSrc0, typename TSrc1>
TILEOP void TWhereSS(TDst dst, TTmp tmpbuf, TCond condition, TSrc0 src0, TSrc1 src1)
{
using ShapeValueType = typename Std::tuple_element<0, typename TDst::Shape>::type;
constexpr auto shapeSize = Std::tuple_size<typename TDst::Shape>::value;
constexpr unsigned elementsPerCount = 1024;
constexpr unsigned bitsOfByte = 8;
uint64_t tmpbufAddr = tmpbuf.GetAddr();
__ubuf__ half* castCondition = reinterpret_cast<__ubuf__ half*>(tmpbufAddr);
__ubuf__ half* compareCondition = castCondition + elementsPerCount;
__ubuf__ uint8_t* vcmpBitResult = reinterpret_cast<__ubuf__ uint8_t*>(compareCondition + elementsPerCount);
__ubuf__ typename TDst::Type* inputTempTensor =
(__ubuf__ typename TDst::Type*)(vcmpBitResult + elementsPerCount / bitsOfByte);
__ubuf__ typename TDst::Type* otherTempTensor = (__ubuf__ typename TDst::Type*)(inputTempTensor + elementsPerCount);
__ubuf__ uint8_t* startAddrUB = reinterpret_cast<__ubuf__ uint8_t*>(otherTempTensor + elementsPerCount);
using TileTSrc0 = pto::Tile<pto::TileType::Vec, TSrc0, 1, elementsPerCount, pto::BLayout::RowMajor, -1, -1>;
TileTSrc0 src0Tile(1, elementsPerCount);
TileTSrc0 src1Tile(1, elementsPerCount);
pto::TASSIGN(src0Tile, (uint64_t)(inputTempTensor));
pto::TASSIGN(src1Tile, (uint64_t)(otherTempTensor));
pto::TEXPANDS(src0Tile, src0);
pto::TEXPANDS(src1Tile, src1);
constexpr size_t expectSize = 5;
const auto dstLayout = dst.GetLayout();
const auto conditionLayout = condition.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<0, expectSize>();
auto shape1 = dstLayout.template GetShapeDim<1, expectSize>();
auto shape2 = dstLayout.template GetShapeDim<2, expectSize>();
auto shape3 = dstLayout.template GetShapeDim<3, expectSize>();
auto shape4 = dstLayout.template GetShapeDim<4, expectSize>();
auto conditionShape = condition.GetLayout().template GetShapeDim<4, expectSize>();
auto stride0 = dstLayout.template GetStrideDim<0, expectSize>();
auto stride1 = dstLayout.template GetStrideDim<1, expectSize>();
auto stride2 = dstLayout.template GetStrideDim<2, expectSize>();
auto stride3 = dstLayout.template GetStrideDim<3, expectSize>();
auto conditionStride0 = conditionLayout.template GetStrideDim<0, expectSize>();
auto conditionStride1 = conditionLayout.template GetStrideDim<1, expectSize>();
auto conditionStride2 = conditionLayout.template GetStrideDim<2, expectSize>();
auto conditionStride3 = conditionLayout.template GetStrideDim<3, expectSize>();
constexpr auto dstTypeSize = sizeof(typename TDst::Type);
constexpr auto conditionTypeSize = sizeof(typename TCond::Type);
unsigned numCountPerLine = shape4 / elementsPerCount;
unsigned elementsRemainPerLine = shape4 % elementsPerCount;
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
for (LoopVar n3Index = 0; n3Index < shape3; ++n3Index) {
if constexpr (std::is_same_v<typename TCond::Type, bool>) {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsPerCount);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(inputTempTensor), (uint64_t)(otherTempTensor), (uint64_t)(startAddrUB),
elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
CaculateMask<elementsPerCount>(
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize), castCondition,
compareCondition, vcmpBitResult, elementsRemainPerLine);
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize), (uint64_t)(vcmpBitResult),
(uint64_t)(inputTempTensor), (uint64_t)(otherTempTensor), (uint64_t)(startAddrUB),
elementsRemainPerLine);
}
} else {
for (LoopVar j = 0; j < numCountPerLine; j++) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
j * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + j * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(inputTempTensor), (uint64_t)(otherTempTensor), (uint64_t)(startAddrUB),
elementsPerCount);
}
if (elementsRemainPerLine) {
auto conditionOffset = n0Index * conditionStride0 + n1Index * conditionStride1 +
n2Index * conditionStride2 + n3Index * conditionStride3 +
numCountPerLine * elementsPerCount;
auto offset = n0Index * stride0 + n1Index * stride1 + n2Index * stride2 +
n3Index * stride3 + numCountPerLine * elementsPerCount;
ProcessWhere<typename TDst::Type, elementsPerCount>(
(uint64_t)(dst.GetAddr() + offset * dstTypeSize),
(uint64_t)(condition.GetAddr() + conditionOffset * conditionTypeSize),
(uint64_t)(inputTempTensor), (uint64_t)(otherTempTensor), (uint64_t)(startAddrUB),
elementsRemainPerLine);
}
}
}
}
}
}
}
#endif