* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file pair_binary.h
* \brief
*/
#ifndef TILEOP_TILE_OPERATOR_PAIR_BINARY_H
#define TILEOP_TILE_OPERATOR_PAIR_BINARY_H
#include "utils/layout.h"
#include "utils/tile_tensor.h"
template <typename T0, typename T1, size_t index = 0, size_t shapeSize = 0>
TILEOP int GetReduceAxisIndex(T0 src0, T1 src1)
{
if constexpr (index < shapeSize) {
const auto src0Layout = src0.GetLayout();
const auto src1Layout = src1.GetLayout();
auto src0CurShape = src0Layout.template GetShapeDim<index, shapeSize>();
auto src1CurShape = src1Layout.template GetShapeDim<index, shapeSize>();
if (src0CurShape != src1CurShape) {
return index;
}
return GetReduceAxisIndex<T0, T1, index + 1, shapeSize>(src0, src1);
} else {
return 0;
}
}
template <int startIndex, int endIndex, typename TileShape>
TILEOP constexpr int GetTileShapeMergeResult()
{
if constexpr (startIndex >= endIndex) {
return 1;
} else {
constexpr auto curShape = Std::tuple_element<startIndex, TileShape>::type::value;
return curShape * GetTileShapeMergeResult<startIndex + 1, endIndex, TileShape>();
}
}
template <int reduceAxisIndex, size_t shapeSize, bool isHeight, typename T0>
TILEOP constexpr int GetOneDimensionTileSize()
{
if constexpr (reduceAxisIndex < shapeSize - 1) {
if constexpr (isHeight) {
return Std::tuple_element<reduceAxisIndex, typename T0::TileShape>::type::value;
} else {
return GetTileShapeMergeResult<reduceAxisIndex + 1, shapeSize, typename T0::TileShape>();
}
} else {
if constexpr (isHeight) {
return TileOp::GetTensorTileShapeDim<T0, 3, 5>();
} else {
return TileOp::GetTensorTileShapeDim<T0, 4, 5>();
}
}
}
template <int reduceAxisIndex, size_t shapeSize, typename T0>
TILEOP int GetHeightValidSize(T0 tensor)
{
const auto tensorLayout = tensor.GetLayout();
if constexpr (reduceAxisIndex == shapeSize - 1) {
return tensorLayout.template GetShapeDim<reduceAxisIndex - 1, shapeSize>();
} else {
return tensorLayout.template GetShapeDim<reduceAxisIndex, shapeSize>();
}
}
template <int reduceAxisIndex, size_t shapeSize, typename T0>
TILEOP int GetWidthValidSize(T0 tensor)
{
const auto tensorLayout = tensor.GetLayout();
if constexpr (reduceAxisIndex == shapeSize - 1) {
return tensorLayout.template GetShapeDim<reduceAxisIndex, shapeSize>();
} else {
return GetTileShapeMergeResult<reduceAxisIndex + 1, shapeSize, typename T0::TileShape>();
}
}
template <PairBinaryOp op, typename T0, typename T1, typename T2>
TILEOP void PairBinaryComputeImpl(T0 dst, T1 src0, T2 src1)
{
if constexpr (op == PairBinaryOp::ADD) {
pto::TPARTADD(dst, src0, src1);
return;
}
if constexpr (op == PairBinaryOp::MAX) {
pto::TPARTMAX(dst, src0, src1);
return;
}
if constexpr (op == PairBinaryOp::MIN) {
pto::TPARTMIN(dst, src0, src1);
return;
}
if constexpr (op == PairBinaryOp::MUL) {
pto::TPARTMUL(dst, src0, src1);
return;
}
}
template <PairBinaryOp op, int reduceAxisIndex, typename T0, typename T1, typename T2>
TILEOP void InnerPairBinaryCompute(T0 dst, T1 src0, T2 src1)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
if constexpr (reduceAxisIndex < shapeSize) {
const auto dstLayout = dst.GetLayout();
const auto src0Layout = src0.GetLayout();
const auto src1Layout = src1.GetLayout();
constexpr auto dstTypeSize = sizeof(typename T0::Type);
constexpr auto src0TypeSize = sizeof(typename T1::Type);
constexpr auto src1TypeSize = sizeof(typename T2::Type);
size_t dstShape0 = 1;
size_t dstShape1 = 1;
size_t dstShape2 = 1;
size_t dstStride0 = 0;
size_t dstStride1 = 0;
size_t dstStride2 = 0;
size_t src0Stride0 = 0;
size_t src0Stride1 = 0;
size_t src0Stride2 = 0;
size_t src1Stride0 = 0;
size_t src1Stride1 = 0;
size_t src1Stride2 = 0;
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 2) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 1)) {
dstShape0 = dstLayout.template GetShapeDim<0, shapeSize>();
dstStride0 = dstLayout.template GetStrideDim<0, shapeSize>();
src0Stride0 = src0Layout.template GetStrideDim<0, shapeSize>();
src1Stride0 = src1Layout.template GetStrideDim<0, shapeSize>();
}
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 3) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 2)) {
dstShape1 = dstLayout.template GetShapeDim<1, shapeSize>();
dstStride1 = dstLayout.template GetStrideDim<1, shapeSize>();
src0Stride1 = src0Layout.template GetStrideDim<1, shapeSize>();
src1Stride1 = src1Layout.template GetStrideDim<1, shapeSize>();
}
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 4) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 3)) {
dstShape2 = dstLayout.template GetShapeDim<2, shapeSize>();
dstStride2 = dstLayout.template GetStrideDim<2, shapeSize>();
src0Stride2 = src0Layout.template GetStrideDim<2, shapeSize>();
src1Stride2 = src1Layout.template GetStrideDim<2, shapeSize>();
}
constexpr auto dstTileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T0>();
constexpr auto dstTileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T0>();
constexpr auto src0TileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T1>();
constexpr auto src0TileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T1>();
constexpr auto src1TileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T2>();
constexpr auto src1TileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T2>();
int dstShape3;
int src0Shape3;
int src1Shape3;
if constexpr (shapeSize != 1) {
dstShape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T0>(dst);
src0Shape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T1>(src0);
src1Shape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T2>(src1);
} else {
dstShape3 = 1;
src0Shape3 = 1;
src1Shape3 = 1;
}
auto dstShape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T0>(dst);
auto src0Shape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T1>(src0);
auto src1Shape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T2>(src1);
using DstTileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
using Src0TileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, src0TileH, src0TileW, pto::BLayout::RowMajor, -1, -1>;
using Src1TileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, src1TileH, src1TileW, pto::BLayout::RowMajor, -1, -1>;
for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
DstTileDefine dstTile(dstShape3, dstShape4);
Src0TileDefine src0Tile(src0Shape3, src0Shape4);
Src1TileDefine src1Tile(src1Shape3, src1Shape4);
auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
auto src0Offset = n0Index * src0Stride0 + n1Index * src0Stride1 + n2Index * src0Stride2;
auto src1Offset = n0Index * src1Stride0 + n1Index * src1Stride1 + n2Index * src1Stride2;
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + dstOffset * dstTypeSize));
pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + src0Offset * src0TypeSize));
pto::TASSIGN(src1Tile, (uint64_t)(src1.GetAddr() + src1Offset * src1TypeSize));
PairBinaryComputeImpl<op>(dstTile, src0Tile, src1Tile);
}
}
}
}
}
template <PairBinaryOp op, int reduceAxisIndex, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5>
TILEOP void InnerPairArgBinaryCompute(T0 dstValue, T1 dstIndex, T2 src0Value, T3 src0Index, T4 src1Value, T5 src1Index)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
if constexpr (reduceAxisIndex < shapeSize) {
const auto dstLayout = dstValue.GetLayout();
const auto src0Layout = src0Value.GetLayout();
const auto src1Layout = src1Value.GetLayout();
constexpr auto dstTypeSize = sizeof(typename T0::Type);
constexpr auto src0TypeSize = sizeof(typename T2::Type);
constexpr auto src1TypeSize = sizeof(typename T4::Type);
size_t dstShape0 = 1;
size_t dstShape1 = 1;
size_t dstShape2 = 1;
size_t dstStride0 = 0;
size_t dstStride1 = 0;
size_t dstStride2 = 0;
size_t src0Stride0 = 0;
size_t src0Stride1 = 0;
size_t src0Stride2 = 0;
size_t src1Stride0 = 0;
size_t src1Stride1 = 0;
size_t src1Stride2 = 0;
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 2) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 1)) {
dstShape0 = dstLayout.template GetShapeDim<0, shapeSize>();
dstStride0 = dstLayout.template GetStrideDim<0, shapeSize>();
src0Stride0 = src0Layout.template GetStrideDim<0, shapeSize>();
src1Stride0 = src1Layout.template GetStrideDim<0, shapeSize>();
}
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 3) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 2)) {
dstShape1 = dstLayout.template GetShapeDim<1, shapeSize>();
dstStride1 = dstLayout.template GetStrideDim<1, shapeSize>();
src0Stride1 = src0Layout.template GetStrideDim<1, shapeSize>();
src1Stride1 = src1Layout.template GetStrideDim<1, shapeSize>();
}
if constexpr (
(reduceAxisIndex == shapeSize - 1 && reduceAxisIndex >= 4) ||
(reduceAxisIndex < shapeSize - 1 && reduceAxisIndex >= 3)) {
dstShape2 = dstLayout.template GetShapeDim<2, shapeSize>();
dstStride2 = dstLayout.template GetStrideDim<2, shapeSize>();
src0Stride2 = src0Layout.template GetStrideDim<2, shapeSize>();
src1Stride2 = src1Layout.template GetStrideDim<2, shapeSize>();
}
constexpr auto dstTileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T0>();
constexpr auto dstTileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T0>();
constexpr auto src0TileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T2>();
constexpr auto src0TileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T2>();
constexpr auto src1TileH = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, true, T4>();
constexpr auto src1TileW = GetOneDimensionTileSize<reduceAxisIndex, shapeSize, false, T4>();
int dstShape3;
int src0Shape3;
int src1Shape3;
if constexpr (shapeSize != 1) {
dstShape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T0>(dstValue);
src0Shape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T2>(src0Value);
src1Shape3 = GetHeightValidSize<reduceAxisIndex, shapeSize, T4>(src1Value);
} else {
dstShape3 = 1;
src0Shape3 = 1;
src1Shape3 = 1;
}
auto dstShape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T0>(dstValue);
auto src0Shape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T2>(src0Value);
auto src1Shape4 = GetWidthValidSize<reduceAxisIndex, shapeSize, T4>(src1Value);
using DstValueTileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
using DstIndexTileDefine =
pto::Tile<pto::TileType::Vec, typename T1::Type, dstTileH, dstTileW, pto::BLayout::RowMajor, -1, -1>;
using Src0ValueTileDefine =
pto::Tile<pto::TileType::Vec, typename T2::Type, src0TileH, src0TileW, pto::BLayout::RowMajor, -1, -1>;
using Src0IndexTileDefine =
pto::Tile<pto::TileType::Vec, typename T3::Type, src0TileH, src0TileW, pto::BLayout::RowMajor, -1, -1>;
using Src1ValueTileDefine =
pto::Tile<pto::TileType::Vec, typename T4::Type, src1TileH, src1TileW, pto::BLayout::RowMajor, -1, -1>;
using Src1IndexTileDefine =
pto::Tile<pto::TileType::Vec, typename T5::Type, src1TileH, src1TileW, pto::BLayout::RowMajor, -1, -1>;
for (LoopVar n0Index = 0; n0Index < dstShape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < dstShape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < dstShape2; ++n2Index) {
DstValueTileDefine dstValueTile(dstShape3, dstShape4);
DstIndexTileDefine dstIndexTile(dstShape3, dstShape4);
Src0ValueTileDefine src0ValueTile(src0Shape3, src0Shape4);
Src0IndexTileDefine src0IndexTile(src0Shape3, src0Shape4);
Src1ValueTileDefine src1ValueTile(src1Shape3, src1Shape4);
Src1IndexTileDefine src1IndexTile(src1Shape3, src1Shape4);
auto dstOffset = n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2;
auto src0Offset = n0Index * src0Stride0 + n1Index * src0Stride1 + n2Index * src0Stride2;
auto src1Offset = n0Index * src1Stride0 + n1Index * src1Stride1 + n2Index * src1Stride2;
pto::TASSIGN(dstValueTile, (uint64_t)(dstValue.GetAddr() + dstOffset * dstTypeSize));
pto::TASSIGN(dstIndexTile, (uint64_t)(dstIndex.GetAddr() + dstOffset * dstTypeSize));
pto::TASSIGN(src0ValueTile, (uint64_t)(src0Value.GetAddr() + src0Offset * src0TypeSize));
pto::TASSIGN(src0IndexTile, (uint64_t)(src0Index.GetAddr() + src0Offset * src0TypeSize));
pto::TASSIGN(src1ValueTile, (uint64_t)(src1Value.GetAddr() + src1Offset * src1TypeSize));
pto::TASSIGN(src1IndexTile, (uint64_t)(src1Index.GetAddr() + src1Offset * src1TypeSize));
if constexpr (op == PairBinaryOp::ARGMAX) {
pto::TPARTARGMAX(dstValueTile, src0ValueTile, src1ValueTile, dstIndexTile, src0IndexTile, src1IndexTile);
} else {
pto::TPARTARGMIN(dstValueTile, src0ValueTile, src1ValueTile, dstIndexTile, src0IndexTile, src1IndexTile);
}
}
}
}
}
}
template <PairBinaryOp op, typename T0, typename T1, typename T2>
TILEOP void PairBinaryCompute(T0 dst, T1 src0, T2 src1)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
size_t reduceAxisIndex = GetReduceAxisIndex<T1, T2, 0, shapeSize>(src0, src1);
if (reduceAxisIndex == 0) {
InnerPairBinaryCompute<op, 0>(dst, src0, src1);
} else if (reduceAxisIndex == 1) {
InnerPairBinaryCompute<op, 1>(dst, src0, src1);
} else if (reduceAxisIndex == 2) {
InnerPairBinaryCompute<op, 2>(dst, src0, src1);
} else if (reduceAxisIndex == 3) {
InnerPairBinaryCompute<op, 3>(dst, src0, src1);
} else if (reduceAxisIndex == 4) {
InnerPairBinaryCompute<op, 4>(dst, src0, src1);
}
}
template <PairBinaryOp op, typename T0, typename T1, typename T2, typename T3, typename T4, typename T5>
TILEOP void PairArgBinaryCompute(T0 dstValue, T1 dstIndex, T2 src0Value, T3 src0Index, T4 src1Value, T5 src1Index)
{
constexpr auto shapeSize = Std::tuple_size<typename T0::Shape>::value;
size_t reduceAxisIndex = GetReduceAxisIndex<T2, T4, 0, shapeSize>(src0Value, src1Value);
if (reduceAxisIndex == 0) {
InnerPairArgBinaryCompute<op, 0>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
} else if (reduceAxisIndex == 1) {
InnerPairArgBinaryCompute<op, 1>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
} else if (reduceAxisIndex == 2) {
InnerPairArgBinaryCompute<op, 2>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
} else if (reduceAxisIndex == 3) {
InnerPairArgBinaryCompute<op, 3>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
} else if (reduceAxisIndex == 4) {
InnerPairArgBinaryCompute<op, 4>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
}
}
#define OP_TILE_OP_PAIRSUM TPairSum
template <typename T0, typename T1, typename T2>
TILEOP void TPairSum(T0 dst, T1 src0, T2 src1)
{
PairBinaryCompute<PairBinaryOp::ADD>(dst, src0, src1);
}
#define OP_TILE_OP_PAIRMAX TPairMax
template <typename T0, typename T1, typename T2>
TILEOP void TPairMax(T0 dst, T1 src0, T2 src1)
{
PairBinaryCompute<PairBinaryOp::MAX>(dst, src0, src1);
}
#define OP_TILE_OP_PAIRMIN TPairMin
template <typename T0, typename T1, typename T2>
TILEOP void TPairMin(T0 dst, T1 src0, T2 src1)
{
PairBinaryCompute<PairBinaryOp::MIN>(dst, src0, src1);
}
#define OP_TILE_OP_PAIRPROD TPairProd
template <typename T0, typename T1, typename T2>
TILEOP void TPairProd(T0 dst, T1 src0, T2 src1)
{
PairBinaryCompute<PairBinaryOp::MUL>(dst, src0, src1);
}
#define OP_TILE_OP_PAIRARGMAX TPairArgMax
template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5>
TILEOP void TPairArgMax(T0 dstValue, T1 dstIndex, T2 src0Value, T3 src0Index, T4 src1Value, T5 src1Index)
{
PairArgBinaryCompute<PairBinaryOp::ARGMAX>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
}
#define OP_TILE_OP_PAIRARGMIN TPairArgMin
template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5>
TILEOP void TPairArgMin(T0 dstValue, T1 dstIndex, T2 src0Value, T3 src0Index, T4 src1Value, T5 src1Index)
{
PairArgBinaryCompute<PairBinaryOp::ARGMIN>(dstValue, dstIndex, src0Value, src0Index, src1Value, src1Index);
}
#endif