* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file unary.h
* \brief
*/
#pragma once
#include <string>
#include "interface/utils/common.h"
#include "interface/operation/opcode.h"
#include "interface/operation/operation_common.h"
#include "interface/function/function.h"
#include "interface/program/program.h"
#include "interface/configs/config_manager.h"
namespace npu::tile_fwk {
enum class BinaryOpType {
ADD,
SUB,
MUL,
DIV,
MAX,
MIN,
POW,
ADD_BRC,
SUB_BRC,
MUL_BRC,
DIV_BRC,
MAX_BRC,
MIN_BRC,
MOD_BRC,
S_ADD,
S_SUB,
S_MUL,
S_DIV,
S_MAX,
S_MIN,
MAXIMUM,
MINIMUM,
LRELU,
ATAN2,
CMP,
MOD,
REM,
REMR,
BITWISEAND,
BITWISEOR,
BITWISEXOR,
EXPANDEXPDIF,
COPYSIGN,
GCD,
FLOORDIV,
};
template <BinaryOpType T>
std::string GetBinaryOpName()
{
switch (T) {
case BinaryOpType::ADD:
return "ADD";
case BinaryOpType::SUB:
return "SUB";
case BinaryOpType::MUL:
return "MUL";
case BinaryOpType::DIV:
return "DIV";
case BinaryOpType::MAX:
return "MAX";
case BinaryOpType::MIN:
return "MIN";
case BinaryOpType::MAXIMUM:
return "MAXIMUM";
case BinaryOpType::MINIMUM:
return "MINIMUM";
case BinaryOpType::LRELU:
return "LRELU";
case BinaryOpType::ATAN2:
return "ATAN2";
case BinaryOpType::POW:
return "POW";
case BinaryOpType::MOD:
return "MOD";
case BinaryOpType::REM:
return "REM";
case BinaryOpType::REMR:
return "REMR";
case BinaryOpType::CMP:
return "CMP";
case BinaryOpType::S_ADD:
return "S_ADD";
case BinaryOpType::S_SUB:
return "S_SUB";
case BinaryOpType::S_MUL:
return "S_MUL";
case BinaryOpType::S_DIV:
return "S_DIV";
case BinaryOpType::S_MAX:
return "S_MAX";
case BinaryOpType::S_MIN:
return "S_MIN";
case BinaryOpType::BITWISEAND:
return "BITWISEAND";
case BinaryOpType::BITWISEOR:
return "BITWISEOR";
case BinaryOpType::BITWISEXOR:
return "BITWISEXOR";
case BinaryOpType::EXPANDEXPDIF:
return "EXPANDEXPDIF";
case BinaryOpType::COPYSIGN:
return "COPYSIGN";
case BinaryOpType::GCD:
return "GCD";
case BinaryOpType::FLOORDIV:
return "FLOORDIV";
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
return "";
}
}
template <BinaryOpType T, bool WithElement = false, bool WithBrc = false>
Opcode GetBinaryOpNameCode()
{
if constexpr (WithElement) {
#define CASE(X) \
case BinaryOpType::X: \
return Opcode::OP_##X##S
switch (T) {
CASE(ADD);
CASE(SUB);
CASE(MUL);
CASE(DIV);
CASE(MAX);
CASE(MIN);
CASE(MOD);
CASE(POW);
CASE(REM);
CASE(REMR);
CASE(S_ADD);
CASE(S_SUB);
CASE(S_MUL);
CASE(S_DIV);
CASE(S_MAX);
CASE(S_MIN);
CASE(BITWISEAND);
CASE(BITWISEOR);
CASE(BITWISEXOR);
CASE(GCD);
CASE(FLOORDIV);
case BinaryOpType::LRELU:
return Opcode::OP_LRELU;
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
}
#undef CASE
}
if constexpr (WithBrc) {
#define CASE(X) \
case BinaryOpType::X: \
return Opcode::OP_##X##_BRC
switch (T) {
CASE(ADD);
CASE(SUB);
CASE(MUL);
CASE(DIV);
CASE(MAX);
CASE(MIN);
CASE(GCD);
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
}
#undef CASE
}
#define CASE(X) \
case BinaryOpType::X: \
return Opcode::OP_##X
switch (T) {
CASE(ADD);
CASE(SUB);
CASE(MUL);
CASE(DIV);
CASE(S_ADD);
CASE(S_SUB);
CASE(S_MUL);
CASE(S_DIV);
CASE(S_MAX);
CASE(S_MIN);
CASE(MAXIMUM);
CASE(MINIMUM);
CASE(LRELU);
CASE(ATAN2);
CASE(POW);
CASE(MOD);
CASE(REM);
CASE(BITWISEAND);
CASE(BITWISEOR);
CASE(BITWISEXOR);
CASE(EXPANDEXPDIF);
CASE(COPYSIGN);
CASE(GCD);
CASE(FLOORDIV);
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
}
#undef CASE
}
struct LogicalInput {
const LogicalTensorPtr tensor;
TileInfo tileInfo;
};
std::vector<int64_t> BinaryOperationResultShape(LogicalTensorPtr operand1, LogicalTensorPtr operand2);
LogicalTensorPtr BinaryOperationBroadCast(const LogicalTensorPtr& operand, const std::vector<int>& broadCastShape);
int BrcAxisBinaryOp(LogicalTensorPtr operand1, LogicalTensorPtr operand2, int64_t axisNum);
void CheckBinOpOperandsValid(const LogicalTensorPtr& operand1, const LogicalTensorPtr& operand2);
void BinaryOperationOperandCheck(
const std::vector<LogicalTensorPtr>& iOperand, const std::vector<LogicalTensorPtr>& oOperand);
void BroadcastOperandTensor(
LogicalTensorPtr& operand, LogicalTensorPtr& other, LogicalTensorPtr result, Function& function,
const TileShape& tileShape, std::vector<int64_t> dstShape = {});
template <BinaryOpType T>
std::pair<LogicalTensorPtr, Operation*> TensorBinaryOperationWithOp(
Function& function, const Tensor& operand1, const Tensor& operand2)
{
auto oprandT1 = operand1.GetStorage();
auto oprandT2 = operand2.GetStorage();
if (oprandT1->shape.size() != oprandT2->shape.size()) {
std::vector<int> broadCastShape = GetBroadCastShape(oprandT1, oprandT2);
oprandT1 = BinaryOperationBroadCast(oprandT1, broadCastShape);
oprandT2 = BinaryOperationBroadCast(oprandT2, broadCastShape);
}
auto opName = GetBinaryOpName<T>();
CheckBinaryInputTensors(oprandT1, oprandT2, opName);
std::vector<SymbolicScalar> resultValidShape;
std::vector<int64_t> resultShape = BinaryOperationResultShape(oprandT1, oprandT2);
size_t shapeSize = resultShape.size();
if ((!oprandT1->GetDynValidShape().empty()) && (!oprandT2->GetDynValidShape().empty())) {
for (size_t i = 0; i < shapeSize; ++i) {
if (resultShape[i] == oprandT1->shape[i]) {
resultValidShape.push_back(oprandT1->GetDynValidShape()[i]);
} else {
resultValidShape.push_back(oprandT2->GetDynValidShape()[i]);
}
}
}
auto result = std::make_shared<LogicalTensor>(
function, oprandT1->Datatype(), resultShape, resultValidShape, oprandT1->Format());
auto& op = function.AddOperation(GetBinaryOpNameCode<T>(), {oprandT1, oprandT2}, {result});
return {result, &op};
}
template <BinaryOpType T>
LogicalTensorPtr TensorBinaryOperation(Function& function, const Tensor& operand1, const Tensor& operand2)
{
return TensorBinaryOperationWithOp<T>(function, operand1, operand2).first;
}
template <BinaryOpType T>
std::pair<LogicalTensorPtr, Operation*> TensorBinaryOperationScalarWithOp(
Function& function, LogicalTensorPtr operand1, const Element& value)
{
auto opName = GetBinaryOpName<T>();
CheckTensorDimRange(operand1, MIN_TENSOR_DIM, MAX_TENSOR_DIM, opName);
CheckTensorShapeSize(operand1, opName);
auto result =
std::make_shared<LogicalTensor>(function, operand1->Datatype(), operand1->shape, operand1->GetDynValidShape());
auto& op = function.AddOperation(GetBinaryOpNameCode<T, true>(), {operand1}, {result});
op.SetAttribute(OpAttributeKey::scalar, value);
return {result, &op};
}
template <BinaryOpType T>
LogicalTensorPtr TensorBinaryOperationScalar(Function& function, LogicalTensorPtr operand1, const Element& value)
{
return TensorBinaryOperationScalarWithOp<T>(function, operand1, value).first;
}
template <BinaryOpType T>
LogicalTensorPtr TensorBinaryOperationAllScalar(
Function& function, const Tensor& operand1, const Element& value, bool reverseOperand)
{
auto opName = GetBinaryOpName<T>();
auto storage = operand1.GetStorage();
CheckTensorDimRange(storage, MIN_TENSOR_DIM, MAX_TENSOR_DIM, opName);
CheckTensorShapeSize(storage, opName);
auto result = std::make_shared<LogicalTensor>(
function, storage->Datatype(), operand1.GetShape(), storage->GetDynValidShape());
auto& op = function.AddOperation(GetBinaryOpNameCode<T, true>(), {storage}, {result});
op.SetAttribute(OpAttributeKey::scalar, value);
op.SetAttribute(OP_ATTR_PREFIX + "reverseOperand", reverseOperand);
return result;
}
template <BinaryOpType T>
LogicalTensorPtr TensorBinaryOperationAllScalar(Function& function, const Tensor& operand1, const Tensor& operand2)
{
auto opName = GetBinaryOpName<T>();
CheckBinaryInputTensors(operand1.GetStorage(), operand2.GetStorage(), opName);
auto result = std::make_shared<LogicalTensor>(
function, operand1.GetStorage()->Datatype(), operand1.GetShape(), operand1.GetStorage()->GetDynValidShape());
function.AddOperation(GetBinaryOpNameCode<T, false>(), {operand1.GetStorage(), operand2.GetStorage()}, {result});
return result;
}
LogicalTensorPtr GenAllOneTensor(const Shape& shape, std::vector<SymbolicScalar> validShape, const DataType& dataType);
}