* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file calc_vector.cpp
* \brief
*/
#include "interface/interpreter/function.h"
#include "interface/interpreter/operation.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
constexpr int GATHER_IGNORE_AXIS = -2;
template <Opcode opcode>
void ExecuteOpBinary(ExecuteOperationContext* ctx)
{
if (opcode == Opcode::OP_ADD_BRC || opcode == Opcode::OP_SUB_BRC || opcode == Opcode::OP_MUL_BRC ||
opcode == Opcode::OP_DIV_BRC) {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == SIZE_TWO);
} else if (
opcode == Opcode::OP_BITWISEXOR || opcode == Opcode::OP_COPYSIGN || opcode == Opcode::OP_POW ||
opcode == Opcode::OP_FLOORDIV || opcode == Opcode::OP_REM) {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
} else {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
}
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto ret = ctx->ooperandInplaceDataViewList->at(0);
auto tlhs = ctx->ioperandDataViewList->at(0);
auto lhs = tlhs;
auto trhs = ctx->ioperandDataViewList->at(1);
auto rhs = trhs;
auto lhsTensor = ctx->op->GetIOperands()[0];
auto rhsTensor = ctx->op->GetIOperands()[1];
bool lhsFromBrcb =
!lhsTensor->GetProducers().empty() && (*lhsTensor->GetProducers().begin())->GetOpcode() == Opcode::OP_BRCB;
bool rhsFromBrcb =
!rhsTensor->GetProducers().empty() && (*rhsTensor->GetProducers().begin())->GetOpcode() == Opcode::OP_BRCB;
if (lhsFromBrcb) {
lhs = tlhs->View({tlhs->GetShape()[0], 1}, tlhs->GetOffset());
} else if (rhsFromBrcb) {
rhs = trhs->View({trhs->GetShape()[0], 1}, trhs->GetOffset());
}
if (lhsFromBrcb || rhsFromBrcb) {
INTERPRETER_LOGI(
"AxisCombine: detected by BRCB, opcode=%s lhsFromBrcb=%d rhsFromBrcb=%d", ctx->op->GetOpcodeStr().c_str(),
static_cast<int>(lhsFromBrcb), static_cast<int>(rhsFromBrcb));
INTERPRETER_LOGI(
"AxisCombine: lhs(shape=%s validShape=%s offset=%s) rhs(shape=%s validShape=%s offset=%s)",
IntVecToStr(lhs->GetShape()).c_str(), IntVecToStr(lhs->GetValidShape()).c_str(),
IntVecToStr(lhs->GetOffset()).c_str(), IntVecToStr(rhs->GetShape()).c_str(),
IntVecToStr(rhs->GetValidShape()).c_str(), IntVecToStr(rhs->GetOffset()).c_str());
}
if (opcode == Opcode::OP_ADD_BRC || opcode == Opcode::OP_SUB_BRC || opcode == Opcode::OP_MUL_BRC ||
opcode == Opcode::OP_DIV_BRC) {
bool axisCombine = ctx->op->GetBoolAttribute("input_combine_axis_done");
lhs->SetAxisCombine(axisCombine);
rhs->SetAxisCombine(axisCombine);
}
switch (opcode) {
case Opcode::OP_ADD:
calc::Add(ret, lhs, rhs);
break;
case Opcode::OP_ADD_BRC:
calc::Add(ret, lhs, rhs);
break;
case Opcode::OP_PAIRSUM:
calc::PairSum(ret, lhs, rhs);
break;
case Opcode::OP_PAIRPROD:
calc::PairProd(ret, lhs, rhs);
break;
case Opcode::OP_SUB:
calc::Sub(ret, lhs, rhs);
break;
case Opcode::OP_SUB_BRC:
calc::Sub(ret, lhs, rhs);
break;
case Opcode::OP_MUL:
calc::Mul(ret, lhs, rhs);
break;
case Opcode::OP_MUL_BRC:
calc::Mul(ret, lhs, rhs);
break;
case Opcode::OP_DIV:
calc::Div(ret, lhs, rhs);
break;
case Opcode::OP_DIV_BRC:
calc::Div(ret, lhs, rhs);
break;
case Opcode::OP_FLOORDIV:
calc::FloorDiv(ret, lhs, rhs);
break;
case Opcode::OP_ATAN2:
calc::Atan2(ret, lhs, rhs);
break;
case Opcode::OP_POW:
calc::Pow(ret, lhs, rhs);
break;
case Opcode::OP_REM:
calc::Remainder(ret, lhs, rhs);
break;
case Opcode::OP_S_MAX:
calc::Max(ret, lhs, rhs);
break;
case Opcode::OP_PAIRMAX:
calc::PairMax(ret, lhs, rhs);
break;
case Opcode::OP_PAIRMIN:
calc::PairMin(ret, lhs, rhs);
break;
case Opcode::OP_S_MIN:
calc::Min(ret, lhs, rhs);
break;
case Opcode::OP_BITWISEAND:
calc::BitwiseAnd(ret, lhs, rhs);
break;
case Opcode::OP_BITWISEOR:
calc::BitwiseOr(ret, lhs, rhs);
break;
case Opcode::OP_BITWISEXOR:
calc::BitwiseXor(ret, lhs, rhs);
break;
case Opcode::OP_EXPANDEXPDIF:
calc::ExpandExpDif(ret, lhs, rhs);
break;
case Opcode::OP_COPYSIGN:
calc::CopySign(ret, lhs, rhs);
break;
case Opcode::OP_GCD:
calc::Gcd(ret, lhs, rhs);
break;
case Opcode::OP_GCD_BRC:
calc::Gcd(ret, lhs, rhs);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false);
}
}
REGISTER_CALC_OP(OP_ADD, Opcode::OP_ADD, ExecuteOpBinary<Opcode::OP_ADD>);
REGISTER_CALC_OP(OP_ADD_BRC, Opcode::OP_ADD_BRC, ExecuteOpBinary<Opcode::OP_ADD_BRC>);
REGISTER_CALC_OP(OP_SUB, Opcode::OP_SUB, ExecuteOpBinary<Opcode::OP_SUB>);
REGISTER_CALC_OP(OP_SUB_BRC, Opcode::OP_SUB_BRC, ExecuteOpBinary<Opcode::OP_SUB_BRC>);
REGISTER_CALC_OP(OP_MUL, Opcode::OP_MUL, ExecuteOpBinary<Opcode::OP_MUL>);
REGISTER_CALC_OP(OP_MUL_BRC, Opcode::OP_MUL_BRC, ExecuteOpBinary<Opcode::OP_MUL_BRC>);
REGISTER_CALC_OP(OP_DIV, Opcode::OP_DIV, ExecuteOpBinary<Opcode::OP_DIV>);
REGISTER_CALC_OP(OP_DIV_BRC, Opcode::OP_DIV_BRC, ExecuteOpBinary<Opcode::OP_DIV_BRC>);
REGISTER_CALC_OP(OP_FLOORDIV, Opcode::OP_FLOORDIV, ExecuteOpBinary<Opcode::OP_FLOORDIV>);
REGISTER_CALC_OP(OP_ATAN2, Opcode::OP_ATAN2, ExecuteOpBinary<Opcode::OP_ATAN2>);
REGISTER_CALC_OP(OP_POW, Opcode::OP_POW, ExecuteOpBinary<Opcode::OP_POW>);
REGISTER_CALC_OP(OP_REM, Opcode::OP_REM, ExecuteOpBinary<Opcode::OP_REM>);
REGISTER_CALC_OP(OP_S_ADD, Opcode::OP_S_ADD, ExecuteOpBinary<Opcode::OP_ADD>);
REGISTER_CALC_OP(OP_S_SUB, Opcode::OP_S_SUB, ExecuteOpBinary<Opcode::OP_SUB>);
REGISTER_CALC_OP(OP_S_MUL, Opcode::OP_S_MUL, ExecuteOpBinary<Opcode::OP_MUL>);
REGISTER_CALC_OP(OP_S_DIV, Opcode::OP_S_DIV, ExecuteOpBinary<Opcode::OP_DIV>);
REGISTER_CALC_OP(OP_PAIRMAX, Opcode::OP_PAIRMAX, ExecuteOpBinary<Opcode::OP_PAIRMAX>);
REGISTER_CALC_OP(OP_PAIRMIN, Opcode::OP_PAIRMIN, ExecuteOpBinary<Opcode::OP_PAIRMIN>);
REGISTER_CALC_OP(OP_PAIRSUM, Opcode::OP_PAIRSUM, ExecuteOpBinary<Opcode::OP_PAIRSUM>);
REGISTER_CALC_OP(OP_PAIRPROD, Opcode::OP_PAIRPROD, ExecuteOpBinary<Opcode::OP_PAIRPROD>);
REGISTER_CALC_OP(OP_S_MAX, Opcode::OP_S_MAX, ExecuteOpBinary<Opcode::OP_S_MAX>);
REGISTER_CALC_OP(OP_S_MIN, Opcode::OP_S_MIN, ExecuteOpBinary<Opcode::OP_S_MIN>);
REGISTER_CALC_OP(OP_MAXIMUM, Opcode::OP_MAXIMUM, ExecuteOpBinary<Opcode::OP_S_MAX>);
REGISTER_CALC_OP(OP_MINIMUM, Opcode::OP_MINIMUM, ExecuteOpBinary<Opcode::OP_S_MIN>);
REGISTER_CALC_OP(OP_BITWISEAND, Opcode::OP_BITWISEAND, ExecuteOpBinary<Opcode::OP_BITWISEAND>);
REGISTER_CALC_OP(OP_BITWISEOR, Opcode::OP_BITWISEOR, ExecuteOpBinary<Opcode::OP_BITWISEOR>);
REGISTER_CALC_OP(OP_BITWISEXOR, Opcode::OP_BITWISEXOR, ExecuteOpBinary<Opcode::OP_BITWISEXOR>);
REGISTER_CALC_OP(OP_EXPANDEXPDIF, Opcode::OP_EXPANDEXPDIF, ExecuteOpBinary<Opcode::OP_EXPANDEXPDIF>);
REGISTER_CALC_OP(OP_COPYSIGN, Opcode::OP_COPYSIGN, ExecuteOpBinary<Opcode::OP_COPYSIGN>);
REGISTER_CALC_OP(OP_GCD, Opcode::OP_GCD, ExecuteOpBinary<Opcode::OP_GCD>);
REGISTER_CALC_OP(OP_GCD_BRC, Opcode::OP_GCD_BRC, ExecuteOpBinary<Opcode::OP_GCD_BRC>);
void ExecuteOpFmod(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto ret = ctx->ooperandInplaceDataViewList->at(0);
auto lhs = ctx->ioperandDataViewList->at(0);
auto rhs = ctx->ioperandDataViewList->at(1);
calc::Fmod(ret, lhs, rhs);
}
REGISTER_CALC_OP(OP_MOD, Opcode::OP_MOD, ExecuteOpFmod);
void ExecuteOpFmods(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& lhs = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
calc::FmodS(ret, lhs, element);
}
REGISTER_CALC_OP(OP_MODS, Opcode::OP_MODS, ExecuteOpFmods);
void ExecuteOpVecDup(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 0);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
calc::ExpandS(ret, element);
}
REGISTER_CALC_OP(OP_VEC_DUP, Opcode::OP_VEC_DUP, ExecuteOpVecDup);
void ExecuteOpWhereTT(ExecuteOperationContext* ctx)
{
auto result = ctx->ooperandInplaceDataViewList->at(0);
auto condition = ctx->ioperandDataViewList->at(0);
auto input = ctx->ioperandDataViewList->at(1);
auto other = ctx->ioperandDataViewList->at(2);
calc::WhereTT(result, condition, input, other);
}
REGISTER_CALC_OP(OP_WHERE_TT, Opcode::OP_WHERE_TT, ExecuteOpWhereTT);
void ExecuteOpWhereTS(ExecuteOperationContext* ctx)
{
auto result = ctx->ooperandInplaceDataViewList->at(0);
auto condition = ctx->ioperandDataViewList->at(0);
auto input = ctx->ioperandDataViewList->at(1);
auto other = ctx->op->GetElementAttribute(OpAttributeKey::scalar);
calc::WhereTS(result, condition, input, other);
}
REGISTER_CALC_OP(OP_WHERE_TS, Opcode::OP_WHERE_TS, ExecuteOpWhereTS);
void ExecuteOpWhereST(ExecuteOperationContext* ctx)
{
auto result = ctx->ooperandInplaceDataViewList->at(0);
auto condition = ctx->ioperandDataViewList->at(0);
auto other = ctx->ioperandDataViewList->at(1);
auto input = ctx->op->GetElementAttribute(OpAttributeKey::scalar);
calc::WhereST(result, condition, input, other);
}
REGISTER_CALC_OP(OP_WHERE_ST, Opcode::OP_WHERE_ST, ExecuteOpWhereST);
void ExecuteOpWhereSS(ExecuteOperationContext* ctx)
{
auto result = ctx->ooperandInplaceDataViewList->at(0);
auto condition = ctx->ioperandDataViewList->at(0);
auto scalars = ctx->op->GetVectorElementAttribute(OpAttributeKey::vectorScalar);
auto input = scalars[0];
auto other = scalars[1];
calc::WhereSS(result, condition, input, other);
}
REGISTER_CALC_OP(OP_WHERE_SS, Opcode::OP_WHERE_SS, ExecuteOpWhereSS);
template <Opcode opcode>
void ExecuteOpReduce(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
if (oop->GetShape()[axis] != 1) {
std::vector<int64_t> oopShape = oop->GetShape();
oopShape[axis] = 1;
oop = oop->View(oopShape, std::vector<int64_t>(oopShape.size(), 0));
}
switch (opcode) {
case Opcode::OP_ROWSUM_SINGLE:
calc::RowSumSingle(oop, iop, axis);
break;
case Opcode::OP_ROWMAX_SINGLE:
calc::RowMaxSingle(oop, iop, axis);
break;
case Opcode::OP_ROWMIN_SINGLE:
calc::RowMinSingle(oop, iop, axis);
break;
case Opcode::OP_ROWPROD_SINGLE:
calc::RowProdSingle(oop, iop, axis);
break;
case Opcode::OP_ROWSUMLINE:
calc::RowSumExpand(oop, iop, axis);
break;
case Opcode::OP_ROWMAXLINE:
calc::RowMaxLine(oop, iop, axis);
break;
case Opcode::OP_ROWMINLINE:
calc::RowMinLine(oop, iop, axis);
break;
case Opcode::OP_ROWPRODLINE:
calc::RowProdLine(oop, iop, axis);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false) << "opcode not support" << ctx->op->GetOpcodeStr();
}
}
REGISTER_CALC_OP(OP_ROWSUM_SINGLE, Opcode::OP_ROWSUM_SINGLE, ExecuteOpReduce<Opcode::OP_ROWSUM_SINGLE>);
REGISTER_CALC_OP(OP_ROWSUMLINE, Opcode::OP_ROWSUMLINE, ExecuteOpReduce<Opcode::OP_ROWSUMLINE>);
REGISTER_CALC_OP(OP_ROWMAX_SINGLE, Opcode::OP_ROWMAX_SINGLE, ExecuteOpReduce<Opcode::OP_ROWMAX_SINGLE>);
REGISTER_CALC_OP(OP_ROWMAXLINE, Opcode::OP_ROWMAXLINE, ExecuteOpReduce<Opcode::OP_ROWMAXLINE>);
REGISTER_CALC_OP(OP_ROWMIN_SINGLE, Opcode::OP_ROWMIN_SINGLE, ExecuteOpReduce<Opcode::OP_ROWMIN_SINGLE>);
REGISTER_CALC_OP(OP_ROWMINLINE, Opcode::OP_ROWMINLINE, ExecuteOpReduce<Opcode::OP_ROWMINLINE>);
REGISTER_CALC_OP(OP_ROWPROD_SINGLE, Opcode::OP_ROWPROD_SINGLE, ExecuteOpReduce<Opcode::OP_ROWPROD_SINGLE>);
REGISTER_CALC_OP(OP_ROWPRODLINE, Opcode::OP_ROWPRODLINE, ExecuteOpReduce<Opcode::OP_ROWPRODLINE>);
template <Opcode opcode>
void ExecuteOpArgReduceSingle(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
if (oop->GetShape()[axis] != 1) {
std::vector<int64_t> oopShape = oop->GetShape();
oopShape[axis] = 1;
oop = oop->View(oopShape, std::vector<int64_t>(oopShape.size(), 0));
}
switch (opcode) {
case Opcode::OP_ROWARGMAX_SINGLE:
calc::RowArgMaxSingle(oop, iop, axis);
break;
case Opcode::OP_ROWARGMIN_SINGLE:
calc::RowArgMinSingle(oop, iop, axis);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false) << "opcode not support" << ctx->op->GetOpcodeStr();
}
}
REGISTER_CALC_OP(OP_ROWARGMAX_SINGLE, Opcode::OP_ROWARGMAX_SINGLE, ExecuteOpReduce<Opcode::OP_ROWARGMAX_SINGLE>);
REGISTER_CALC_OP(OP_ROWARGMIN_SINGLE, Opcode::OP_ROWARGMIN_SINGLE, ExecuteOpReduce<Opcode::OP_ROWARGMIN_SINGLE>);
template <Opcode opcode>
void ExecuteOpArgReduceWithValue(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto outValue = ctx->ooperandInplaceDataViewList->at(0);
auto outIndex = ctx->ooperandInplaceDataViewList->at(1);
auto outTemp = ctx->ooperandInplaceDataViewList->at(2);
auto iop = ctx->ioperandDataViewList->at(0);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
switch (opcode) {
case Opcode::OP_ROWARGMAXWITHVALUE_SINGLE:
calc::RowArgMaxWithValueSingle(outValue, outIndex, outTemp, iop, axis);
break;
case Opcode::OP_ROWARGMINWITHVALUE_SINGLE:
calc::RowArgMinWithValueSingle(outValue, outIndex, outTemp, iop, axis);
break;
case Opcode::OP_ROWARGMAXWITHVALUE_LINE:
calc::RowArgMaxWithValueLine(outValue, outIndex, outTemp, iop, axis);
break;
case Opcode::OP_ROWARGMINWITHVALUE_LINE:
calc::RowArgMinWithValueLine(outValue, outIndex, outTemp, iop, axis);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false) << "opcode not support" << ctx->op->GetOpcodeStr();
}
}
REGISTER_CALC_OP(OP_ROWARGMAXWITHVALUE_SINGLE, Opcode::OP_ROWARGMAXWITHVALUE_SINGLE, ExecuteOpReduce<Opcode::OP_ROWARGMAXWITHVALUE_SINGLE>);
REGISTER_CALC_OP(OP_ROWARGMAXWITHVALUE_LINE, Opcode::OP_ROWARGMAXWITHVALUE_LINE, ExecuteOpReduce<Opcode::OP_ROWARGMAXWITHVALUE_LINE>);
REGISTER_CALC_OP(OP_ROWARGMINWITHVALUE_SINGLE, Opcode::OP_ROWARGMINWITHVALUE_SINGLE, ExecuteOpReduce<Opcode::OP_ROWARGMINWITHVALUE_SINGLE>);
REGISTER_CALC_OP(OP_ROWARGMINWITHVALUE_LINE, Opcode::OP_ROWARGMINWITHVALUE_LINE, ExecuteOpReduce<Opcode::OP_ROWARGMINWITHVALUE_LINE>);
template <Opcode opcode>
void ExecuteOpPairArgRedyce(ExecuteOperationContext* ctx)
{
auto outValue = ctx->ooperandInplaceDataViewList->at(0);
auto outIndex = ctx->ooperandInplaceDataViewList->at(1);
auto value1 = ctx->ioperandDataViewList->at(0);
auto index1 = ctx->ioperandDataViewList->at(1);
auto value2 = ctx->ioperandDataViewList->at(2);
auto index2 = ctx->ioperandDataViewList->at(3);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
switch (opcode) {
case Opcode::OP_PAIRARGMAX:
calc::PairArgMax(outValue, outIndex, value1, index1, value2, index2);
break;
case Opcode::OP_PAIRARGMIN:
calc::PairArgMin(outValue, outIndex, value1, index1, value2, index2);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false) << "opcode not support" << ctx->op->GetOpcodeStr();
}
}
REGISTER_CALC_OP(OP_PAIRARGMAX, Opcode::OP_PAIRARGMAX, ExecuteOpReduce<Opcode::OP_PAIRARGMAX>);
REGISTER_CALC_OP(OP_PAIRARGMIN, Opcode::OP_PAIRARGMIN, ExecuteOpReduce<Opcode::OP_PAIRARGMIN>);
void ExecuteOpCast(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= 0x2);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
CastMode mode = static_cast<CastMode>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "mode"));
calc::Cast(ret, iop, mode);
}
REGISTER_CALC_OP(OP_CAST, Opcode::OP_CAST, ExecuteOpCast);
template <Opcode opcode>
void ExecuteOpUnary(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
switch (opcode) {
case Opcode::OP_EXP:
calc::Exp(ret, iop);
break;
case Opcode::OP_SINH:
calc::Sinh(ret, iop);
break;
case Opcode::OP_COSH:
calc::Cosh(ret, iop);
break;
case Opcode::OP_ASIN:
calc::Asin(ret, iop);
break;
case Opcode::OP_ACOS:
calc::Acos(ret, iop);
break;
case Opcode::OP_ASINH:
calc::ASinh(ret, iop);
break;
case Opcode::OP_ACOSH:
calc::ACosh(ret, iop);
break;
case Opcode::OP_ATANH:
calc::Atanh(ret, iop);
break;
case Opcode::OP_NEG:
calc::Neg(ret, iop);
break;
case Opcode::OP_SIGN:
calc::Sign(ret, iop);
break;
case Opcode::OP_SIGNBIT:
calc::Signbit(ret, iop);
break;
case Opcode::OP_TANH:
calc::Tanh(ret, iop);
break;
case Opcode::OP_TAN:
calc::Tan(ret, iop);
break;
case Opcode::OP_RSQRT:
calc::Rsqrt(ret, iop);
break;
case Opcode::OP_SQRT:
calc::Sqrt(ret, iop);
break;
case Opcode::OP_RECIPROCAL:
calc::Reciprocal(ret, iop);
break;
case Opcode::OP_RELU:
calc::Relu(ret, iop);
break;
case Opcode::OP_ATAN:
calc::Atan(ret, iop);
break;
case Opcode::OP_BITWISENOT:
calc::BitwiseNot(ret, iop);
break;
case Opcode::OP_ABS:
calc::Abs(ret, iop);
break;
case Opcode::OP_BRCB:
calc::Brcb(ret, iop);
break;
case Opcode::OP_LN:
calc::Ln(ret, iop);
break;
case Opcode::OP_ISFINITE:
calc::IsFinite(ret, iop);
break;
case Opcode::OP_ERF:
calc::Erf(ret, iop);
break;
case Opcode::OP_SIN:
calc::Sin(ret, iop);
break;
case Opcode::OP_COS:
calc::Cos(ret, iop);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false);
}
}
REGISTER_CALC_OP(OP_EXP, Opcode::OP_EXP, ExecuteOpUnary<Opcode::OP_EXP>);
REGISTER_CALC_OP(OP_SINH, Opcode::OP_SINH, ExecuteOpUnary<Opcode::OP_SINH>);
REGISTER_CALC_OP(OP_COSH, Opcode::OP_COSH, ExecuteOpUnary<Opcode::OP_COSH>);
REGISTER_CALC_OP(OP_ASIN, Opcode::OP_ASIN, ExecuteOpUnary<Opcode::OP_ASIN>);
REGISTER_CALC_OP(OP_ACOS, Opcode::OP_ACOS, ExecuteOpUnary<Opcode::OP_ACOS>);
REGISTER_CALC_OP(OP_ASINH, Opcode::OP_ASINH, ExecuteOpUnary<Opcode::OP_ASINH>);
REGISTER_CALC_OP(OP_ACOSH, Opcode::OP_ACOSH, ExecuteOpUnary<Opcode::OP_ACOSH>);
REGISTER_CALC_OP(OP_ATANH, Opcode::OP_ATANH, ExecuteOpUnary<Opcode::OP_ATANH>);
REGISTER_CALC_OP(OP_NEG, Opcode::OP_NEG, ExecuteOpUnary<Opcode::OP_NEG>);
REGISTER_CALC_OP(OP_SIGN, Opcode::OP_SIGN, ExecuteOpUnary<Opcode::OP_SIGN>);
REGISTER_CALC_OP(OP_SIGNBIT, Opcode::OP_SIGNBIT, ExecuteOpUnary<Opcode::OP_SIGNBIT>);
REGISTER_CALC_OP(OP_TANH, Opcode::OP_TANH, ExecuteOpUnary<Opcode::OP_TANH>);
REGISTER_CALC_OP(OP_TAN, Opcode::OP_TAN, ExecuteOpUnary<Opcode::OP_TAN>);
REGISTER_CALC_OP(OP_RSQRT, Opcode::OP_RSQRT, ExecuteOpUnary<Opcode::OP_RSQRT>);
REGISTER_CALC_OP(OP_SQRT, Opcode::OP_SQRT, ExecuteOpUnary<Opcode::OP_SQRT>);
REGISTER_CALC_OP(OP_RECIPROCAL, Opcode::OP_RECIPROCAL, ExecuteOpUnary<Opcode::OP_RECIPROCAL>);
REGISTER_CALC_OP(OP_RELU, Opcode::OP_RELU, ExecuteOpUnary<Opcode::OP_RELU>);
REGISTER_CALC_OP(OP_ATAN, Opcode::OP_ATAN, ExecuteOpUnary<Opcode::OP_ATAN>);
REGISTER_CALC_OP(OP_BITWISENOT, Opcode::OP_BITWISENOT, ExecuteOpUnary<Opcode::OP_BITWISENOT>);
REGISTER_CALC_OP(OP_ABS, Opcode::OP_ABS, ExecuteOpUnary<Opcode::OP_ABS>);
REGISTER_CALC_OP(OP_BRCB, Opcode::OP_BRCB, ExecuteOpUnary<Opcode::OP_BRCB>);
REGISTER_CALC_OP(OP_LN, Opcode::OP_LN, ExecuteOpUnary<Opcode::OP_LN>);
REGISTER_CALC_OP(OP_ISFINITE, Opcode::OP_ISFINITE, ExecuteOpUnary<Opcode::OP_ISFINITE>);
REGISTER_CALC_OP(OP_ERF, Opcode::OP_ERF, ExecuteOpUnary<Opcode::OP_ERF>);
REGISTER_CALC_OP(OP_SIN, Opcode::OP_SIN, ExecuteOpUnary<Opcode::OP_SIN>);
REGISTER_CALC_OP(OP_COS, Opcode::OP_COS, ExecuteOpUnary<Opcode::OP_COS>);
void ExecuteOpCeil(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
calc::Ceil(ret, iop);
}
REGISTER_CALC_OP(OP_CEIL, Opcode::OP_CEIL, ExecuteOpCeil);
void ExecuteOpFloor(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
calc::Floor(ret, iop);
}
REGISTER_CALC_OP(OP_FLOOR, Opcode::OP_FLOOR, ExecuteOpFloor);
void ExecuteOpTrunc(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
calc::Trunc(ret, iop);
}
REGISTER_CALC_OP(OP_TRUNC, Opcode::OP_TRUNC, ExecuteOpTrunc);
void ExecuteOpExp2(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
calc::Exp2(ret, iop);
}
REGISTER_CALC_OP(OP_EXP2, Opcode::OP_EXP2, ExecuteOpExp2);
void ExecuteOpPad(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_input = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
calc::Pad(oop, iop_input, element);
}
REGISTER_CALC_OP(OP_PAD, Opcode::OP_PAD, ExecuteOpPad);
void ExecuteOpFillPad(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_input = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
calc::FillPad(oop, iop_input, element);
}
REGISTER_CALC_OP(OP_FILLPAD, Opcode::OP_FILLPAD, ExecuteOpFillPad);
void ExecuteOpRound(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
int decimals = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "decimals");
calc::Round(output, input, decimals);
}
REGISTER_CALC_OP(OP_ROUND, Opcode::OP_ROUND, ExecuteOpRound);
void ExecuteOpExpm1(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
calc::Expm1(output, input);
}
REGISTER_CALC_OP(OP_EXPM1, Opcode::OP_EXPM1, ExecuteOpExpm1);
void ExecuteOpErfc(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
calc::Erfc(output, input);
}
REGISTER_CALC_OP(OP_ERFC, Opcode::OP_ERFC, ExecuteOpErfc);
void ExecuteOpOneHot(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
int numClasses = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "numClasses");
calc::OneHot(ret, iop, numClasses);
}
REGISTER_CALC_OP(OP_ONEHOT, Opcode::OP_ONEHOT, ExecuteOpOneHot);
void ExecuteOpExpand(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
calc::Expand(oop, iop);
}
REGISTER_CALC_OP(OP_EXPAND, Opcode::OP_EXPAND, ExecuteOpExpand);
void ExecuteOpTransposeMoveOut(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
std::vector<int64_t> axises = ctx->op->GetVectorIntAttribute(OP_ATTR_PREFIX + "shape");
if (std::dynamic_pointer_cast<CopyOpAttribute>(ctx->op->GetOpAttribute())) {
auto copyoutAttr = std::dynamic_pointer_cast<CopyOpAttribute>(ctx->op->GetOpAttribute());
std::vector<int64_t> shape = ctx->opInter->EvaluateOpImmediate(ctx->frame, copyoutAttr->GetShape());
if (ctx->op->GetOpcode() == Opcode::OP_TRANSPOSE_MOVEOUT) {
std::vector<int64_t> toOffset = ctx->opInter->EvaluateOpImmediate(ctx->frame, copyoutAttr->GetToOffset());
std::vector<int64_t> iopShape = iop->GetShape();
std::swap(iopShape[axises[0]], iopShape[axises[1]]);
auto oopCopy = std::make_shared<LogicalTensorData>(oop->GetData(), iopShape, toOffset);
return calc::Transpose(oopCopy, iop, axises[0], axises[1]);
} else {
std::vector<int64_t> fromOffset =
ctx->opInter->EvaluateOpImmediate(ctx->frame, copyoutAttr->GetFromOffset());
std::vector<int64_t> oopShape = oop->GetShape();
std::swap(oopShape[axises[0]], oopShape[axises[1]]);
auto iopCopy = std::make_shared<LogicalTensorData>(iop->GetData(), oopShape, fromOffset);
return calc::Transpose(oop, iopCopy, axises[0], axises[1]);
}
}
calc::Transpose(oop, iop, axises[0], axises[1]);
}
REGISTER_CALC_OP(OP_TRANSPOSE_MOVEOUT, Opcode::OP_TRANSPOSE_MOVEOUT, ExecuteOpTransposeMoveOut);
REGISTER_CALC_OP(OP_TRANSPOSE_MOVEIN, Opcode::OP_TRANSPOSE_MOVEIN, ExecuteOpTransposeMoveOut);
void ExecuteOpTranspose(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
auto axises = ctx->op->GetVectorIntAttribute(OP_ATTR_PREFIX + "shape");
calc::Transpose(oop, iop, axises[0], axises[1]);
}
REGISTER_CALC_OP(OP_TRANSPOSE_VNCHWCONV, Opcode::OP_TRANSPOSE_VNCHWCONV, ExecuteOpTranspose);
void ExecuteOpPermute(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& oop = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
std::vector<int64_t> perm = ctx->op->GetVectorIntAttribute(OpAttributeKey::perm);
auto iopDataView = iop->View(iop->GetValidShape(), iop->GetOffset());
auto oopDataView = oop->View(oop->GetValidShape(), oop->GetOffset());
calc::Permute(oopDataView, iopDataView, perm);
}
REGISTER_CALC_OP(OP_PERMUTE, Opcode::OP_PERMUTE, ExecuteOpPermute);
REGISTER_CALC_OP(OP_PERMUTE_ELEMENT, Opcode::OP_PERMUTE_ELEMENT, ExecuteOpPermute);
void ExecuteOpLogicalNot(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop = ctx->ioperandDataViewList->at(0);
calc::LogicalNot(oop, iop);
}
REGISTER_CALC_OP(OP_LOGICALNOT, Opcode::OP_LOGICALNOT, ExecuteOpLogicalNot);
void ExecuteOpLogicalAnd(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto ret = ctx->ooperandInplaceDataViewList->at(0);
auto lhs = ctx->ioperandDataViewList->at(0);
auto rhs = ctx->ioperandDataViewList->at(1);
calc::LogicalAnd(ret, lhs, rhs);
}
REGISTER_CALC_OP(OP_LOGICALAND, Opcode::OP_LOGICALAND, ExecuteOpLogicalAnd);
void ExecuteOpIndexOutcast(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_THREE);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto src = ctx->ioperandDataViewList->at(0);
auto index = ctx->ioperandDataViewList->at(1);
auto dst = ctx->ioperandDataViewList->at(2);
int axis = ctx->op->GetIntAttribute("axis");
int blockSize = ctx->op->GetIntAttribute(OpAttributeKey::panzBlockSize);
std::string cacheMode = ctx->op->GetStringAttribute(OpAttributeKey::cacheMode);
auto actualOop = std::make_shared<LogicalTensorData>(dst->GetData());
if (dst->GetSize() != oop->GetSize()) {
INTERPRETER_EVENT("%s", ctx->op->Dump().c_str());
INTERPRETER_EVENT(
"dst validShape: %s ---> oop validShape: %s", IntVecToStr(dst->GetShape()).c_str(),
IntVecToStr(oop->GetShape()).c_str());
INTERPRETER_EVENT("IndexOutcast: oop validShape is not equal to dst validShape");
calc::ScatterUpdate(actualOop, src, index, dst, axis, cacheMode, blockSize);
} else {
calc::ScatterUpdate(oop, src, index, dst, axis, cacheMode, blockSize);
}
}
REGISTER_CALC_OP(OP_INDEX_OUTCAST, Opcode::OP_INDEX_OUTCAST, ExecuteOpIndexOutcast);
void ExecuteOpScatterElement(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto self = ctx->ioperandDataViewList->at(0);
auto indices = ctx->ioperandDataViewList->at(1);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
auto src = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, src);
int reduce = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "scatter_mode");
calc::ScatterElement(oop, self, indices, src, axis, reduce);
}
REGISTER_CALC_OP(OP_SCATTER_ELEMENT, Opcode::OP_SCATTER_ELEMENT, ExecuteOpScatterElement);
void ExecuteOpScatter(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_THREE);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto self = ctx->ioperandDataViewList->at(0);
auto indices = ctx->ioperandDataViewList->at(1);
auto src = ctx->ioperandDataViewList->at(2);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
int reduce = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "scatter_mode");
calc::Scatter(oop, self, indices, src, axis, reduce);
}
REGISTER_CALC_OP(OP_SCATTER, Opcode::OP_SCATTER, ExecuteOpScatter);
template <typename T, DataType dataType>
Element GetEndBySize(Element start, Element size, Element step)
{
T startValue;
T stepValue;
if (dataType == DT_INT32 || dataType == DT_INT64) {
startValue = start.GetSignedData();
stepValue = step.GetSignedData();
} else if (dataType == DT_FP32) {
startValue = (float)start.GetFloatData();
stepValue = (float)step.GetFloatData();
}
T endValue = startValue + size.GetSignedData() * stepValue - stepValue / 2;
Element end(dataType, endValue);
return end;
}
void ExecuteOpRange(ExecuteOperationContext* ctx)
{
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto start = ctx->op->GetElementAttribute(OP_ATTR_PREFIX + "START");
auto size = ctx->op->GetElementAttribute(OP_ATTR_PREFIX + "SIZE");
auto step = ctx->op->GetElementAttribute(OP_ATTR_PREFIX + "STEP");
Element curStart = start;
if (ctx->op->HasAttr(OpAttributeKey::dynScalar)) {
SymbolicScalar tileIdx = ctx->op->GetSymbolicScalarAttribute(OpAttributeKey::dynScalar);
if (tileIdx.ConcreteValid()) {
int64_t tileIdxVal = tileIdx.Concrete();
Element tileIdxElem = Element(start.GetDataType(), tileIdxVal);
curStart = start + step * tileIdxElem;
}
}
Element end;
if (start.GetDataType() == DT_INT32) {
end = GetEndBySize<int32_t, DT_INT32>(curStart, size, step);
} else if (start.GetDataType() == DT_INT64) {
end = GetEndBySize<int64_t, DT_INT64>(curStart, size, step);
} else if (start.GetDataType() == DT_FP32) {
end = GetEndBySize<float, DT_FP32>(curStart, size, step);
} else {
ASSERT(ExecuteOperationScene::INVALID_TENSOR_DTYPE, false)
<< "Unsupported DataType " << DataType2String(start.GetDataType());
}
calc::Range(oop, curStart, end, step);
}
REGISTER_CALC_OP(OP_RANGE, Opcode::OP_RANGE, ExecuteOpRange);
void ExecuteOpUniform(ExecuteOperationContext* ctx)
{
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto scalars = ctx->op->GetVectorElementAttribute(OpAttributeKey::vectorScalar);
Element key = scalars[0];
Element counter1 = scalars[1];
Element rounds = scalars[2];
DataType dtype = static_cast<DataType>(scalars[3].Cast<int32_t>());
Element counter0(DT_UINT64, static_cast<uint64_t>(0));
if (ctx->op->HasAttr(OpAttributeKey::dynScalar)) {
SymbolicScalar dynScalar = ctx->op->GetSymbolicScalarAttribute(OpAttributeKey::dynScalar);
if (dynScalar.ConcreteValid()) {
counter0 = Element(DT_UINT64, static_cast<uint64_t>(dynScalar.Concrete()));
}
}
calc::Uniform(oop, key, counter0, counter1, rounds, dtype);
}
REGISTER_CALC_OP(OP_UNIFORM, Opcode::OP_UNIFORM, ExecuteOpUniform);
void ExecuteOpQuantMX(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 0x4);
int64_t mode = 0;
ASSERT(ExecuteOperationScene::RUNTIME_EXCEPTION, ctx->op->GetAttr(OpAttributeKey::mxQuantMode, mode))
<< "QuantMX missing required attribute: " << OpAttributeKey::mxQuantMode;
constexpr int64_t kMxQuantModeRoundUp = 0;
constexpr int64_t kMxQuantModeRoundDown = 1;
ASSERT(ExecuteOperationScene::RUNTIME_EXCEPTION, mode == kMxQuantModeRoundDown || mode == kMxQuantModeRoundUp)
<< "QuantMX interpreter currently only supports ROUND_DOWN (OCP) and ROUND_UP (NV) modes.";
int64_t axis = 0;
ASSERT(ExecuteOperationScene::RUNTIME_EXCEPTION, ctx->op->GetAttr(OpAttributeKey::mxQuantAxis, axis))
<< "QuantMX missing required attribute: " << OpAttributeKey::mxQuantAxis;
int64_t performanceMode = 0;
ASSERT(
ExecuteOperationScene::RUNTIME_EXCEPTION,
ctx->op->GetAttr(OpAttributeKey::mxQuantPerformanceMode, performanceMode))
<< "QuantMX missing required attribute: " << OpAttributeKey::mxQuantPerformanceMode;
auto out = ctx->ooperandInplaceDataViewList->at(0);
auto exp = ctx->ooperandInplaceDataViewList->at(1);
auto max = ctx->ooperandInplaceDataViewList->at(2);
auto scaling = ctx->ooperandInplaceDataViewList->at(3);
auto src = ctx->ioperandDataViewList->at(0);
const auto srcRank = static_cast<int64_t>(src->GetShape().size());
const auto normalizedAxis = axis < 0 ? axis + srcRank : axis;
ASSERT(ExecuteOperationScene::RUNTIME_EXCEPTION, normalizedAxis >= 0 && normalizedAxis < srcRank)
<< "QuantMX axis is out of range. Current axis: " << axis << ", input rank: " << srcRank;
ASSERT(ExecuteOperationScene::RUNTIME_EXCEPTION, normalizedAxis == srcRank - 1)
<< "QuantMX interpreter currently only supports the last axis. Current axis: " << axis
<< ", input rank: " << srcRank;
calc::QuantMX(out, exp, max, scaling, src, performanceMode != 0, mode);
}
REGISTER_CALC_OP(OP_QUANT_MX, Opcode::OP_QUANT_MX, ExecuteOpQuantMX);
void ExecuteOpLog1p(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& iop = ctx->ioperandDataViewList->at(0);
calc::Log1p(ret, iop);
}
REGISTER_CALC_OP(OP_LOG1P, Opcode::OP_LOG1P, ExecuteOpLog1p);
void ExecuteOpCompare(ExecuteOperationContext* ctx)
{
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_self = ctx->ioperandDataViewList->at(0);
auto iop_other = ctx->ioperandDataViewList->at(1);
auto operation = static_cast<CmpOperationType>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "cmp_operation"));
auto mode = static_cast<CmpModeType>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "cmp_mode"));
calc::Compare(oop, iop_self, iop_other, operation, mode);
}
REGISTER_CALC_OP(OP_CMP, Opcode::OP_CMP, ExecuteOpCompare);
void ExecuteOpCmps(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_self = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
auto operation = static_cast<CmpOperationType>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "cmp_operation"));
auto mode = static_cast<CmpModeType>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "cmp_mode"));
calc::Cmps(oop, iop_self, element, operation, mode);
}
REGISTER_CALC_OP(OP_CMPS, Opcode::OP_CMPS, ExecuteOpCmps);
void ExecuteOpHypot(ExecuteOperationContext* ctx)
{
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_self = ctx->ioperandDataViewList->at(0);
auto iop_other = ctx->ioperandDataViewList->at(1);
calc::Hypot(oop, iop_self, iop_other);
}
REGISTER_CALC_OP(OP_HYPOT, Opcode::OP_HYPOT, ExecuteOpHypot);
void ExecuteOpPReLU(ExecuteOperationContext* ctx)
{
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto iop_self = ctx->ioperandDataViewList->at(0);
auto iop_weight = ctx->ioperandDataViewList->at(1);
calc::PReLU(oop, iop_self, iop_weight);
}
REGISTER_CALC_OP(OP_PRELU, Opcode::OP_PRELU, ExecuteOpPReLU);
void ExecuteOpExtract(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto src = ctx->ioperandDataViewList->at(0);
auto maskMode = ctx->op->GetIntAttribute("op_attr_makeMode");
int descending = ctx->op->GetIntAttribute("op_attr_order");
calc::Extract(oop, src, maskMode, descending);
}
REGISTER_CALC_OP(OP_EXTRACT, Opcode::OP_EXTRACT, ExecuteOpExtract);
REGISTER_CALC_OP(OP_EXTRACT_SINGLE, Opcode::OP_EXTRACT_SINGLE, ExecuteOpExtract);
void ExecuteOpGather(ExecuteOperationContext* ctx)
{
auto output = ctx->ooperandInplaceDataViewList->at(0);
auto parmas = ctx->ioperandDataViewList->at(0);
auto indices = ctx->ioperandDataViewList->at(1);
int axis = ctx->op->GetIntAttribute("op_attr_axis");
calc::Gather(output, parmas, indices, axis);
}
REGISTER_CALC_OP(OP_GATHER, Opcode::OP_GATHER, ExecuteOpGather);
void ExecuteOpGatherINUB(ExecuteOperationContext* ctx)
{
auto output = ctx->ooperandInplaceDataViewList->at(0);
auto parmas = ctx->ioperandDataViewList->at(0);
auto indices = ctx->ioperandDataViewList->at(1);
auto pageTable = ctx->ioperandDataViewList->at(2);
int blocksize = ctx->op->GetIntAttribute(OpAttributeKey::blockSize);
calc::GatherINUB(output, parmas, indices, pageTable, blocksize, GATHER_IGNORE_AXIS);
}
REGISTER_CALC_OP(OP_GATHER_IN_UB, Opcode::OP_GATHER_IN_UB, ExecuteOpGatherINUB);
void ExecuteOpIndexAdd(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_THREE);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& self = ctx->ioperandDataViewList->at(0);
auto& src = ctx->ioperandDataViewList->at(1);
auto& indices = ctx->ioperandDataViewList->at(2);
auto alpha = Element(DT_FP32, 1.0);
if (ctx->op->HasAttribute(OpAttributeKey::scalar)) {
alpha = ctx->op->GetElementAttribute(OpAttributeKey::scalar);
}
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
if (ctx->op->GetOpcode() == Opcode::OP_INDEX_ADD) {
calc::IndexAdd(self, self, src, indices, axis, alpha);
} else {
calc::IndexAdd(ret, self, src, indices, axis, alpha);
}
}
REGISTER_CALC_OP(OP_INDEX_ADD_UB, Opcode::OP_INDEX_ADD_UB, ExecuteOpIndexAdd);
REGISTER_CALC_OP(OP_INDEX_ADD, Opcode::OP_INDEX_ADD, ExecuteOpIndexAdd);
void ExecuteOpTri(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
SymbolicScalar diaSym = ctx->op->GetSymbolicScalarAttribute(OpAttributeKey::dynScalar);
std::vector<OpImmediate> diaImmList = {OpImmediate::Specified(diaSym)};
int diagonal = static_cast<int>(ctx->opInter->EvaluateOpImmediate(ctx->frame, diaImmList)[0]);
std::cout << "diagonal: " << diagonal << std::endl;
bool isUpper = ctx->op->GetBoolAttribute(OpAttributeKey::isUpper);
isUpper ? calc::TriU(output, input, diagonal) : calc::TriL(output, input, diagonal);
}
REGISTER_CALC_OP(OP_TRIUL, Opcode::OP_TRIUL, ExecuteOpTri);
void ExecuteOpCumSum(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
calc::CumSum(output, input, axis);
}
REGISTER_CALC_OP(OP_CUM_SUM, Opcode::OP_CUM_SUM, ExecuteOpCumSum);
void ExecuteOpCumProd(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& output = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
int axis = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
calc::CumProd(output, input, axis);
}
REGISTER_CALC_OP(OP_CUM_PROD, Opcode::OP_CUM_PROD, ExecuteOpCumProd);
void ExecuteOpIndexPut(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() <= SIZE_SIX);
auto out = ctx->ooperandInplaceDataViewList->at(0);
auto self = ctx->ioperandDataViewList->at(0);
auto values = ctx->ioperandDataViewList->at(1);
std::vector<LogicalTensorDataPtr> indices;
for (int i = SIZE_TWO; i < static_cast<int>(ctx->ioperandDataViewList->size()); i++) {
auto indicesTemp = ctx->ioperandDataViewList->at(i);
indices.push_back(indicesTemp);
}
bool accumulate = ctx->op->GetBoolAttribute(OpAttributeKey::accumulate);
calc::IndexPut(out, self, indices, values, accumulate);
}
REGISTER_CALC_OP(OP_INDEX_PUT, Opcode::OP_INDEX_PUT, ExecuteOpIndexPut);
void ExecuteOpMrgSort(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto src = ctx->ioperandDataViewList->at(0);
auto topk_axis = ctx->op->GetIntAttribute("op_attr_axis");
auto kValue = ctx->op->GetIntAttribute("op_attr_kvalue");
calc::MrgSort(oop, src, topk_axis, kValue);
}
REGISTER_CALC_OP(OP_MRGSORT, Opcode::OP_MRGSORT, ExecuteOpMrgSort);
void ExecuteOpTwoTileMrgSort(ExecuteOperationContext* ctx)
{
auto src = ctx->ioperandDataViewList->at(0);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
calc::TwoTileMrgSort(oop, src);
}
REGISTER_CALC_OP(OP_TWOTILEMRGSORT, Opcode::OP_TWOTILEMRGSORT, ExecuteOpTwoTileMrgSort);
void ExecuteOpSort(ExecuteOperationContext* ctx)
{
auto src = ctx->ioperandDataViewList->at(0);
auto value = ctx->ooperandInplaceDataViewList->at(0);
auto index = ctx->ooperandInplaceDataViewList->at(1);
auto axis = ctx->op->GetIntAttribute("op_attr_axis");
int descending = ctx->op->GetIntAttribute("op_attr_order");
calc::Sort(value, index, src, axis, descending);
}
REGISTER_CALC_OP(OP_SORT_UB, Opcode::OP_SORT_UB, ExecuteOpSort);
void ExecuteOpTopK(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto outValue = ctx->ooperandInplaceDataViewList->at(0);
auto outIndex = ctx->ooperandInplaceDataViewList->at(1);
auto src = ctx->ioperandDataViewList->at(0);
auto topk_axis = ctx->op->GetIntAttribute("op_attr_axis");
auto kValue = ctx->op->GetIntAttribute("op_attr_kvalue");
int descending = ctx->op->GetIntAttribute("op_attr_order");
calc::TopK(outValue, outIndex, src, kValue, topk_axis, descending);
}
REGISTER_CALC_OP(OP_TOPK, Opcode::OP_TOPK, ExecuteOpTopK);
REGISTER_CALC_OP(OP_RADIX_SELECT, Opcode::OP_RADIX_SELECT, ExecuteOpTopK);
void ExecuteOpQuantizeSym(ExecuteOperationContext* ctx)
{
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
auto& scale = ctx->ioperandDataViewList->at(1);
calc::Quantize(ret, input, scale, nullptr);
}
REGISTER_CALC_OP(OP_QUANTIZE_SYM, Opcode::OP_QUANTIZE_SYM, ExecuteOpQuantizeSym);
void ExecuteOpQuantizeAsym(ExecuteOperationContext* ctx)
{
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
auto& scale = ctx->ioperandDataViewList->at(1);
auto& zeropoints = ctx->ioperandDataViewList->at(2);
calc::Quantize(ret, input, scale, zeropoints);
}
REGISTER_CALC_OP(OP_QUANTIZE_ASYM, Opcode::OP_QUANTIZE_ASYM, ExecuteOpQuantizeAsym);
void ExecuteOpDequantize(ExecuteOperationContext* ctx)
{
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& input = ctx->ioperandDataViewList->at(0);
auto& scale = ctx->ioperandDataViewList->at(1);
auto& zeropoints = ctx->ioperandDataViewList->at(2);
calc::Dequantize(ret, input, scale, zeropoints);
}
REGISTER_CALC_OP(OP_DEQUANTIZE, Opcode::OP_DEQUANTIZE, ExecuteOpDequantize);
void ExecuteOpBitSort(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto src = ctx->ioperandDataViewList->at(0);
auto topk_axis = ctx->op->GetIntAttribute("op_attr_axis");
int descending = ctx->op->GetIntAttribute("op_attr_order");
int offset = ctx->op->GetIntAttribute("op_attr_offset");
calc::BitSort(oop, src, topk_axis, descending, offset);
}
REGISTER_CALC_OP(OP_BITSORT, Opcode::OP_BITSORT, ExecuteOpBitSort);
void ExecuteOpTiledMrgSort(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_FOUR);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
auto src1 = ctx->ioperandDataViewList->at(0);
auto src2 = ctx->ioperandDataViewList->at(1);
auto src3 = ctx->ioperandDataViewList->at(2);
auto src4 = ctx->ioperandDataViewList->at(3);
auto validBit = ctx->op->GetIntAttribute("op_attr_validBit");
auto kvalue = ctx->op->GetIntAttribute("op_attr_kvalue");
calc::TiledMrgSort(oop, src1, src2, src3, src4, validBit, kvalue);
}
REGISTER_CALC_OP(OP_TILEDMRGSORT, Opcode::OP_TILEDMRGSORT, ExecuteOpTiledMrgSort);
void ExecuteOpTopkSort(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
ASSERT(
ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH,
ctx->ooperandInplaceDataViewList->size() == 0x2);
auto iop = ctx->ioperandDataViewList->at(0);
auto oop_value = ctx->ooperandInplaceDataViewList->at(0);
auto oop_temp = ctx->ooperandInplaceDataViewList->at(1);
int startIndex = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "start_index");
calc::TopkSort(oop_value, oop_temp, iop, startIndex);
}
REGISTER_CALC_OP(OP_TOPK_SORT, Opcode::OP_TOPK_SORT, ExecuteOpTopkSort);
void ExecuteOpTopkMerge(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
auto iop = ctx->ioperandDataViewList->at(0);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
int mergeSize = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "merge_size");
calc::TopkMerge(oop, iop, mergeSize);
}
REGISTER_CALC_OP(OP_TOPK_MERGE, Opcode::OP_TOPK_MERGE, ExecuteOpTopkMerge);
void ExecuteOpTopkExtract(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
auto iop = ctx->ioperandDataViewList->at(0);
auto oop = ctx->ooperandInplaceDataViewList->at(0);
int k = ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "k");
bool isIndex = static_cast<bool>(ctx->op->GetIntAttribute(OP_ATTR_PREFIX + "is_index"));
calc::TopkExtract(oop, iop, k, isIndex);
}
REGISTER_CALC_OP(OP_TOPK_EXTRACT, Opcode::OP_TOPK_EXTRACT, ExecuteOpTopkExtract);
void ExecuteOpReduceAcc(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
calc::ReduceAcc(ret, *ctx->ioperandDataViewList);
}
REGISTER_CALC_OP(OP_REDUCE_ACC, Opcode::OP_REDUCE_ACC, ExecuteOpReduceAcc);
template <Opcode opcode>
void ExecuteOpBinaryScalar(ExecuteOperationContext* ctx)
{
if (opcode == Opcode::OP_BITWISEXOR || opcode == Opcode::OP_REMRS || opcode == Opcode::OP_FLOORDIVS ||
opcode == Opcode::OP_REMS || opcode == Opcode::OP_POWS) {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
} else {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
}
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& lhs = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_FP32, 0.0f);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
bool reverse = ctx->op->GetBoolAttribute(OP_ATTR_PREFIX + "reverseOperand");
switch (opcode) {
case Opcode::OP_ADDS:
calc::AddS(ret, lhs, element);
break;
case Opcode::OP_SUBS:
calc::SubS(ret, lhs, element, reverse);
break;
case Opcode::OP_MULS:
calc::MulS(ret, lhs, element);
break;
case Opcode::OP_MAXS:
calc::MaxS(ret, lhs, element);
break;
case Opcode::OP_MINS:
calc::MinS(ret, lhs, element);
break;
case Opcode::OP_DIVS:
calc::DivS(ret, lhs, element, reverse);
break;
case Opcode::OP_FLOORDIVS:
calc::FloorDivS(ret, lhs, element, reverse);
break;
case Opcode::OP_REMS:
calc::RemainderS(ret, lhs, element, reverse);
break;
case Opcode::OP_POWS:
calc::PowS(ret, lhs, element);
break;
case Opcode::OP_REMRS:
calc::RemainderRS(ret, lhs, element, reverse);
break;
case Opcode::OP_S_MAXS:
calc::MaxS(ret, lhs, element);
break;
case Opcode::OP_S_MINS:
calc::MinS(ret, lhs, element);
break;
case Opcode::OP_LRELU:
calc::LReLU(ret, lhs, element);
break;
case Opcode::OP_BITWISEANDS:
calc::BitwiseAndS(ret, lhs, element);
break;
case Opcode::OP_BITWISEORS:
calc::BitwiseOrS(ret, lhs, element);
break;
case Opcode::OP_BITWISEXORS:
calc::BitwiseXorS(ret, lhs, element);
break;
case Opcode::OP_GCDS:
calc::GcdS(ret, lhs, element);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false);
}
}
REGISTER_CALC_OP(OP_ADDS, Opcode::OP_ADDS, ExecuteOpBinaryScalar<Opcode::OP_ADDS>);
REGISTER_CALC_OP(OP_SUBS, Opcode::OP_SUBS, ExecuteOpBinaryScalar<Opcode::OP_SUBS>);
REGISTER_CALC_OP(OP_MULS, Opcode::OP_MULS, ExecuteOpBinaryScalar<Opcode::OP_MULS>);
REGISTER_CALC_OP(OP_DIVS, Opcode::OP_DIVS, ExecuteOpBinaryScalar<Opcode::OP_DIVS>);
REGISTER_CALC_OP(OP_FLOORDIVS, Opcode::OP_FLOORDIVS, ExecuteOpBinaryScalar<Opcode::OP_FLOORDIVS>);
REGISTER_CALC_OP(OP_MAXS, Opcode::OP_MAXS, ExecuteOpBinaryScalar<Opcode::OP_MAXS>);
REGISTER_CALC_OP(OP_MINS, Opcode::OP_MINS, ExecuteOpBinaryScalar<Opcode::OP_MINS>);
REGISTER_CALC_OP(OP_LRELU, Opcode::OP_LRELU, ExecuteOpBinaryScalar<Opcode::OP_LRELU>);
REGISTER_CALC_OP(OP_BITWISEANDS, Opcode::OP_BITWISEANDS, ExecuteOpBinaryScalar<Opcode::OP_BITWISEANDS>);
REGISTER_CALC_OP(OP_BITWISEORS, Opcode::OP_BITWISEORS, ExecuteOpBinaryScalar<Opcode::OP_BITWISEORS>);
REGISTER_CALC_OP(OP_BITWISEXORS, Opcode::OP_BITWISEXORS, ExecuteOpBinaryScalar<Opcode::OP_BITWISEXORS>);
REGISTER_CALC_OP(OP_GCDS, Opcode::OP_GCDS, ExecuteOpBinaryScalar<Opcode::OP_GCDS>);
REGISTER_CALC_OP(OP_REMS, Opcode::OP_REMS, ExecuteOpBinaryScalar<Opcode::OP_REMS>);
REGISTER_CALC_OP(OP_REMRS, Opcode::OP_REMRS, ExecuteOpBinaryScalar<Opcode::OP_REMRS>);
REGISTER_CALC_OP(OP_POWS, Opcode::OP_POWS, ExecuteOpBinaryScalar<Opcode::OP_POWS>);
REGISTER_CALC_OP(OP_S_ADDS, Opcode::OP_S_ADDS, ExecuteOpBinaryScalar<Opcode::OP_ADDS>);
REGISTER_CALC_OP(OP_S_SUBS, Opcode::OP_S_SUBS, ExecuteOpBinaryScalar<Opcode::OP_SUBS>);
REGISTER_CALC_OP(OP_S_MULS, Opcode::OP_S_MULS, ExecuteOpBinaryScalar<Opcode::OP_MULS>);
REGISTER_CALC_OP(OP_S_DIVS, Opcode::OP_S_DIVS, ExecuteOpBinaryScalar<Opcode::OP_DIVS>);
REGISTER_CALC_OP(OP_S_MAXS, Opcode::OP_S_MAXS, ExecuteOpBinaryScalar<Opcode::OP_S_MAXS>);
REGISTER_CALC_OP(OP_S_MINS, Opcode::OP_S_MINS, ExecuteOpBinaryScalar<Opcode::OP_S_MINS>);
void ExecuteOpGatherElement(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& params = ctx->ioperandDataViewList->at(0);
auto& indices = ctx->ioperandDataViewList->at(1);
int axis = ctx->op->GetIntAttribute("op_attr_axis");
calc::GatherElements(ret, params, indices, axis);
}
REGISTER_CALC_OP(OP_GATHER_ELEMENT, Opcode::OP_GATHER_ELEMENT, ExecuteOpGatherElement);
void ExecuteOpGatherMask(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto& ret = ctx->ooperandInplaceDataViewList->at(0);
auto& self = ctx->ioperandDataViewList->at(0);
int patternMode = ctx->op->GetIntAttribute("op_attr_patternMode");
calc::GatherMask(ret, self, patternMode);
}
REGISTER_CALC_OP(OP_GATHER_MASK, Opcode::OP_GATHER_MASK, ExecuteOpGatherMask);
REGISTER_CALC_OP(OP_GATHER_MASK_BUILDIN, Opcode::OP_GATHER_MASK_BUILDIN, ExecuteOpGatherMask);
template <Opcode opcode>
void ExecuteOpBitwiseShift(ExecuteOperationContext* ctx)
{
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == SIZE_TWO);
auto ret = ctx->ooperandInplaceDataViewList->at(0);
auto lhs = ctx->ioperandDataViewList->at(0);
auto rhs = ctx->ioperandDataViewList->at(1);
switch (opcode) {
case Opcode::OP_BITWISERIGHTSHIFT:
calc::BitwiseRightShift(ret, lhs, rhs);
break;
case Opcode::OP_BITWISELEFTSHIFT:
calc::BitwiseLeftShift(ret, lhs, rhs);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false);
}
}
template <Opcode opcode>
void ExecuteOpBitwiseShiftScalar(ExecuteOperationContext* ctx)
{
if (opcode == Opcode::OP_SBITWISERIGHTSHIFT || opcode == Opcode::OP_SBITWISELEFTSHIFT) {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() <= SIZE_TWO);
} else {
ASSERT(ExecuteOperationScene::CTX_OUTPUT_COUNT_MISMATCH, ctx->ooperandInplaceDataViewList->size() == 1);
}
ASSERT(ExecuteOperationScene::CTX_INPUT_COUNT_MISMATCH, ctx->ioperandDataViewList->size() == 1);
auto ret = ctx->ooperandInplaceDataViewList->at(0);
auto lhs = ctx->ioperandDataViewList->at(0);
auto element = Element(DT_INT32, 0);
ctx->op->GetAttr(OpAttributeKey::scalar, element);
switch (opcode) {
case Opcode::OP_BITWISERIGHTSHIFTS:
calc::BitwiseRightShiftS(ret, lhs, element);
break;
case Opcode::OP_BITWISELEFTSHIFTS:
calc::BitwiseLeftShiftS(ret, lhs, element);
break;
case Opcode::OP_SBITWISERIGHTSHIFT:
calc::SBitwiseRightShift(ret, element, lhs);
break;
case Opcode::OP_SBITWISELEFTSHIFT:
calc::SBitwiseLeftShift(ret, element, lhs);
break;
default:
ASSERT(ExecuteOperationScene::UNSUPPORTED_OPCODE, false);
}
}
REGISTER_CALC_OP(
OP_BITWISERIGHTSHIFT, Opcode::OP_BITWISERIGHTSHIFT, ExecuteOpBitwiseShift<Opcode::OP_BITWISERIGHTSHIFT>);
REGISTER_CALC_OP(OP_BITWISELEFTSHIFT, Opcode::OP_BITWISELEFTSHIFT, ExecuteOpBitwiseShift<Opcode::OP_BITWISELEFTSHIFT>);
REGISTER_CALC_OP(
OP_BITWISERIGHTSHIFTS, Opcode::OP_BITWISERIGHTSHIFTS, ExecuteOpBitwiseShiftScalar<Opcode::OP_BITWISERIGHTSHIFTS>);
REGISTER_CALC_OP(
OP_BITWISELEFTSHIFTS, Opcode::OP_BITWISELEFTSHIFTS, ExecuteOpBitwiseShiftScalar<Opcode::OP_BITWISELEFTSHIFTS>);
REGISTER_CALC_OP(
OP_SBITWISERIGHTSHIFT, Opcode::OP_SBITWISERIGHTSHIFT, ExecuteOpBitwiseShiftScalar<Opcode::OP_SBITWISERIGHTSHIFT>);
REGISTER_CALC_OP(
OP_SBITWISELEFTSHIFT, Opcode::OP_SBITWISELEFTSHIFT, ExecuteOpBitwiseShiftScalar<Opcode::OP_SBITWISELEFTSHIFT>);
}