* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file op_infer_shape_impl.cpp
* \brief
*/
#include "op_infer_shape_impl.h"
#include "interface/operation/attr_holder.h"
#include "interface/operation/operation.h"
#include "interface/tensor/symbolic_scalar.h"
#include "interface/utils/common.h"
#include "tilefwk/error_code.h"
#include "tilefwk/error_code.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
const std::string COPY_OUT_FORCE_INFER_SHAPE = "copy_out_force_infer_shape";
void BinaryBrcinlineInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto validShape0 = op->GetIOperands()[0]->GetDynValidShape();
auto validShape1 = op->GetIOperands()[1]->GetDynValidShape();
std::vector<SymbolicScalar> outputValidShape = validShape0;
std::vector<int64_t> brcOperand;
if (op->GetAttr(OpAttributeKey::brcOperand, brcOperand)) {
for (size_t i = 0; i < outputValidShape.size(); i++) {
if (brcOperand[i] == 1) {
outputValidShape[i] = validShape1[i];
}
}
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outputValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ADD, Opcode::OP_ADD, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SUB, Opcode::OP_SUB, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MUL, Opcode::OP_MUL, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_DIV, Opcode::OP_DIV, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MAXIMUM, Opcode::OP_MAXIMUM, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MINIMUM, Opcode::OP_MINIMUM, BinaryBrcinlineInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_AXPY, Opcode::OP_AXPY, BinaryBrcinlineInferFunc);
void ElewiseInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inputNum = op->GetIOperands().size();
auto shapeDimNum = op->GetIOperands()[0]->GetDynValidShape().size();
std::vector<std::vector<SymbolicScalar>> dimValidShape(
shapeDimNum, std::vector<SymbolicScalar>(inputNum, SymbolicScalar()));
std::vector<std::vector<int64_t>> dimShape(shapeDimNum, std::vector<int64_t>(inputNum, 0));
for (size_t i = 0; i < op->GetIOperands().size(); ++i) {
auto iOperand = op->GetInputOperand(i);
auto validShape = op->GetIOperands()[i]->GetDynValidShape();
for (size_t dimIdx = 0; dimIdx < validShape.size(); ++dimIdx) {
dimValidShape[dimIdx][i] = validShape[dimIdx];
}
auto shape = op->GetIOperands()[i]->GetShape();
for (size_t dimIdx = 0; dimIdx < shape.size() && dimIdx < shapeDimNum; ++dimIdx) {
if (dimIdx == shape.size() - 1 && iOperand->GetProducers().size() == 1 &&
(*iOperand->GetProducers().begin())->GetOpcode() == Opcode::OP_BRCB) {
dimShape[dimIdx][i] = 1;
} else {
dimShape[dimIdx][i] = shape[dimIdx];
}
}
}
std::vector<int64_t> inputIdx(shapeDimNum, 0);
std::vector<SymbolicScalar> inputValidShape;
for (size_t i = 0; i < shapeDimNum; ++i) {
size_t oneDimNum = 0;
size_t noOneIndex = 0;
for (size_t j = 0; j < dimShape[i].size(); ++j) {
if (dimShape[i][j] == 1) {
oneDimNum++;
} else {
noOneIndex = j;
}
}
if (oneDimNum > 0 && oneDimNum < dimShape[i].size()) {
inputValidShape.push_back(dimValidShape[i][noOneIndex]);
inputIdx[i] = noOneIndex;
continue;
}
auto flag = false;
auto minDim = SymbolicScalar();
for (auto dim : dimValidShape[i]) {
if (!(dim.IsImmediate())) {
inputValidShape.push_back(dim);
flag = true;
break;
} else {
minDim = minDim.ConcreteValid() ? std::min(minDim.Concrete(), dim.Concrete()) : dim.Concrete();
}
}
if (!flag) {
inputValidShape.push_back(minDim);
}
}
int64_t mode = 0;
if (op->GetAttr(OP_ATTR_PREFIX + "cmp_mode", mode) && mode == 1) {
inputValidShape[inputValidShape.size() - 1] =
inputValidShape[inputValidShape.size() - 1] / 8;
}
int64_t whereBitMode = 0;
if (op->GetAttr(OP_ATTR_PREFIX + "whereBitMode", whereBitMode) && whereBitMode == 1 && inputIdx.back() == 0) {
inputValidShape[inputValidShape.size() - 1] = inputValidShape[inputValidShape.size() - 1] * 8;
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(inputValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ADDS, Opcode::OP_ADDS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MULS, Opcode::OP_MULS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_DIVS, Opcode::OP_S_DIVS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_POW, Opcode::OP_POW, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_POWS, Opcode::OP_POWS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_EXP, Opcode::OP_EXP, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_EXP2, Opcode::OP_EXP2, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_EXPM1, Opcode::OP_EXPM1, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SIGN, Opcode::OP_SIGN, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SIGNBIT, Opcode::OP_SIGNBIT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_TANH, Opcode::OP_TANH, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NEG, Opcode::OP_NEG, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROUND, Opcode::OP_ROUND, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_RSQRT, Opcode::OP_RSQRT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_RELU, Opcode::OP_RELU, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_LOG1P, Opcode::OP_LOG1P, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SQRT, Opcode::OP_SQRT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_CEIL, Opcode::OP_CEIL, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_FLOOR, Opcode::OP_FLOOR, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_TRUNC, Opcode::OP_TRUNC, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISENOT, Opcode::OP_BITWISENOT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_GCD, Opcode::OP_GCD, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_GCDS, Opcode::OP_GCDS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ABS, Opcode::OP_ABS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_LN, Opcode::OP_LN, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ISFINITE, Opcode::OP_ISFINITE, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_HUB, Opcode::OP_HUB, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ERF, Opcode::OP_ERF, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SIN, Opcode::OP_SIN, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_COS, Opcode::OP_COS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ERFC, Opcode::OP_ERFC, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ATAN, Opcode::OP_ATAN, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ATAN2, Opcode::OP_ATAN2, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_DIVS, Opcode::OP_DIVS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_RECIPROCAL, Opcode::OP_RECIPROCAL, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SUBS, Opcode::OP_SUBS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MAXS, Opcode::OP_MAXS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MINS, Opcode::OP_MINS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEANDS, Opcode::OP_BITWISEANDS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEORS, Opcode::OP_BITWISEORS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEXORS, Opcode::OP_BITWISEXORS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_ADDS, Opcode::OP_S_ADDS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_SUBS, Opcode::OP_S_SUBS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MULS, Opcode::OP_S_MULS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MAXS, Opcode::OP_S_MAXS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MINS, Opcode::OP_S_MINS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_LRELU, Opcode::OP_LRELU, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_ADD, Opcode::OP_S_ADD, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_SUB, Opcode::OP_S_SUB, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MUL, Opcode::OP_S_MUL, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_DIV, Opcode::OP_S_DIV, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MAX, Opcode::OP_S_MAX, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_S_MIN, Opcode::OP_S_MIN, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_CUM_SUM, Opcode::OP_CUM_SUM, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_CUM_PROD, Opcode::OP_CUM_PROD, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_TRIUL, Opcode::OP_TRIUL, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_REGISTER_COPY, Opcode::OP_REGISTER_COPY, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_CMP, Opcode::OP_CMP, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_CMPS, Opcode::OP_CMPS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_HYPOT, Opcode::OP_HYPOT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MOD, Opcode::OP_MOD, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MODS, Opcode::OP_MODS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISERIGHTSHIFT, Opcode::OP_BITWISERIGHTSHIFT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISELEFTSHIFT, Opcode::OP_BITWISELEFTSHIFT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISERIGHTSHIFTS, Opcode::OP_BITWISERIGHTSHIFTS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISELEFTSHIFTS, Opcode::OP_BITWISELEFTSHIFTS, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SBITWISERIGHTSHIFT, Opcode::OP_SBITWISERIGHTSHIFT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SBITWISELEFTSHIFT, Opcode::OP_SBITWISELEFTSHIFT, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEAND, Opcode::OP_BITWISEAND, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEOR, Opcode::OP_BITWISEOR, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_BITWISEXOR, Opcode::OP_BITWISEXOR, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_EXPANDEXPDIF, Opcode::OP_EXPANDEXPDIF, ElewiseInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_COPYSIGN, Opcode::OP_COPYSIGN, ElewiseInferFunc);
void RemInferShapeFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ElewiseInferFunc(op, outValidShapes);
outValidShapes.erase(outValidShapes.begin() + 1, outValidShapes.end());
std::vector<SymbolicScalar> tmpValidShape;
tmpValidShape.push_back(SymbolicScalar(op->GetOOperands()[1]->GetShape()[0]));
tmpValidShape.push_back(SymbolicScalar(op->GetOOperands()[1]->GetShape()[1]));
outValidShapes.push_back(tmpValidShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_REM, Opcode::OP_REM, RemInferShapeFunc);
REGISTER_INFER_SHAPE_FUNC(OP_REMS, Opcode::OP_REMS, RemInferShapeFunc);
REGISTER_INFER_SHAPE_FUNC(OP_REMRS, Opcode::OP_REMRS, RemInferShapeFunc);
template <int64_t TailScale>
void InferShapeWithTailScaleFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
if (inputValidShape.empty()) {
return;
}
size_t ndim = inputValidShape.size();
outValidShapes.emplace_back(inputValidShape);
outValidShapes.emplace_back(std::vector<SymbolicScalar>{inputValidShape[ndim - 1] * TailScale});
}
REGISTER_INFER_SHAPE_FUNC(OP_FLOORDIV, Opcode::OP_FLOORDIV, InferShapeWithTailScaleFunc<4>);
REGISTER_INFER_SHAPE_FUNC(OP_FLOORDIVS, Opcode::OP_FLOORDIVS, InferShapeWithTailScaleFunc<3>);
REGISTER_INFER_SHAPE_FUNC(OP_SINH, Opcode::OP_SINH, InferShapeWithTailScaleFunc<4>);
REGISTER_INFER_SHAPE_FUNC(OP_COSH, Opcode::OP_COSH, InferShapeWithTailScaleFunc<1>);
REGISTER_INFER_SHAPE_FUNC(OP_TAN, Opcode::OP_TAN, InferShapeWithTailScaleFunc<7>);
REGISTER_INFER_SHAPE_FUNC(OP_ASINH, Opcode::OP_ASINH, InferShapeWithTailScaleFunc<4>);
REGISTER_INFER_SHAPE_FUNC(OP_ACOSH, Opcode::OP_ACOSH, InferShapeWithTailScaleFunc<3>);
void AsinAcosInferShapeFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
if (inputValidShape.empty()) {
return;
}
size_t ndim = inputValidShape.size();
outValidShapes.emplace_back(inputValidShape);
if (ndim >= 2) {
outValidShapes.emplace_back(
std::vector<SymbolicScalar>{inputValidShape[ndim - 2] * inputValidShape[ndim - 1] * 5});
} else {
outValidShapes.emplace_back(std::vector<SymbolicScalar>{inputValidShape[ndim - 1] * 5});
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ASIN, Opcode::OP_ASIN, AsinAcosInferShapeFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ACOS, Opcode::OP_ACOS, AsinAcosInferShapeFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ATANH, Opcode::OP_ATANH, InferShapeWithTailScaleFunc<4>);
void PadInferShapeFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
if (inputValidShape.empty()) {
return;
}
size_t ndim = inputValidShape.size();
std::vector<SymbolicScalar> outValidShape = inputValidShape;
int64_t padRight = 0;
int64_t padBottom = 0;
op->GetAttr(OP_ATTR_PREFIX + "pad_right", padRight);
op->GetAttr(OP_ATTR_PREFIX + "pad_bottom", padBottom);
if (ndim >= 1 && padRight > 0) {
outValidShape[ndim - 1] = outValidShape[ndim - 1] + padRight;
}
if (ndim >= 2 && padBottom > 0) {
outValidShape[ndim - 2] = outValidShape[ndim - 2] + padBottom;
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_PAD, Opcode::OP_PAD, PadInferShapeFunc);
void FillPadInferShapeFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
for (auto c : op->GetOOperands()[0]->oriShape) {
outValidShape.push_back(SymbolicScalar(c));
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_FILLPAD, Opcode::OP_FILLPAD, FillPadInferShapeFunc);
void IndexOutCastInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
size_t input_dim = 2;
auto inValidShape = op->GetIOperands()[input_dim]->GetDynValidShape();
for (size_t i = 0; i < inValidShape.size(); ++i) {
outValidShape.push_back(inValidShape[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
auto indexOutCastOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (indexOutCastOpAttribute == nullptr) {
VECTOR_LOGW("IndexOutCast [%d] has no copyOpAttr.", op->GetOpMagic());
return;
}
indexOutCastOpAttribute->SetFromDynValidShape(
OpImmediate::Specified(op->GetIOperands()[input_dim]->GetDynValidShape()));
}
REGISTER_INFER_SHAPE_FUNC(OP_INDEX_OUTCAST, Opcode::OP_INDEX_OUTCAST, IndexOutCastInferFunc);
void CastInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto srcValidShape = op->GetIOperands()[0]->GetDynValidShape();
outValidShapes.push_back(srcValidShape);
if (op->GetOOperands().size() > 1) {
std::vector<SymbolicScalar> tmpValidShape;
for (size_t i = 0; i < op->GetOOperands()[1]->GetShape().size(); ++i) {
tmpValidShape.push_back(SymbolicScalar(op->GetOOperands()[1]->GetShape()[i]));
}
outValidShapes.push_back(tmpValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_CAST, Opcode::OP_CAST, CastInferFunc);
void GatherElementInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
auto inValidShape = op->GetIOperands()[1]->GetDynValidShape();
for (size_t i = 0; i < inValidShape.size(); ++i) {
outValidShape.push_back(inValidShape[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_GATHER_ELEMENT, Opcode::OP_GATHER_ELEMENT, GatherElementInferFunc);
void GatherMaskFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
uint8_t patternMode = op->GetIntAttribute(OP_ATTR_PREFIX + "patternMode");
if (patternMode == 1 || patternMode == 2) {
res.back() = res.back() / 2;
} else if (patternMode == 3 || patternMode == 4 || patternMode == 5 || patternMode == 6) {
res.back() = res.back() / 4;
}
outValidShapes.push_back(res);
}
REGISTER_INFER_SHAPE_FUNC(OP_GATHER_MASK, Opcode::OP_GATHER_MASK, GatherMaskFunc);
void ScatterInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
auto inValidShape = op->GetIOperands()[0]->GetDynValidShape();
for (size_t i = 0; i < inValidShape.size(); ++i) {
outValidShape.push_back(inValidShape[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_SCATTER_ELEMENT, Opcode::OP_SCATTER_ELEMENT, ScatterInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SCATTER, Opcode::OP_SCATTER, ScatterInferFunc);
void IndexAddInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto selfValidShape = op->GetIOperands()[0]->GetDynValidShape();
auto srcValidShape = op->GetIOperands()[1]->GetDynValidShape();
outValidShapes.push_back(selfValidShape);
outValidShapes.push_back({1, srcValidShape[srcValidShape.size() - 1]});
}
REGISTER_INFER_SHAPE_FUNC(OP_INDEX_ADD, Opcode::OP_INDEX_ADD, IndexAddInferFunc);
void IndexAddUBInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inValidShape = op->GetIOperands()[0]->GetDynValidShape();
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(inValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_INDEX_ADD_UB, Opcode::OP_INDEX_ADD_UB, IndexAddUBInferFunc);
void LogicalNotInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ElewiseInferFunc(op, outValidShapes);
outValidShapes.erase(outValidShapes.begin() + 1, outValidShapes.end());
auto data_type = op->GetIOperands()[0]->Datatype();
DataType select_dtype;
if (data_type == DT_FP32 || data_type == DT_BF16) {
select_dtype = DT_FP32;
} else {
select_dtype = DT_FP16;
}
constexpr int64_t COUNT_SIZE = 2048;
constexpr int64_t vcmp_bit_size = COUNT_SIZE / 8;
constexpr size_t ALIGN_SIZE = 32;
int64_t total_size = COUNT_SIZE * 2 + COUNT_SIZE * BytesOf(select_dtype) * 2 + vcmp_bit_size + 8;
total_size = (total_size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE;
int64_t shape = total_size / BytesOf(select_dtype);
outValidShapes.push_back({shape});
}
REGISTER_INFER_SHAPE_FUNC(OP_LOGICALNOT, Opcode::OP_LOGICALNOT, LogicalNotInferFunc);
void LogicalAndInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ElewiseInferFunc(op, outValidShapes);
outValidShapes.erase(outValidShapes.begin() + 1, outValidShapes.end());
const int64_t COUNT_SIZE = 64;
outValidShapes.push_back({COUNT_SIZE * 5 + COUNT_SIZE / 8 + 1});
}
REGISTER_INFER_SHAPE_FUNC(OP_LOGICALAND, Opcode::OP_LOGICALAND, LogicalAndInferFunc);
void ViewTypeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto inputOperand = op->GetIOperands()[0];
auto outputOperand = op->GetOOperands()[0];
auto validShape = inputOperand->GetDynValidShape();
auto changedDim =
validShape[validShape.size() - 1] * BytesOf(inputOperand->Datatype()) / BytesOf(outputOperand->Datatype());
validShape[validShape.size() - 1] = changedDim;
outValidShapes.push_back(validShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_VIEW_TYPE, Opcode::OP_VIEW_TYPE, ViewTypeInferFunc);
void IndexPutInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
outValidShapes.push_back(op->GetIOperands()[0]->GetDynValidShape());
}
REGISTER_INFER_SHAPE_FUNC(OP_INDEX_PUT, Opcode::OP_INDEX_PUT, IndexPutInferFunc);
void PairReduceInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto dimSize = op->GetIOperands()[0]->GetDynValidShape().size();
std::vector<SymbolicScalar> outValidShape;
for (size_t i = 0; i < dimSize; i++) {
outValidShape.push_back(op->GetIOperands()[0]->GetDynValidShape()[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_PAIRMAX, Opcode::OP_PAIRMAX, PairReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PAIRMIN, Opcode::OP_PAIRMIN, PairReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PAIRSUM, Opcode::OP_PAIRSUM, PairReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PAIRPROD, Opcode::OP_PAIRPROD, PairReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PAIRARGMAX, Opcode::OP_PAIRARGMAX, PairReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PAIRARGMIN, Opcode::OP_PAIRARGMIN, PairReduceInferFunc);
void ElewiseBrcInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
for (size_t i = 0; i < op->GetIOperands()[0]->GetDynValidShape().size(); ++i) {
auto leftIShapeDim = op->GetIOperands()[0]->GetDynValidShape()[i];
if (leftIShapeDim.IsImmediate() && leftIShapeDim.Concrete() == 1) {
outValidShape.push_back(op->GetIOperands()[1]->GetDynValidShape()[i]);
} else {
outValidShape.push_back(leftIShapeDim);
}
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ADD_BRC, Opcode::OP_ADD_BRC, ElewiseBrcInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_SUB_BRC, Opcode::OP_SUB_BRC, ElewiseBrcInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MUL_BRC, Opcode::OP_MUL_BRC, ElewiseBrcInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_DIV_BRC, Opcode::OP_DIV_BRC, ElewiseBrcInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_MAX_BRC, Opcode::OP_MAX_BRC, ElewiseBrcInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_GCD_BRC, Opcode::OP_GCD_BRC, ElewiseBrcInferFunc);
void BroadcastInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
if (op->GetAttr(OP_ATTR_PREFIX + "validShape", outValidShape)) {
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
return;
}
auto outTensor = op->GetOOperands()[0];
for (size_t i = 0; i < op->GetIOperands()[0]->GetDynValidShape().size(); ++i) {
if (op->GetIOperands()[0]->oriShape[i] != 1) {
outValidShape.push_back(op->GetIOperands()[0]->GetDynValidShape()[i]);
} else {
outValidShape.push_back(SymbolicScalar(op->GetOOperands()[0]->GetShape()[i]));
}
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_EXPAND, Opcode::OP_EXPAND, BroadcastInferFunc);
void OneHotInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape(op->GetIOperands()[0]->GetDynValidShape());
int lastDim = op->GetIntAttribute(OP_ATTR_PREFIX + "numClasses");
outValidShape.push_back(SymbolicScalar(lastDim));
outValidShapes.push_back(outValidShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_ONEHOT, Opcode::OP_ONEHOT, OneHotInferFunc);
void QuantizeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> dstValidShape(op->GetIOperands()[0]->GetDynValidShape());
outValidShapes.push_back(dstValidShape);
auto inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
if (!inputValidShape.empty()) {
std::vector<SymbolicScalar> tmpValidShape;
if (inputValidShape.size() >= 2) {
SymbolicScalar rows = inputValidShape[inputValidShape.size() - 2];
SymbolicScalar tmpBufSize = rows * 32;
tmpValidShape.push_back(tmpBufSize);
} else {
tmpValidShape.push_back(SymbolicScalar(256));
}
outValidShapes.push_back(tmpValidShape);
} else {
std::vector<SymbolicScalar> tmpValidShape;
tmpValidShape.push_back(SymbolicScalar(256));
outValidShapes.push_back(tmpValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_QUANTIZE_SYM, Opcode::OP_QUANTIZE_SYM, QuantizeInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_QUANTIZE_ASYM, Opcode::OP_QUANTIZE_ASYM, QuantizeInferFunc);
void DequantizeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape(op->GetIOperands()[0]->GetDynValidShape());
outValidShapes.push_back(outValidShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_DEQUANTIZE, Opcode::OP_DEQUANTIZE, DequantizeInferFunc);
void RangeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
Element size = op->GetElementAttribute(OP_ATTR_PREFIX + "SIZE");
outValidShape.push_back(SymbolicScalar(size.GetSignedData()));
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_RANGE, Opcode::OP_RANGE, RangeInferFunc);
void UniformInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
auto shapeAttr = op->GetVectorIntAttribute(OP_ATTR_PREFIX + "SHAPE");
for (auto dim : shapeAttr) {
outValidShape.push_back(SymbolicScalar(static_cast<int64_t>(dim)));
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_UNIFORM, Opcode::OP_UNIFORM, UniformInferFunc);
void ReduceInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
auto outValidShape = inputValidShapes[0];
int axis = op->GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
outValidShape[axis] = SymbolicScalar(1);
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ROWSUMLINE, Opcode::OP_ROWSUMLINE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWMAXLINE, Opcode::OP_ROWMAXLINE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWMINLINE, Opcode::OP_ROWMINLINE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWPRODLINE, Opcode::OP_ROWPRODLINE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWMAX_SINGLE, Opcode::OP_ROWMAX_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWMIN_SINGLE, Opcode::OP_ROWMIN_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWSUM_SINGLE, Opcode::OP_ROWSUM_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWPROD_SINGLE, Opcode::OP_ROWPROD_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWARGMAXWITHVALUE_SINGLE, Opcode::OP_ROWARGMAXWITHVALUE_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWARGMINWITHVALUE_SINGLE, Opcode::OP_ROWARGMINWITHVALUE_SINGLE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWARGMAXWITHVALUE_LINE, Opcode::OP_ROWARGMAXWITHVALUE_LINE, ReduceInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_ROWARGMINWITHVALUE_LINE, Opcode::OP_ROWARGMINWITHVALUE_LINE, ReduceInferFunc);
void WhereInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ElewiseInferFunc(op, outValidShapes);
}
REGISTER_INFER_SHAPE_FUNC(OP_WHERE_TT, Opcode::OP_WHERE_TT, WhereInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_WHERE_TS, Opcode::OP_WHERE_TS, WhereInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_WHERE_ST, Opcode::OP_WHERE_ST, WhereInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_WHERE_SS, Opcode::OP_WHERE_SS, WhereInferFunc);
void InferFunc4Gather(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto iOperands = op->GetIOperands();
ASSERT(VectorErrorCode::ERR_PARAM_COUNT_INVALID, iOperands.size() >= NUM2) << "iOperands.size() should be >= NUM2";
int axis = op->GetIntAttribute(OP_ATTR_PREFIX + "axis");
int src0Rank = iOperands[0]->GetShape().size();
if (axis < 0) {
axis = axis + src0Rank;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, axis >= 0 && axis < src0Rank) << "InferFunc4Gather, axis is invalid";
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : iOperands) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
std::vector<SymbolicScalar> outValidShape = inputValidShapes[0];
outValidShape.erase(outValidShape.begin() + axis);
outValidShape.insert(outValidShape.begin() + axis, inputValidShapes[1].begin(), inputValidShapes[1].end());
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_GATHER_FROM_UB, Opcode::OP_GATHER_FROM_UB, InferFunc4Gather);
REGISTER_INFER_SHAPE_FUNC(OP_GATHER, Opcode::OP_GATHER, InferFunc4Gather);
void InferFuncGatherInL1(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto iOperands = op->GetIOperands();
ASSERT(MatmulErrorCode::ERR_PARAM_MISMATCH, iOperands.size() == 3) << "iOperands.size() should be 3";
auto srcValidShape = iOperands[0]->GetDynValidShape();
auto offsetValidShape = iOperands[1]->GetDynValidShape();
auto srcStartColumnOffset = op->GetIntAttribute(OpAttributeKey::startOffset);
ASSERT(MatmulErrorCode::ERR_PARAM_MISMATCH, op->GetOOperands().size() == 1)
<< "op->GetOOperands().size() should be 1";
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(
{offsetValidShape[1], std::min(srcValidShape[1] - srcStartColumnOffset, output->GetShape()[1])});
}
}
REGISTER_INFER_SHAPE_FUNC(OP_GATHER_IN_L1, Opcode::OP_GATHER_IN_L1, InferFuncGatherInL1);
* 定制,
* parma [a,b]
* indices [1,c]
* axis=-2
* result [c,b]
*/
void InferFuncGatherInUB(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto iOperands = op->GetIOperands();
ASSERT(VectorErrorCode::ERR_PARAM_COUNT_INVALID, iOperands.size() == 3) << "iOperands.size() should be 3";
auto srcValidShape = iOperands[0]->GetDynValidShape();
auto indicesValidShape = iOperands[1]->GetDynValidShape();
ASSERT(VectorErrorCode::ERR_PARAM_COUNT_INVALID, op->GetOOperands().size() == 1)
<< "op->GetOOperands().size() should be 1";
for (auto output : op->GetOOperands()) {
outValidShapes.push_back({indicesValidShape[1], srcValidShape[1]});
}
}
REGISTER_INFER_SHAPE_FUNC(OP_GATHER_IN_UB, Opcode::OP_GATHER_IN_UB, InferFuncGatherInUB);
void MatmulInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
for (auto inputTensor : op->GetIOperands()) {
auto inputValidShape = inputTensor->GetDynValidShape();
if (inputTensor->GetMemoryTypeOriginal() == MemoryType::MEM_L0A) {
outValidShape.push_back(inputValidShape[0]);
} else if (inputTensor->GetMemoryTypeOriginal() == MemoryType::MEM_L0B) {
outValidShape.push_back(inputValidShape[1]);
}
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_A_MUL_B, Opcode::OP_A_MUL_B, MatmulInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_A_MUL_BT, Opcode::OP_A_MUL_BT, MatmulInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_AT_MUL_B, Opcode::OP_AT_MUL_B, MatmulInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_AT_MUL_BT, Opcode::OP_AT_MUL_BT, MatmulInferFunc);
void LoadBTFBInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
for (auto output : op->GetOOperands()) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, !output->GetDynValidShape().empty())
<< "output->GetDynValidShape() should not be empty";
outValidShapes.push_back(output->GetDynValidShape());
}
}
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_BT, Opcode::OP_L1_TO_BT, LoadBTFBInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_FIX_QUANT_PRE, Opcode::OP_L1_TO_FIX_QUANT_PRE, LoadBTFBInferFunc);
void MatmulACCInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> outValidShape;
for (auto inputTensor : op->GetIOperands()) {
if (inputTensor->GetMemoryTypeOriginal() == MemoryType::MEM_L0C) {
outValidShape = inputTensor->GetDynValidShape();
break;
}
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_A_MULACC_B, Opcode::OP_A_MULACC_B, MatmulACCInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_A_MULACC_BT, Opcode::OP_A_MULACC_BT, MatmulACCInferFunc);
void LoadL0C2L1InferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyAttr = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyAttr != nullptr) {
auto fromValidShape = op->GetIOperands()[0]->GetDynValidShape();
copyAttr->SetFromDynValidShape(OpImmediate::Specified(fromValidShape));
} else {
VECTOR_LOGW(
"%s[%d] has no copy out attr, set output valid shape same as input.", op->GetOpcodeStr().c_str(),
op->GetOpMagic());
outValidShapes.emplace_back(op->GetIOperands()[0]->GetDynValidShape());
return;
}
auto offsets = copyAttr->GetToOffset();
auto inputShapes = copyAttr->GetToDynValidShape();
std::vector<SymbolicScalar> outDynShape = op->GetOOperands()[0]->GetDynValidShape();
if (outDynShape.empty()) {
outDynShape.resize(op->GetOOperands()[0]->GetShape().size(), SymbolicScalar(0));
}
std::vector<SymbolicScalar> outShape;
for (size_t i = 0; i < inputShapes.size(); i++) {
auto inputShape = inputShapes[i].GetSpecifiedValue();
auto offset = offsets[i].GetSpecifiedValue();
SymbolicScalar actualDim = std::max(outDynShape[i], (inputShape + offset) * (inputShape != 0));
outShape.emplace_back(actualDim);
}
for (auto output : op->GetOOperands()) {
outValidShapes.emplace_back(outShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_L0C_TO_L1, Opcode::OP_L0C_TO_L1, LoadL0C2L1InferFunc);
void Load2L1InferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(inputValidShapes[0]);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_UB_COPY_ND2NZ, Opcode::OP_UB_COPY_ND2NZ, Load2L1InferFunc);
void Load2L1MXScaleInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
ASSERT(
MatmulErrorCode::ERR_PARAM_INVALID, !op->GetIOperands().empty() && op->GetIOperands()[0] != nullptr &&
op->GetIOperands()[0]->GetDynValidShape().size() == SHAPE_DIM3)
<< "op->GetIOperands() should not be empty and GetDynValidShape().size() should be SHAPE_DIM3";
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (op->GetOOperands()[0] != nullptr && !(op->GetOOperands()[0]->GetDynValidShape().empty())) {
outValidShapes.push_back(op->GetOOperands()[0]->GetDynValidShape());
if (copyOpAttribute != nullptr && (copyOpAttribute->GetToDynValidShape()).empty()) {
auto toDynShape = OpImmediate::Specified(op->GetOOperands()[0]->GetDynValidShape());
copyOpAttribute->SetToDynValidShape(toDynShape);
}
return;
}
}
REGISTER_INFER_SHAPE_FUNC(OP_L1_COPY_IN_B_SCALE, Opcode::OP_L1_COPY_IN_B_SCALE, Load2L1MXScaleInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L1_COPY_IN_A_SCALE, Opcode::OP_L1_COPY_IN_A_SCALE, Load2L1MXScaleInferFunc);
template <bool isTrans = false>
void LoadL0InferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
if (op->HasAttr(Matrix::L1_TO_L0_OFFSET) && op->HasAttr(Matrix::L1_TO_L0_TILE)) {
std::vector<SymbolicScalar> offset;
std::vector<SymbolicScalar> tile;
op->GetAttr(Matrix::L1_TO_L0_OFFSET, offset);
op->GetAttr(Matrix::L1_TO_L0_TILE, tile);
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, offset.size() == SHAPE_DIM2) << "offset.size() should be SHAPE_DIM2";
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, tile.size() == SHAPE_DIM2) << "tile.size() should be SHAPE_DIM2";
ASSERT(
MatmulErrorCode::ERR_PARAM_INVALID, !op->GetIOperands().empty() && op->GetIOperands()[0] != nullptr &&
op->GetIOperands()[0]->GetDynValidShape().size() == SHAPE_DIM2)
<< "op->GetIOperands() check failed or GetDynValidShape().size() should be SHAPE_DIM2";
std::vector<SymbolicScalar> srcValidShape = op->GetIOperands()[0]->GetDynValidShape();
std::vector<SymbolicScalar> dstValidShape = GetViewValidShape(
srcValidShape, SymbolicScalar::Concrete(offset, 0), offset, SymbolicScalar::Concrete(tile, 0));
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, dstValidShape.size() == SHAPE_DIM2)
<< "dstValidShape.size() should be SHAPE_DIM2";
if constexpr (isTrans) {
std::swap(dstValidShape[0], dstValidShape[1]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(dstValidShape);
}
return;
}
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, inputTensor != nullptr) << "inputTensor should not be nullptr";
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty() || inputValidShapes[0].size() != SHAPE_DIM2) {
return;
}
for (auto output : op->GetOOperands()) {
if constexpr (isTrans) {
outValidShapes.push_back({inputValidShapes[0][1], inputValidShapes[0][0]});
} else {
outValidShapes.push_back(inputValidShapes[0]);
}
}
}
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0A, Opcode::OP_L1_TO_L0A, LoadL0InferFunc<false>);
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0B, Opcode::OP_L1_TO_L0B, LoadL0InferFunc<false>);
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0_AT, Opcode::OP_L1_TO_L0_AT, LoadL0InferFunc<true>);
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0_BT, Opcode::OP_L1_TO_L0_BT, LoadL0InferFunc<true>);
void LoadL0MXInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
std::vector<SymbolicScalar> offset;
std::vector<SymbolicScalar> tile;
op->GetAttr(Matrix::L1_TO_L0_OFFSET, offset);
op->GetAttr(Matrix::L1_TO_L0_TILE, tile);
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, offset.size() == SHAPE_DIM3) << "offset.size() should be SHAPE_DIM3";
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, tile.size() == SHAPE_DIM3) << "tile.size() should be SHAPE_DIM3";
ASSERT(
MatmulErrorCode::ERR_PARAM_INVALID, !op->GetIOperands().empty() && op->GetIOperands()[0] != nullptr &&
op->GetIOperands()[0]->GetDynValidShape().size() == SHAPE_DIM3)
<< "op->GetIOperands() check failed or GetDynValidShape().size() should be SHAPE_DIM3";
std::vector<SymbolicScalar> srcValidShape = op->GetIOperands()[0]->GetDynValidShape();
std::vector<SymbolicScalar> dstValidShape = GetViewValidShape(
srcValidShape, SymbolicScalar::Concrete(offset, 0), offset, SymbolicScalar::Concrete(tile, 0));
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, dstValidShape.size() == SHAPE_DIM3)
<< "dstValidShape.size() should be SHAPE_DIM3";
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(dstValidShape);
}
return;
}
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0A_SCALE, Opcode::OP_L1_TO_L0A_SCALE, LoadL0MXInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L1_TO_L0B_SCALE, Opcode::OP_L1_TO_L0B_SCALE, LoadL0MXInferFunc);
void L1CopyInConvInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
const std::string L1_TILE_SHAPE = "l1_tile_shape";
const std::string IS_FMAP_FLAG = "IS_FMAP";
ASSERT(MatmulErrorCode::ERR_PARAM_INVALID, op->HasAttr(L1_TILE_SHAPE)) << "op should have L1_TILE_SHAPE attr";
std::vector<SymbolicScalar> tile;
op->GetAttr(L1_TILE_SHAPE, tile);
ASSERT(MatmulErrorCode::ERR_PARAM_INVALID, op->HasAttr(IS_FMAP_FLAG)) << "op should have IS_FMAP_FLAG attr";
bool isFmap = false;
op->GetAttr(IS_FMAP_FLAG, isFmap);
if (isFmap) {
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, tile.size() == SHAPE_DIM5 || tile.size() == SHAPE_DIM6)
<< "tile.size() should be SHAPE_DIM5 or SHAPE_DIM6";
} else {
ASSERT(MatmulErrorCode::ERR_CONFIG_TILE, tile.size() == SHAPE_DIM4) << "tile.size() should be SHAPE_DIM4";
}
std::vector<SymbolicScalar> outShape;
for (size_t i = 0; i < tile.size(); i++) {
outShape.push_back(tile[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outShape);
}
}
void L1ToL0ConvInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
const std::string L0_TILE_SHAPE = "l0_tile_shape";
ASSERT(MatmulErrorCode::ERR_PARAM_INVALID, op->HasAttr(L0_TILE_SHAPE)) << "op should have L0_TILE_SHAPE attr";
std::vector<SymbolicScalar> tile;
op->GetAttr(L0_TILE_SHAPE, tile);
std::vector<SymbolicScalar> outShape;
for (size_t i = 0; i < tile.size(); i++) {
outShape.push_back(tile[i]);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outShape);
}
}
void L0CCopyOutConvInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(MatmulErrorCode::ERR_RUNTIME_NULLPTR, op != nullptr) << "op should not be nullptr";
if (!(op->GetOOperands()[0]->GetDynValidShape().empty())) {
outValidShapes.push_back(op->GetOOperands()[0]->GetDynValidShape());
}
}
REGISTER_INFER_SHAPE_FUNC(OP_L1_COPY_IN_CONV, Opcode::OP_L1_COPY_IN_CONV, L1CopyInConvInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_LOAD3D_CONV, Opcode::OP_LOAD3D_CONV, L1ToL0ConvInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_LOAD2D_CONV, Opcode::OP_LOAD2D_CONV, L1ToL0ConvInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L0C_COPY_OUT_CONV, Opcode::OP_L0C_COPY_OUT_CONV, L0CCopyOutConvInferFunc);
void TransDataDefaultInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
SymbolicScalar validShapeN = inputValidShape[0];
SymbolicScalar validShapeC = inputValidShape[1];
SymbolicScalar validShapeH = inputValidShape[2];
SymbolicScalar validShapeW = inputValidShape[3];
SymbolicScalar validShapeN0 = 16;
SymbolicScalar validShapeC0 = SymbolicScalar(BLOCK_SIZE / BytesOf(op->GetIOperands()[0]->Datatype()));
SymbolicScalar validShapeN1 = validShapeN / validShapeN0;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, validShapeC0 > 0) << "The validShapeC0 is not valid !";
SymbolicScalar validShapeC1 = validShapeC / validShapeC0;
if (op->GetOpcode() == Opcode::OP_NCHW2NC1HWC0) {
std::vector<SymbolicScalar> outputValidShape = {
validShapeN, validShapeC1, validShapeH, validShapeW, validShapeC0};
outValidShapes.push_back(outputValidShape);
} else if (op->GetOpcode() == Opcode::OP_NCHW2Fractal_Z) {
std::vector<SymbolicScalar> outputValidShape = {
validShapeC1 * validShapeH * validShapeW, validShapeN1, validShapeN0, validShapeC0};
outValidShapes.push_back(outputValidShape);
}
std::vector<int64_t> tmpValidShape = op->GetOOperands()[1]->GetShape();
outValidShapes.push_back(SymbolicScalar::FromConcrete(tmpValidShape));
}
void TransDataNCDHW2NDC1HWC0InferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
SymbolicScalar validShapeN = inputValidShape[0];
SymbolicScalar validShapeD = inputValidShape[1];
SymbolicScalar validShapeC = inputValidShape[2];
SymbolicScalar validShapeH = inputValidShape[3];
SymbolicScalar validShapeW = inputValidShape[4];
SymbolicScalar validShapeC0 = SymbolicScalar(BLOCK_SIZE / BytesOf(op->GetIOperands()[0]->Datatype()));
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, validShapeC0 > 0) << "The validShapeC0 is not valid !";
SymbolicScalar validShapeC1 = validShapeC / validShapeC0;
std::vector<SymbolicScalar> outputValidShape = {validShapeN, validShapeD, validShapeC1,
validShapeH, validShapeW, validShapeC0};
outValidShapes.push_back(outputValidShape);
std::vector<int64_t> tmpValidShape = op->GetOOperands()[1]->GetShape();
outValidShapes.push_back(SymbolicScalar::FromConcrete(tmpValidShape));
}
void TransDataNC1HWC02NCHWInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
SymbolicScalar validShapeN = inputValidShape[0];
SymbolicScalar validShapeC1 = inputValidShape[1];
SymbolicScalar validShapeH = inputValidShape[2];
SymbolicScalar validShapeW = inputValidShape[3];
SymbolicScalar validShapeC0 = inputValidShape[4];
std::vector<SymbolicScalar> outputValidShape = {validShapeN, validShapeC1 * validShapeC0, validShapeH, validShapeW};
outValidShapes.push_back(outputValidShape);
std::vector<int64_t> tmpValidShape = op->GetOOperands()[1]->GetShape();
outValidShapes.push_back(SymbolicScalar::FromConcrete(tmpValidShape));
}
void TransDataNDC1HWC02NCDHWInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
SymbolicScalar validShapeD = inputValidShape[0];
SymbolicScalar validShapeC1 = inputValidShape[1];
SymbolicScalar validShapeH = inputValidShape[2];
SymbolicScalar validShapeW = inputValidShape[3];
SymbolicScalar validShapeC0 = inputValidShape[4];
std::vector<SymbolicScalar> outputValidShape = {
1, validShapeD, validShapeC1 * validShapeC0, validShapeH, validShapeW};
outValidShapes.push_back(outputValidShape);
std::vector<int64_t> tmpValidShape = op->GetOOperands()[1]->GetShape();
outValidShapes.push_back(SymbolicScalar::FromConcrete(tmpValidShape));
}
void TransDataNCDHW2Fractal_Z_3DInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
SymbolicScalar validShapeN = inputValidShape[0];
SymbolicScalar validShapeC = inputValidShape[1];
SymbolicScalar validShapeD = inputValidShape[2];
SymbolicScalar validShapeH = inputValidShape[3];
SymbolicScalar validShapeW = inputValidShape[4];
SymbolicScalar validShapeC0 = SymbolicScalar(BLOCK_SIZE / BytesOf(op->GetIOperands()[0]->Datatype()));
SymbolicScalar validShapeN0 = SymbolicScalar(16);
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, validShapeC0 > 0) << "The validShapeC0 is not valid !";
SymbolicScalar validShapeC1 = validShapeC / validShapeC0;
SymbolicScalar validShapeN1 = validShapeN / validShapeN0;
std::vector<SymbolicScalar> outputValidShape = {
validShapeD * validShapeC1 * validShapeH * validShapeW, validShapeN1, validShapeN0, validShapeC0};
outValidShapes.push_back(outputValidShape);
std::vector<int64_t> tmpValidShape = op->GetOOperands()[1]->GetShape();
outValidShapes.push_back(SymbolicScalar::FromConcrete(tmpValidShape));
}
void TransDataInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
switch (op->GetOpcode()) {
case Opcode::OP_NCHW2NC1HWC0:
TransDataDefaultInferFunc(op, outValidShapes);
return;
case Opcode::OP_NCHW2Fractal_Z:
TransDataDefaultInferFunc(op, outValidShapes);
return;
case Opcode::OP_NC1HWC02NCHW:
TransDataNC1HWC02NCHWInferFunc(op, outValidShapes);
return;
case Opcode::OP_NCDHW2FRACTAL_Z_3D:
TransDataNCDHW2Fractal_Z_3DInferFunc(op, outValidShapes);
return;
case Opcode::OP_NCDHW2NDC1HWC0:
TransDataNCDHW2NDC1HWC0InferFunc(op, outValidShapes);
return;
case Opcode::OP_NDC1HWC02NCDHW:
TransDataNDC1HWC02NCDHWInferFunc(op, outValidShapes);
return;
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "The transDataType is not supported";
}
}
REGISTER_INFER_SHAPE_FUNC(OP_NCHW2NC1HWC0, Opcode::OP_NCHW2NC1HWC0, TransDataInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NCHW2Fractal_Z, Opcode::OP_NCHW2Fractal_Z, TransDataInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NC1HWC02NCHW, Opcode::OP_NC1HWC02NCHW, TransDataInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NCDHW2NDC1HWC0, Opcode::OP_NCDHW2NDC1HWC0, TransDataInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NCDHW2FRACTAL_Z_3D, Opcode::OP_NCDHW2FRACTAL_Z_3D, TransDataInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_NDC1HWC02NCDHW, Opcode::OP_NDC1HWC02NCDHW, TransDataInferFunc);
void CopyInInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (!(op->GetOOperands()[0]->GetDynValidShape().empty())) {
outValidShapes.push_back(op->GetOOperands()[0]->GetDynValidShape());
if (copyOpAttribute != nullptr && (copyOpAttribute->GetToDynValidShape()).empty()) {
auto toDynShape = OpImmediate::Specified(op->GetOOperands()[0]->GetDynValidShape());
copyOpAttribute->SetToDynValidShape(toDynShape);
}
return;
}
auto toValidShape = copyOpAttribute->GetToDynValidShape();
std::vector<SymbolicScalar> toValidShapeSym(toValidShape.size());
OpImmediate::NormalizeValue(toValidShapeSym, 0, toValidShape, 0, false);
auto toValidShapeValue = SymbolicScalar::Concrete(toValidShapeSym, -1);
auto tileShape = copyOpAttribute->GetShape();
std::vector<SymbolicScalar> tileShapeSym(tileShape.size());
OpImmediate::NormalizeValue(tileShapeSym, 0, tileShape, 0, false);
if (!toValidShape.empty()) {
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(toValidShapeSym);
}
return;
}
if (!(op->GetOOperands()[0]->GetDynValidShape().empty())) {
outValidShapes.push_back(op->GetOOperands()[0]->GetDynValidShape());
auto toDynShape = OpImmediate::Specified(op->GetOOperands()[0]->GetDynValidShape());
copyOpAttribute->SetToDynValidShape(toDynShape);
return;
}
if (op->GetIOperands()[0]->GetProducers().empty()) {
if (toValidShape.empty()) {
std::vector<SymbolicScalar> toValidShapeVec;
for (auto dim : copyOpAttribute->GetShape()) {
toValidShapeVec.push_back(dim.GetSpecifiedValue());
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(toValidShapeVec);
}
}
auto toDynShape = OpImmediate::Specified(outValidShapes[0]);
copyOpAttribute->SetToDynValidShape(toDynShape);
return;
}
std::vector<std::vector<SymbolicScalar>> inputShapes;
for (auto inputTensor : op->GetIOperands()) {
inputShapes.push_back(inputTensor->GetDynValidShape());
}
auto offset = copyOpAttribute->GetFromOffset();
std::vector<SymbolicScalar> oriOffset;
for (auto offsetValue : offset) {
oriOffset.push_back(offsetValue.GetSpecifiedValue());
}
std::vector<SymbolicScalar> outputShape;
for (size_t i = 0U; i < inputShapes[0].size(); i++) {
SymbolicScalar actualDim = std::max(0, std::min((inputShapes[0][i] - oriOffset[i]), tileShapeSym[i]));
outputShape.push_back(actualDim);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outputShape);
}
auto toDynShape = OpImmediate::Specified(outputShape);
copyOpAttribute->SetToDynValidShape(toDynShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_COPY_IN, Opcode::OP_COPY_IN, CopyInInferFunc);
void ReshapeCopyInInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
if (!op->GetIOperands().empty() && !op->GetIOperands()[0]->GetDynValidShape().empty()) {
auto validShape = op->GetIOperands().front()->GetDynValidShape();
outValidShapes.push_back(validShape);
auto copyAttr = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyAttr == nullptr) {
return;
}
copyAttr->SetFromDynValidShape(OpImmediate::Specified(validShape));
}
}
REGISTER_INFER_SHAPE_FUNC(OP_RESHAPE_COPY_IN, Opcode::OP_RESHAPE_COPY_IN, ReshapeCopyInInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L1_RESHAPE_COPY_IN, Opcode::OP_L1_RESHAPE_COPY_IN, ReshapeCopyInInferFunc);
void ReshapeCopyOutInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
if (!op->GetOOperands().empty() && !op->GetOOperands()[0]->GetDynValidShape().empty()) {
auto validshape = op->GetOOperands()[0]->GetDynValidShape();
outValidShapes.push_back(validshape);
auto copyAttr = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyAttr == nullptr) {
return;
}
copyAttr->SetToDynValidShape(OpImmediate::Specified(validshape));
}
}
REGISTER_INFER_SHAPE_FUNC(OP_RESHAPE_COPY_OUT, Opcode::OP_RESHAPE_COPY_OUT, ReshapeCopyOutInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L0C_RESHAPE_COPY_OUT, Opcode::OP_L0C_RESHAPE_COPY_OUT, ReshapeCopyOutInferFunc);
void ShmemPutInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (op->iOperand[2]->GetDynOffset().size() != 0) {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[2]->GetDynOffset()));
} else {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[2]->GetOffset()));
}
std::vector<SymbolicScalar> shmemPutDynValidShape(copyOpAttribute->GetFromDynValidShape().size());
OpImmediate::NormalizeValue(shmemPutDynValidShape, 0, copyOpAttribute->GetFromDynValidShape(), 0, false);
for (size_t i = 0; i < op->GetOOperands().size(); i++) {
if (i == 0) {
outValidShapes.push_back(shmemPutDynValidShape);
} else {
outValidShapes.push_back(SymbolicScalar::FromConcrete(op->GetOOperands()[i]->GetShape()));
}
}
}
REGISTER_INFER_SHAPE_FUNC(OP_SHMEM_PUT, Opcode::OP_SHMEM_PUT, ShmemPutInferFunc);
void ShmemGetInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
std::vector<SymbolicScalar> shmemGetDynValidShape(copyOpAttribute->GetFromDynValidShape().size());
OpImmediate::NormalizeValue(shmemGetDynValidShape, 0, copyOpAttribute->GetFromDynValidShape(), 0, false);
for (size_t i = 0; i < op->GetOOperands().size(); i++) {
if (i == 0) {
outValidShapes.push_back(shmemGetDynValidShape);
} else {
outValidShapes.push_back(SymbolicScalar::FromConcrete(op->GetOOperands()[i]->GetShape()));
}
}
}
REGISTER_INFER_SHAPE_FUNC(OP_SHMEM_GET, Opcode::OP_SHMEM_GET, ShmemGetInferFunc);
void ShmemStoreInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (op->iOperand[1]->GetDynOffset().size() != 0) {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[1]->GetDynOffset()));
} else {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[1]->GetOffset()));
}
std::vector<SymbolicScalar> shmemStoreDynValidShape(copyOpAttribute->GetFromDynValidShape().size());
OpImmediate::NormalizeValue(shmemStoreDynValidShape, 0, copyOpAttribute->GetFromDynValidShape(), 0, false);
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(shmemStoreDynValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_SHMEM_STORE, Opcode::OP_SHMEM_STORE, ShmemStoreInferFunc);
void ShmemLoadInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (op->iOperand[1]->GetDynOffset().size() != 0) {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[1]->GetDynOffset()));
} else {
copyOpAttribute->SetFromOffset(OpImmediate::Specified(op->iOperand[1]->GetOffset()));
}
std::vector<SymbolicScalar> shmemLoadDynValidShape(copyOpAttribute->GetToDynValidShape().size());
OpImmediate::NormalizeValue(shmemLoadDynValidShape, 0, copyOpAttribute->GetToDynValidShape(), 0, false);
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(shmemLoadDynValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_SHMEM_LOAD, Opcode::OP_SHMEM_LOAD, ShmemLoadInferFunc);
void CopyOutInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyOpAttribute != nullptr) {
copyOpAttribute->SetFromDynValidShape(OpImmediate::Specified(op->GetIOperands()[0]->GetDynValidShape()));
} else {
VECTOR_LOGW("Copyout [%d] has no copy out attr.", op->GetOpMagic());
outValidShapes.push_back(op->GetIOperands()[0]->GetDynValidShape());
return;
}
bool needInferShape = false;
if (!(op->GetOOperands()[0]->GetDynValidShape().empty()) &&
!op->GetOOperands()[0]->GetAttr(COPY_OUT_FORCE_INFER_SHAPE, needInferShape)) {
outValidShapes.push_back(op->GetOOperands()[0]->GetDynValidShape());
return;
}
op->GetOOperands()[0]->SetAttr(COPY_OUT_FORCE_INFER_SHAPE, true);
auto offset = copyOpAttribute->GetToOffset();
std::vector<SymbolicScalar> oriOffset;
for (auto offsetValue : offset) {
oriOffset.push_back(offsetValue.GetSpecifiedValue());
}
std::vector<std::vector<SymbolicScalar>> inputShapes;
std::vector<std::vector<int64_t>> staticInputShapes;
for (auto inputTensor : op->GetIOperands()) {
inputShapes.push_back(inputTensor->GetDynValidShape());
staticInputShapes.push_back(inputTensor->GetShape());
}
std::vector<SymbolicScalar> outDynShape = op->GetOOperands()[0]->GetDynValidShape();
if (outDynShape.empty()) {
for (size_t i = 0; i < op->GetOOperands()[0]->GetShape().size(); ++i) {
outDynShape.push_back(SymbolicScalar(0));
}
}
std::vector<SymbolicScalar> outShape;
for (size_t i = 0U; i < inputShapes[0].size(); i++) {
SymbolicScalar actualDim;
if (staticInputShapes[0][i] ==
op->GetOOperands()[0]->GetShape()[i]) {
actualDim = std::max(SymbolicScalar(0), inputShapes[0][i] + oriOffset[i]);
} else {
actualDim = std::max(outDynShape[i], (inputShapes[0][i] + oriOffset[i]) * (inputShapes[0][i] != 0));
}
outShape.push_back(actualDim);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_COPY_OUT, Opcode::OP_COPY_OUT, CopyOutInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_UB_COPY_L1, Opcode::OP_UB_COPY_L1, CopyOutInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_L0C_COPY_UB, Opcode::OP_L0C_COPY_UB, CopyOutInferFunc);
void TransposeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
for (auto output : op->GetOOperands()) {
std::vector<SymbolicScalar> res;
res.insert(res.end(), inputValidShapes[0].begin(), inputValidShapes[0].end());
auto axises = op->GetVectorIntAttribute<int>(OP_ATTR_PREFIX + "shape");
size_t index0 = axises[0];
size_t index1 = axises[1];
if (index0 < res.size() && index1 < res.size()) {
std::swap(res[index0], res[index1]);
}
outValidShapes.push_back(res);
}
if (op->GetOpcode() == Opcode::OP_TRANSPOSE_MOVEIN) {
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyOpAttribute != nullptr) {
copyOpAttribute->SetToDynValidShape(OpImmediate::Specified(outValidShapes[0]));
}
}
if (op->GetOpcode() == Opcode::OP_TRANSPOSE_MOVEOUT) {
auto copyOpAttribute = std::dynamic_pointer_cast<CopyOpAttribute>(op->GetOpAttribute());
if (copyOpAttribute != nullptr) {
copyOpAttribute->SetFromDynValidShape(OpImmediate::Specified(outValidShapes[0]));
std::vector<SymbolicScalar> outDynShape = op->GetOOperands()[0]->GetDynValidShape();
if (!outDynShape.empty()) {
auto dynOffset = copyOpAttribute->GetToOffset();
std::vector<SymbolicScalar> outShape;
for (size_t i = 0U; i < dynOffset.size(); i++) {
SymbolicScalar actualDim = std::max(
outDynShape[i],
(outValidShapes[0][i] + dynOffset[i].GetSpecifiedValue()) * (outValidShapes[0][i] != 0));
outShape.push_back(actualDim);
}
outValidShapes[0] = outShape;
}
}
}
}
REGISTER_INFER_SHAPE_FUNC(OP_TRANSPOSE_VNCHWCONV, Opcode::OP_TRANSPOSE_VNCHWCONV, TransposeInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_TRANSPOSE_MOVEIN, Opcode::OP_TRANSPOSE_MOVEIN, TransposeInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_TRANSPOSE_MOVEOUT, Opcode::OP_TRANSPOSE_MOVEOUT, TransposeInferFunc);
void PermuteInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<SymbolicScalar> validShape;
const auto& oOperands = op->GetOOperands();
std::vector<SymbolicScalar> inputValidShape;
if (!op->GetIOperands().empty()) {
inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
}
if (inputValidShape.empty()) {
return;
}
std::vector<int> perm = op->GetVectorIntAttribute<int>(OpAttributeKey::perm);
std::vector<SymbolicScalar> resultValidShape;
resultValidShape.reserve(perm.size());
for (int axis : perm) {
resultValidShape.push_back(inputValidShape[axis]);
}
if (op->GetAttr(OP_ATTR_PREFIX + "validShape", validShape) && !validShape.empty()) {
outValidShapes.push_back(validShape);
} else if (!resultValidShape.empty()) {
outValidShapes.push_back(resultValidShape);
}
for (size_t idx = 1; idx < oOperands.size(); ++idx) {
auto outputValidShape = oOperands[idx]->GetDynValidShape();
if (outputValidShape.empty()) {
outputValidShape = resultValidShape;
}
outValidShapes.push_back(outputValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_PERMUTE, Opcode::OP_PERMUTE, PermuteInferFunc);
REGISTER_INFER_SHAPE_FUNC(OP_PERMUTE_ELEMENT, Opcode::OP_PERMUTE_ELEMENT, PermuteInferFunc);
void ViewInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto viewOpAttribute = std::dynamic_pointer_cast<ViewOpAttribute>(op->GetOpAttribute());
if (viewOpAttribute == nullptr) {
VECTOR_LOGW("View [%d] has no view attr.", op->GetOpMagic());
outValidShapes.push_back(op->GetIOperands()[0]->GetDynValidShape());
return;
}
auto toValidShape = viewOpAttribute->GetToDynValidShape();
if (!toValidShape.empty()) {
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(toValidShape);
}
} else {
auto inputValidShape = op->GetIOperands()[0]->GetDynValidShape();
if (inputValidShape.empty()) {
auto shapeImm = OpImmediate::Specified(op->GetIOperands()[0]->GetShape());
inputValidShape.resize(shapeImm.size());
OpImmediate::NormalizeValue(inputValidShape, 0, shapeImm, 0, false);
}
auto newDynValidShape = GetViewValidShape(
inputValidShape, viewOpAttribute->GetFromOffset(), viewOpAttribute->GetFromDynOffset(),
op->GetOOperands()[0]->oriShape);
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(newDynValidShape);
}
viewOpAttribute->SetToDynValidShape(newDynValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_VIEW, Opcode::OP_VIEW, ViewInferFunc);
void AssembleInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto assembleOpAttribute = std::dynamic_pointer_cast<AssembleOpAttribute>(op->GetOpAttribute());
if (assembleOpAttribute != nullptr) {
auto fromValidShape = op->GetIOperands()[0]->GetDynValidShape();
assembleOpAttribute->SetFromDynValidShape(fromValidShape);
} else {
VECTOR_LOGW("Copyout [%d] has no copy out attr.", op->GetOpMagic());
outValidShapes.push_back(op->GetIOperands()[0]->GetDynValidShape());
return;
}
auto offset = assembleOpAttribute->GetToOffset();
auto inputShapes = op->GetIOperands()[0]->GetDynValidShape();
std::vector<SymbolicScalar> outDynShape = op->GetOOperands()[0]->GetDynValidShape();
if (outDynShape.empty()) {
for (size_t i = 0; i < op->GetOOperands()[0]->GetShape().size(); ++i) {
outDynShape.push_back(SymbolicScalar(0));
}
}
std::vector<SymbolicScalar> outShape;
for (size_t i = 0U; i < inputShapes.size(); i++) {
SymbolicScalar actualDim;
if (offset[i] == 0) {
actualDim = std::max(outDynShape[i], inputShapes[i]);
} else {
actualDim = std::max(outDynShape[i], (inputShapes[i] + offset[i]) * (inputShapes[i] != 0));
}
outShape.push_back(actualDim);
}
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_ASSEMBLE, Opcode::OP_ASSEMBLE, AssembleInferFunc);
const std::string TOPK_AXIS = OP_ATTR_PREFIX + "axis";
const std::string TOPK_ORDER = OP_ATTR_PREFIX + "order";
const std::string TOPK_ALGO = OP_ATTR_PREFIX + "algo";
const std::string TOPK_KVALUE = OP_ATTR_PREFIX + "kvalue";
const std::string EXTRACT_MASKMODE = OP_ATTR_PREFIX + "makeMode";
const std::string SORT_AXIS = OP_ATTR_PREFIX + "axis";
constexpr int32_t blockSize = 32;
constexpr int32_t kFactorSize = 4;
constexpr int32_t kBlockFpNum = 8;
void BitSortFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
auto topk_axis = op->GetIntAttribute(TOPK_AXIS);
res[topk_axis] = res[topk_axis] * NUM2;
outValidShapes.emplace_back(res);
if (inputValidShapes[0].size() == 1) {
outValidShapes.emplace_back(std::vector<SymbolicScalar>{res[topk_axis]});
} else {
outValidShapes.emplace_back(std::vector<SymbolicScalar>{1, res[topk_axis]});
}
}
REGISTER_INFER_SHAPE_FUNC(OP_BITSORT, Opcode::OP_BITSORT, BitSortFunc);
void MrgSortFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
auto topk_axis = op->GetIntAttribute(TOPK_AXIS);
auto topk_kvalue = op->GetIntAttribute(TOPK_KVALUE);
res[topk_axis] = std::min(res[topk_axis], (topk_kvalue + kBlockFpNum - 1) / kBlockFpNum * kBlockFpNum * NUM2);
outValidShapes.emplace_back(res);
if (inputValidShapes[0].size() == 1) {
outValidShapes.emplace_back(std::vector<SymbolicScalar>{res[topk_axis]});
} else {
outValidShapes.emplace_back(std::vector<SymbolicScalar>{1, res[topk_axis]});
}
}
REGISTER_INFER_SHAPE_FUNC(OP_MRGSORT, Opcode::OP_MRGSORT, MrgSortFunc);
void TiledMrgSortFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
auto outValidShape = inputValidShapes[0];
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_TILEDMRGSORT, Opcode::OP_TILEDMRGSORT, TiledMrgSortFunc);
void ExtractFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
res.back() = op->GetIntAttribute(TOPK_KVALUE);
outValidShapes.push_back(res);
}
REGISTER_INFER_SHAPE_FUNC(OP_EXTRACT, Opcode::OP_EXTRACT, ExtractFunc);
void QuantMXInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
constexpr int64_t kQuantMXGroupSize = 32;
constexpr int64_t kQuantMXFp32ScalingFactor = 2;
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> groupedValidShape;
std::vector<SymbolicScalar> scalingValidShape;
if (op->GetIntAttribute(OpAttributeKey::mxQuantPerformanceMode) != 0) {
if (inputValidShapes[0].size() == 1) {
groupedValidShape = {(inputValidShapes[0][0] + kQuantMXGroupSize - 1) / kQuantMXGroupSize};
} else {
groupedValidShape.reserve(inputValidShapes[0].size() - 1);
for (size_t i = 0; i + 2 < inputValidShapes[0].size(); ++i) {
groupedValidShape.push_back(inputValidShapes[0][i]);
}
groupedValidShape.push_back(
inputValidShapes[0][inputValidShapes[0].size() - 2] *
((inputValidShapes[0].back() + kQuantMXGroupSize - 1) / kQuantMXGroupSize));
}
scalingValidShape = groupedValidShape;
if (op->GetIOperands()[0]->Datatype() == DataType::DT_FP32) {
scalingValidShape.back() = scalingValidShape.back() * kQuantMXFp32ScalingFactor;
}
} else {
groupedValidShape = inputValidShapes[0];
groupedValidShape.back() = (groupedValidShape.back() + kQuantMXGroupSize - 1) / kQuantMXGroupSize;
scalingValidShape = inputValidShapes[0];
}
outValidShapes.push_back(inputValidShapes[0]);
outValidShapes.push_back(groupedValidShape);
outValidShapes.push_back(groupedValidShape);
outValidShapes.push_back(scalingValidShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_QUANT_MX, Opcode::OP_QUANT_MX, QuantMXInferFunc);
void RadixSelectFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto input = op->GetIOperands()[0];
std::vector<SymbolicScalar> res(input->GetDynValidShape());
res.back() = op->GetIntAttribute(TOPK_KVALUE);
outValidShapes.push_back(res);
outValidShapes.push_back(res);
std::vector<SymbolicScalar> tmpValidShape;
std::vector<int64_t> srcShape = input->GetShape();
auto lastDim = srcShape[srcShape.size() - 1];
tmpValidShape.emplace_back(
static_cast<int64_t>(NUM_VALUE_2 * lastDim * BytesOf(input->Datatype())) +
static_cast<int64_t>(NUM_VALUE_6 * lastDim) + static_cast<int64_t>(NUM_VALUE_1024) +
static_cast<int64_t>(NUM_VALUE_1024 > NUM_VALUE_8 * lastDim ? NUM_VALUE_1024 : NUM_VALUE_8 * lastDim));
outValidShapes.push_back(tmpValidShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_RADIX_SELECT, Opcode::OP_RADIX_SELECT, RadixSelectFunc);
void VecDupInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& validShapes)
{
std::vector<SymbolicScalar> validShape;
op->GetAttr(OP_ATTR_PREFIX + "validShape", validShape);
validShapes.push_back(validShape);
}
REGISTER_INFER_SHAPE_FUNC(OP_VEC_DUP, Opcode::OP_VEC_DUP, VecDupInferFunc);
void ReshapeInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& validShapes)
{
std::vector<SymbolicScalar> validShape;
if (op->GetAttr(OP_ATTR_PREFIX + "validShape", validShape) && validShape.size() != 0) {
validShapes.push_back(validShape);
} else {
auto dstShape = op->GetOOperands()[0]->GetShape();
validShapes.push_back(SymbolicScalar::FromConcrete(dstShape));
}
}
REGISTER_INFER_SHAPE_FUNC(OP_RESHAPE, Opcode::OP_RESHAPE, ReshapeInferFunc);
void BrcbInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
auto dimSize = op->GetIOperands()[0]->GetDynValidShape().size();
std::vector<SymbolicScalar> outValidShape;
for (size_t i = 0; i < dimSize - 1; i++) {
outValidShape.push_back(op->GetIOperands()[0]->GetDynValidShape()[i]);
}
int64_t lastDimShape = blockSize / BytesOf(op->GetIOperands()[0]->Datatype());
outValidShape.push_back(lastDimShape);
for (auto output : op->GetOOperands()) {
outValidShapes.push_back(outValidShape);
}
}
REGISTER_INFER_SHAPE_FUNC(OP_BRCB, Opcode::OP_BRCB, BrcbInferFunc);
void TwoTileMrgSortFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
outValidShapes.push_back(res);
}
REGISTER_INFER_SHAPE_FUNC(OP_TWOTILEMRGSORT, Opcode::OP_TWOTILEMRGSORT, TwoTileMrgSortFunc);
void ExtractSingleFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
std::vector<std::vector<SymbolicScalar>> inputValidShapes;
for (auto inputTensor : op->GetIOperands()) {
inputValidShapes.push_back(inputTensor->GetDynValidShape());
}
if (inputValidShapes.empty()) {
return;
}
std::vector<SymbolicScalar> res(inputValidShapes[0]);
res.back() = res.back() / 2;
outValidShapes.push_back(res);
}
REGISTER_INFER_SHAPE_FUNC(OP_EXTRACT_SINGLE, Opcode::OP_EXTRACT_SINGLE, ExtractSingleFunc);
void PReLUInferFunc(Operation* op, std::vector<std::vector<SymbolicScalar>>& outValidShapes)
{
ASSERT(VectorErrorCode::ERR_PARAM_COUNT_INVALID, op->GetIOperands().size() == 2)
<< "PReLU input operand size should be 2";
ASSERT(VectorErrorCode::ERR_PARAM_COUNT_INVALID, op->GetOOperands().size() == 2)
<< "PReLU output operand size should be 2";
auto input0 = op->GetIOperands()[0];
std::vector<SymbolicScalar> output0ValidShape = input0->GetDynValidShape();
std::vector<SymbolicScalar> output1ValidShape;
auto input0ShapeDim = input0->GetDynValidShape().size();
if (input0ShapeDim == 2) {
output1ValidShape.emplace_back(input0->GetDynValidShape().back());
} else {
constexpr int64_t ALIGN_SIZE = 32;
int64_t elementCount = ALIGN_SIZE / BytesOf(input0->Datatype());
output1ValidShape.emplace_back(elementCount);
}
outValidShapes.emplace_back(std::move(output0ValidShape));
outValidShapes.emplace_back(std::move(output1ValidShape));
}
REGISTER_INFER_SHAPE_FUNC(OP_PRELU, Opcode::OP_PRELU, PReLUInferFunc);
}