* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file binary_scalar.h
* \brief
*/
#ifndef TILEOP_TILE_OPERATOR_BINARY_SCALAR__H
#define TILEOP_TILE_OPERATOR_BINARY_SCALAR__H
#include "binary.h"
template <BinaryScalarOp op, auto PrecisionType = 0, typename LastUse, typename T0, typename T1, typename Scalar>
TILEOP void BinaryScalarComputeImpl(T0 dst, T1 src0, Scalar src1)
{
constexpr auto n1 = Std::tuple_element<DIM_1ST, LastUse>::type::value;
constexpr auto n2 = Std::tuple_element<DIM_2ND, LastUse>::type::value;
if constexpr (op == BinaryScalarOp::ADD) {
PTO_WITH_LAST_USE(pto::TADDS(dst, src0, src1), n1, n2);
return;
}
if constexpr (op == BinaryScalarOp::SUB) {
if constexpr (std::is_same<Scalar, half>::value) {
PTO_WITH_LAST_USE(
pto::TADDS(dst, src0, static_cast<half>(static_cast<float>(-1) * static_cast<float>(src1))), n1, n2);
} else {
PTO_WITH_LAST_USE(pto::TADDS(dst, src0, -src1), n1, n2);
}
return;
}
if constexpr (op == BinaryScalarOp::MUL) {
PTO_WITH_LAST_USE(pto::TMULS(dst, src0, src1), n1, n2);
return;
}
if constexpr (op == BinaryScalarOp::DIV) {
PTO_WITH_LAST_USE(pto::TDIVS<PrecisionType>(dst, src0, src1), n1, n2);
return;
}
if constexpr (op == BinaryScalarOp::MAX) {
PTO_WITH_LAST_USE(pto::TMAXS(dst, src0, src1), n1, n2);
return;
}
if constexpr (op == BinaryScalarOp::MIN) {
PTO_WITH_LAST_USE(pto::TMINS(dst, src0, src1), n1, n2);
return;
}
if constexpr (op == BinaryScalarOp::BITWISEAND) {
pto::TANDS(dst, src0, src1);
return;
}
if constexpr (op == BinaryScalarOp::BITWISEOR) {
pto::TORS(dst, src0, src1);
return;
}
if constexpr (op == BinaryScalarOp::MOD) {
pto::TFMODS<PrecisionType>(dst, src0, src1);
return;
}
if constexpr (op == BinaryScalarOp::LRELU) {
pto::TLRELU(dst, src0, src1);
return;
}
}
template <BinaryScalarOp op, auto PrecisionType = 0, typename LastUse, typename T0, typename T1, typename Scalar>
TILEOP void BinaryScalarCompute(T0 dst, T1 src0, Scalar src1)
{
const auto dstLayout = dst.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto dstTile = PtoTile<T0>(dst);
auto src0Tile = PtoTile<T1>(src0);
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
dstTile.Assign(dst, tileOffsets);
src0Tile.Assign(src0, tileOffsets);
BinaryScalarComputeImpl<op, PrecisionType, LastUse>(dstTile.Data(), src0Tile.Data(), src1);
}
}
}
}
#define OP_TILE_OP_ADDS TAddS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TAddS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::ADD, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_SUBS TSubS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TSubS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::SUB, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_MULS TMulS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMulS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::MUL, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_DIVS TDivS
template <
auto PrecisionType = pto::DivAlgorithm::DEFAULT, typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0,
typename T1>
TILEOP void TDivS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::DIV, PrecisionType, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_MAXS TMaxS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMaxS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::MAX, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_MINS TMinS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TMinS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::MIN, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_LRELU TLReLU
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TLReLU(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::LRELU, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_BITWISEANDS TBitwiseAndS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TBitwiseAndS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::BITWISEAND, 0, LastUse>(dst, src0, src1);
}
#define OP_TILE_OP_BITWISEORS TBitwiseOrS
template <typename LastUse = LastUse2Dim<0, 0>, typename Scalar, typename T0, typename T1>
TILEOP void TBitwiseOrS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::BITWISEOR, 0, LastUse>(dst, src0, src1);
}
TILEOP int gcds(int a, int b)
{
if (a < 0) {
a = 0 - a;
}
if (b < 0) {
b = 0 - b;
}
while (a % b != 0) {
int c = a % b;
a = b;
b = c;
}
return b;
}
#define OP_TILE_OP_GCDS TGcdS
template <typename Scalar, typename T0, typename T1>
TILEOP void TGcdS(T0 dst, T1 src0, Scalar src1)
{
const auto dstLayout = dst.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto shape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto shape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
auto dstStride0 = dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
auto dstStride1 = dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
auto dstStride2 = dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
auto dstStride3 = dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
constexpr auto dstTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
constexpr auto dstTileW = TileOp::GetTensorTileShapeDim<T0, 4, 5>();
auto src0Addr = (__ubuf__ typename T1::Type*)((uint64_t)(src0.GetAddr()));
auto dstAddr = (__ubuf__ typename T0::Type*)((uint64_t)(dst.GetAddr()));
set_flag(PIPE_V, PIPE_S, EVENT_ID7);
wait_flag(PIPE_V, PIPE_S, EVENT_ID7);
for (LoopVar n = 0; n < shape0; n++) {
for (LoopVar j = 0; j < shape1; j++) {
for (LoopVar k = 0; k < shape2; k++) {
for (LoopVar m = 0; m < shape3; m++) {
for (LoopVar i = 0; i < shape4; i++) {
int tmpStride = n * dstStride0 + j * dstStride1 + k * dstStride2 + m * dstStride3 + i;
dstAddr[tmpStride] = gcds(src0Addr[tmpStride], src1);
}
}
}
}
}
set_flag(PIPE_S, PIPE_V, EVENT_ID7);
wait_flag(PIPE_S, PIPE_V, EVENT_ID7);
}
#define OP_TILE_OP_MODS TModS
template <
auto PrecisionType = pto::FmodSAlgorithm::DEFAULT, typename LastUse = LastUse2Dim<0, 0>, typename Scalar,
typename T0, typename T1>
TILEOP void TModS(T0 dst, T1 src0, Scalar src1)
{
BinaryScalarCompute<BinaryScalarOp::MOD, PrecisionType, LastUse>(dst, src0, src1);
}
template <BinaryScalarOp op, auto PrecisionType = 0, typename T0, typename T1, typename Scalar, typename T2>
TILEOP void BinaryScalarTmpComputeImpl(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
if constexpr (op == BinaryScalarOp::BITWISEXOR) {
pto::TXORS(dst, src0, src1, tmp);
return;
}
if constexpr (op == BinaryScalarOp::REM) {
pto::TREMS<PrecisionType>(dst, src0, src1, tmp);
return;
}
if constexpr (op == BinaryScalarOp::POW) {
pto::TPOWS<PrecisionType>(dst, src0, src1, tmp);
return;
}
}
template <BinaryScalarOp op, auto PrecisionType = 0, typename T0, typename T1, typename Scalar, typename T2>
TILEOP void BinaryScalarTmpCompute(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
const auto dstLayout = dst.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto dstTile = PtoTile<T0>(dst);
auto src0Tile = PtoTile<T1>(src0);
auto tmpTile = PtoTile<T2>(tmp);
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
dstTile.Assign(dst, tileOffsets);
src0Tile.Assign(src0, tileOffsets);
tmpTile.Assign(tmp, tileOffsets);
BinaryScalarTmpComputeImpl<op, PrecisionType>(dstTile.Data(), src0Tile.Data(), src1, tmpTile.Data());
}
}
}
}
#define OP_TILE_OP_BITWISEXORS TBitwiseXorS
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TBitwiseXorS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
BinaryScalarTmpCompute<BinaryScalarOp::BITWISEXOR, 0>(dst, src0, src1, tmp);
}
#define OP_TILE_OP_REMS TRemainderS
template <typename Scalar, auto PrecisionType = pto::RemSAlgorithm::DEFAULT, typename T0, typename T1, typename T2>
TILEOP void TRemainderS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
BinaryScalarTmpCompute<BinaryScalarOp::REM, PrecisionType>(dst, src0, src1, tmp);
}
#define OP_TILE_OP_POWS TPowS
template <auto PrecisionType = pto::PowAlgorithm::DEFAULT, typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TPowS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
BinaryScalarTmpCompute<BinaryScalarOp::POW, PrecisionType>(dst, src0, src1, tmp);
}
#define OP_TILE_OP_REMRS TRemainderRS
template <typename Scalar, auto PrecisionType = pto::RemAlgorithm::DEFAULT, typename T0, typename T1, typename T2>
TILEOP void TRemainderRS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
const auto dstLayout = dst.GetLayout();
auto shape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto shape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto shape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto shape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto shape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
auto dstTile = PtoTile<T0>(dst);
auto src0Tile = PtoTile<T1>(src0);
constexpr auto tmpTileH = TileOp::GetTensorTileShapeDim<T0, 3, 5>();
constexpr auto tmpTileW = TileOp::GetTensorTileShapeDim<T2, 4, 5>();
using tmp0TileDefine =
pto::Tile<pto::TileType::Vec, typename T2::Type, tmpTileH, tmpTileW, pto::BLayout::RowMajor, -1, -1>;
using tmp1TileDefine =
pto::Tile<pto::TileType::Vec, typename T2::Type, 2, tmpTileW, pto::BLayout::RowMajor, -1, -1>;
tmp0TileDefine tmp0Tile(shape3, shape4);
tmp1TileDefine tmp1Tile(2, shape4);
for (LoopVar n0Index = 0; n0Index < shape0; ++n0Index) {
for (LoopVar n1Index = 0; n1Index < shape1; ++n1Index) {
for (LoopVar n2Index = 0; n2Index < shape2; ++n2Index) {
auto tileOffsets = TileOffset(n0Index, n1Index, n2Index);
dstTile.Assign(dst, tileOffsets);
src0Tile.Assign(src0, tileOffsets);
pto::TASSIGN(tmp0Tile, (uint64_t)(tmp.GetAddr()));
pto::TASSIGN(tmp1Tile, (uint64_t)(tmp.GetAddr() + shape3 * tmpTileW * sizeof(typename T2::Type)));
pto::TEXPANDS(tmp0Tile, src1);
#ifdef __DAV_V220
pipe_barrier(PIPE_V);
#endif
pto::TREM<PrecisionType>(dstTile.Data(), tmp0Tile, src0Tile.Data(), tmp1Tile);
}
}
}
}
#define OP_TILE_OP_FLOORDIVS TFloorDivS
template <typename Scalar, typename T0, typename T1, typename T2>
TILEOP void TFloorDivS(T0 dst, T1 src0, Scalar src1, T2 tmp)
{
static_assert(std::is_same_v<typename T1::Type, int32_t>);
const auto dstLayout = dst.GetLayout();
auto dstShape0 = dstLayout.template GetShapeDim<DIM_1ST, MAX_DIMS>();
auto dstShape1 = dstLayout.template GetShapeDim<DIM_2ND, MAX_DIMS>();
auto dstShape2 = dstLayout.template GetShapeDim<DIM_3RD, MAX_DIMS>();
auto dstShape3 = dstLayout.template GetShapeDim<DIM_4TH, MAX_DIMS>();
auto dstShape4 = dstLayout.template GetShapeDim<DIM_5TH, MAX_DIMS>();
if (dstShape0 == 0 || dstShape1 == 0 || dstShape2 == 0 || dstShape3 == 0 || dstShape4 == 0) {
return;
}
auto dstStride0 = dstLayout.template GetStrideDim<DIM_1ST, MAX_DIMS>();
auto dstStride1 = dstLayout.template GetStrideDim<DIM_2ND, MAX_DIMS>();
auto dstStride2 = dstLayout.template GetStrideDim<DIM_3RD, MAX_DIMS>();
auto dstStride3 = dstLayout.template GetStrideDim<DIM_4TH, MAX_DIMS>();
constexpr auto tileW = TileOp::GetTensorTileShapeDim<T0, DIM_5TH, MAX_DIMS>();
constexpr auto dstTypeSize = sizeof(typename T0::Type);
for (LoopVar n0Index = 0; n0Index < dstShape0; n0Index++) {
for (LoopVar n1Index = 0; n1Index < dstShape1; n1Index++) {
for (LoopVar n2Index = 0; n2Index < dstShape2; n2Index++) {
for (LoopVar n3Index = 0; n3Index < dstShape3; n3Index++) {
auto offset =
n0Index * dstStride0 + n1Index * dstStride1 + n2Index * dstStride2 + n3Index * dstStride3;
#ifdef __DAV_V220
using IntTileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
using FloatTileDefine =
pto::Tile<pto::TileType::Vec, float, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
IntTileDefine src0Tile(1, dstShape4);
IntTileDefine dstTile(1, dstShape4);
FloatTileDefine tmp0Tile(1, dstShape4);
FloatTileDefine tmp1Tile(1, dstShape4);
pto::TASSIGN(tmp0Tile, (uint64_t)(tmp.GetAddr()));
pto::TASSIGN(tmp1Tile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));
pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + offset * dstTypeSize));
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + offset * dstTypeSize));
pto::TCVT(tmp0Tile, src0Tile, pto::RoundMode::CAST_NONE, pto::SaturationMode::OFF);
pipe_barrier(PIPE_V);
pto::TDIVS(tmp1Tile, tmp0Tile, static_cast<float>(src1));
pipe_barrier(PIPE_V);
pto::TCVT(dstTile, tmp1Tile, pto::RoundMode::CAST_FLOOR);
pipe_barrier(PIPE_V);
#else
using DataTileDefine =
pto::Tile<pto::TileType::Vec, typename T0::Type, 1, tileW, pto::BLayout::RowMajor, -1, -1>;
using MaskTileDefine =
pto::Tile<pto::TileType::Vec, uint8_t, 1, tileW * 4, pto::BLayout::RowMajor, -1, -1>;
DataTileDefine src0Tile(1, dstShape4);
DataTileDefine dstTile(1, dstShape4);
DataTileDefine tmp0DataTile(1, dstShape4);
DataTileDefine tmp1DataTile(1, dstShape4);
DataTileDefine tmp2DataTile(1, dstShape4);
MaskTileDefine tmp0MaskTile(1, dstShape4);
MaskTileDefine tmp1MaskTile(1, dstShape4);
pto::TASSIGN(tmp0DataTile, (uint64_t)(tmp.GetAddr()));
pto::TASSIGN(tmp1DataTile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));
pto::TASSIGN(tmp2DataTile, (uint64_t)(tmp.GetAddr() + 2 * tileW * dstTypeSize));
pto::TASSIGN(src0Tile, (uint64_t)(src0.GetAddr() + offset * dstTypeSize));
pto::TASSIGN(dstTile, (uint64_t)(dst.GetAddr() + offset * dstTypeSize));
pto::TASSIGN(tmp0MaskTile, (uint64_t)(tmp.GetAddr()));
pto::TASSIGN(tmp1MaskTile, (uint64_t)(tmp.GetAddr() + tileW * dstTypeSize));
if (src1 == 0) {
constexpr int32_t pos = 0x7FFF7F7F, neg = 0x80008080;
pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::LT);
pto::TSELS(dstTile, tmp0MaskTile, dstTile, tmp1DataTile, pos);
pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::GE);
pto::TSELS(dstTile, tmp0MaskTile, dstTile, tmp1DataTile, neg);
} else {
uint8_t src1Mask = 0;
if (src1 < 0) {
src1Mask = 0xff;
}
pto::TCMPS(tmp0MaskTile, src0Tile, 0, CmpMode::LT);
pto::TXORS(tmp1MaskTile, tmp0MaskTile, src1Mask, dstTile);
pto::TDIVS(dstTile, src0Tile, src1);
pto::TMULS(tmp0DataTile, dstTile, -src1);
pto::TADD(tmp2DataTile, tmp0DataTile, src0Tile);
pto::TCMPS(tmp0MaskTile, tmp2DataTile, 0, CmpMode::NE);
pto::TAND(tmp0MaskTile, tmp1MaskTile, tmp0MaskTile);
pto::TADDS(tmp2DataTile, dstTile, -1);
pto::TSEL(dstTile, tmp0MaskTile, tmp2DataTile, dstTile, tmp1DataTile);
}
#endif
}
}
}
}
}
#endif