* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file reduction.cpp
* \brief
*/
#include "unary.h"
#include "interface/configs/config_manager.h"
#include "interface/utils/operator_tracer.h"
#include "passes/pass_utils/pass_utils.h"
#include "tilefwk/error_code.h"
#include "tilefwk/tilefwk_op.h"
#include "tensor_transformation.h"
namespace npu::tile_fwk {
enum class ReduceType {
NORMAL,
EXPAND,
SINGLE,
};
void TileReduceNew(
Function& function, const TileShape& tileShape, const std::string& op, ReduceType reduceType,
const LogicalTensorPtr& in, const LogicalTensorPtr& result, int axis = -1)
{
axis = axis < 0 ? in->shape.size() + axis : axis;
std::vector<int64_t> tileAshape = in->shape;
std::vector<int64_t> tileBshape = in->shape;
std::vector<int64_t> regShape = in->shape;
std::vector<int64_t> regAshape = in->shape;
std::vector<int64_t> regBshape = in->shape;
std::vector<int64_t> regOffset(regShape.size(), 0);
std::vector<int64_t> tileAoffset(regShape.size(), 0);
std::vector<int64_t> tileBoffset(regShape.size(), 0);
std::vector<int64_t> remainderShape = in->shape;
std::vector<int64_t> remainderOffset(remainderShape.size(), 0);
auto opNew = op;
if (opNew == "MAX_COMBINE_AXIS") {
opNew = "MAX";
}
if (opNew == "SUM_COMBINE_AXIS") {
opNew = "SUM";
}
auto source = in;
auto vecTile = tileShape.GetVecTile();
int64_t width = (source->shape[axis] + vecTile[axis] - 1) / vecTile[axis] * vecTile[axis];
int padSize = width - source->shape[axis];
int remainder = 0;
int p2width = vecTile[axis];
while (width >= p2width) {
p2width = p2width << 1;
}
p2width = p2width >> 1;
remainder = width - p2width;
remainderShape[axis] = remainder;
remainderOffset[axis] = p2width;
width = p2width;
while (width >= NUM2 * vecTile[axis])
{
width = width >> 1;
tileAshape[axis] = width;
tileBshape[axis] = std::min(width, source->shape[axis] - width);
tileBoffset[axis] = width;
auto tileA = source->View(function, tileAshape, tileAoffset);
auto tileB = source->View(function, tileBshape, tileBoffset);
auto resultA = std::make_shared<LogicalTensor>(
function, in->Datatype(), reduceType == npu::tile_fwk::ReduceType::EXPAND ? result->shape : source->shape,
reduceType == npu::tile_fwk::ReduceType::EXPAND ? result->GetDynValidShape() : source->GetDynValidShape());
for (int j = 0; j < width; j += vecTile[axis]) {
regAshape[axis] = vecTile[axis];
regBshape[axis] = std::min(vecTile[axis], tileB->shape[axis] - j);
regOffset[axis] = j;
auto regA = tileA->View(function, regAshape, regOffset);
auto regB = tileB->View(function, regBshape, regOffset);
auto regResult = resultA->View(function, regAshape, regOffset);
function.AddOperation("TILE_PAIR" + opNew, {regA, regB}, {regResult});
}
if (remainder < width) {
source = resultA;
continue;
}
if ((remainderShape[axis] + remainderOffset[axis] > in->shape[axis])) {
remainderShape[axis] = remainderShape[axis] - padSize;
}
auto tileRemainder = in->View(function, remainderShape, remainderOffset);
auto resultAnext =
std::make_shared<LogicalTensor>(function, in->Datatype(), resultA->shape, resultA->GetDynValidShape());
for (int j = 0; j < width; j += vecTile[axis]) {
regAshape[axis] = vecTile[axis];
regBshape[axis] = std::min(vecTile[axis], tileRemainder->shape[axis] - j);
regOffset[axis] = j;
auto regA = resultA->View(function, regAshape, regOffset);
auto regB = tileRemainder->View(function, regBshape, regOffset);
auto regResult = resultAnext->View(function, regAshape, regOffset);
function.AddOperation("TILE_PAIR" + opNew, {regA, regB}, {regResult});
}
remainder -= width;
remainderOffset[axis] += width;
remainderShape[axis] -= width;
source = resultAnext;
}
regShape[axis] = std::min(in->shape[axis], vecTile[axis]);
regOffset[axis] = 0;
auto temp = std::make_shared<LogicalTensor>(
function, in->Datatype(), reduceType == npu::tile_fwk::ReduceType::EXPAND ? result->shape : source->shape,
reduceType == npu::tile_fwk::ReduceType::EXPAND ? result->GetDynValidShape() : source->GetDynValidShape());
auto sourceReg = source->View(function, regShape, regOffset);
switch (reduceType) {
case npu::tile_fwk::ReduceType::NORMAL: {
auto resultReg = result->View(function, regShape, regOffset);
function.AddOperation("TILE_ROW" + op, {sourceReg}, {resultReg});
break;
}
case npu::tile_fwk::ReduceType::EXPAND: {
auto resultReg = temp->View(function, regShape, regOffset);
function.AddOperation("TILE_ROWEXP" + op, {sourceReg}, {resultReg});
auto resultReg1 = temp->View(function, regShape, regOffset);
for (int j = 0; j < result->shape[1]; j += vecTile[1])
{
regShape[0] = in->shape[0];
regShape[1] = vecTile[1];
regOffset[0] = 0;
regOffset[1] = j;
resultReg = result->View(function, regShape, regOffset);
function.AddOperation("TILE_REGISTER_COPY", {resultReg1}, {resultReg});
}
break;
}
case npu::tile_fwk::ReduceType::SINGLE: {
std::vector<int64_t> tmpShape = {1, static_cast<int>(BLOCK_SIZE / BytesOf(in->Datatype()))};
if (op == "SUM" || (static_cast<size_t>(axis) == (in->shape.size() - 1))) {
if (static_cast<size_t>(axis) == (in->shape.size() - 1)) {
tmpShape[0] = sourceReg->shape[std::max(0, axis - 1)];
if (static_cast<size_t>(sourceReg->shape[axis]) <= REPEAT_BYTE / BytesOf(in->Datatype())) {
tmpShape[0] = 1;
} else if (
static_cast<size_t>(sourceReg->shape[axis]) <= NUM2 * REPEAT_BYTE / BytesOf(in->Datatype())) {
tmpShape[1] = REPEAT_BYTE / BytesOf(in->Datatype());
} else {
tmpShape[1] = (((sourceReg->shape[axis] * BytesOf(in->Datatype())) / REPEAT_BYTE) / NUM2) *
REPEAT_BYTE / BytesOf(in->Datatype());
}
if (in->shape.size() == 1) {
tmpShape = {tmpShape[1]};
}
auto tempTensor = std::make_shared<LogicalTensor>(function, in->Datatype(), tmpShape);
tempTensor->dynValidShape_ = SymbolicScalar::FromConcrete(tmpShape);
auto& newOp = function.AddOperation("TILE_ROW" + op + "_SINGLE", {sourceReg}, {result, tempTensor});
newOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", axis);
} else {
tmpShape[0] = (sourceReg->shape[axis] + 1) / NUM2;
tmpShape[1] = (sourceReg->shape[in->shape.size() - 1] + BLOCK_NUM - 1) / BLOCK_NUM * BLOCK_NUM;
auto tempTensor = std::make_shared<LogicalTensor>(function, in->Datatype(), tmpShape);
tempTensor->dynValidShape_ = SymbolicScalar::FromConcrete(tmpShape);
auto& newOp = function.AddOperation("TILE_ROW" + op + "LINE", {sourceReg}, {result, tempTensor});
newOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", axis);
}
} else {
auto& newOp = function.AddOperation("TILE_ROW" + op + "LINE", {sourceReg}, {result});
newOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", axis);
}
break;
}
default:
break;
}
}
void TileArgReduce(
Function& function, const TileShape& tileShape, const std::string& op,
const LogicalTensorPtr& in, const LogicalTensorPtr& result, int axis = -1)
{
auto indexDtype = in->Datatype() == DataType::DT_FP32 ? DataType::DT_INT32 : DataType::DT_INT16;
axis = axis < 0 ? in->shape.size() + axis : axis;
auto& vecTile = tileShape.GetVecTile();
std::vector<int64_t> tileValueShape = result->shape;
std::vector<int64_t> tileSourceShape = in->shape;
std::vector<int64_t> tileSourceOffset(tileSourceShape.size(), 0);
std::vector<LogicalTensorPtr> valueList;
std::vector<LogicalTensorPtr> indexList;
std::vector<SymbolicScalar> inValidShape = in->GetDynValidShape();
std::vector<SymbolicScalar> argReudceValidShape;
if (!inValidShape.empty()) {
argReudceValidShape = inValidShape;
argReudceValidShape[axis] = SymbolicScalar(1);
}
for (int i = 0; i < in->shape[axis]; i += vecTile[axis]) {
tileSourceShape[axis] = std::min(vecTile[axis], in->shape[axis] - i);
tileSourceOffset[axis] = i;
auto inputTile = in->View(function, tileSourceShape, tileSourceOffset);
auto valueTile = std::make_shared<LogicalTensor>(function, in->Datatype(), tileValueShape, argReudceValidShape);
auto indexTile = std::make_shared<LogicalTensor>(function, indexDtype, tileValueShape, argReudceValidShape);
std::vector<int64_t> tmpShape;
if (inputTile->shape.size() == 1) {
tmpShape = {tileSourceShape[axis]};
} else {
tmpShape = (static_cast<size_t>(axis) == (inputTile->shape.size() - 1)) ? std::vector<int64_t>{tileSourceShape[axis - 1], tileSourceShape[axis]} :
std::vector<int64_t>{1, static_cast<int64_t>(REPEAT_BYTE / BytesOf(in->Datatype()) * NUM3)};
}
auto argreduceTempTensor = std::make_shared<LogicalTensor>(function, in->Datatype(), tmpShape, argReudceValidShape);
if (static_cast<size_t>(axis) == (inputTile->shape.size() - 1)) {
auto& argreduceOp = function.AddOperation("TILE_ROW" + op + "WITHVALUE_SINGLE", {inputTile}, {valueTile, indexTile, argreduceTempTensor});
argreduceOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", axis);
} else {
auto& argreduceOp = function.AddOperation("TILE_ROW" + op + "WITHVALUE_LINE", {inputTile}, {valueTile, indexTile, argreduceTempTensor});
argreduceOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", axis);
}
auto addsTile = std::make_shared<LogicalTensor>(function, indexDtype, tileValueShape, argReudceValidShape);
auto& indexUpdateOp = function.AddOperation(Opcode::OP_ADDS, {indexTile}, {addsTile});
indexUpdateOp.SetAttribute(OpAttributeKey::scalar, Element(indexDtype, i));
indexUpdateOp.SetAttribute(OP_ATTR_PREFIX + "reverseOperand", false);
valueList.push_back(valueTile);
indexList.push_back(addsTile);
}
auto axisTileNum = (in->shape[axis] + vecTile[axis] - 1) / vecTile[axis];
std::vector<LogicalTensorPtr> workValueList = valueList;
std::vector<LogicalTensorPtr> workIndexList = indexList;
int currentSize = axisTileNum;
while (currentSize > 1) {
int halfSize = (currentSize + 1) / NUM_VALUE_2;
std::vector<LogicalTensorPtr> nextValueList;
std::vector<LogicalTensorPtr> nextIndexList;
for (int i = 0; i < currentSize; i += NUM_VALUE_2) {
if ((currentSize - i) == 1) {
nextValueList.push_back(workValueList[i]);
nextIndexList.push_back(workIndexList[i]);
} else {
auto pairValue = std::make_shared<LogicalTensor>(function, in->Datatype(), workValueList[0]->shape, workValueList[0]->GetDynValidShape());
auto pairIndex = std::make_shared<LogicalTensor>(function, indexDtype, workIndexList[0]->shape, workIndexList[0]->GetDynValidShape());
auto pairOpcode = op == "ARGMAX" ? Opcode::OP_PAIRARGMAX : Opcode::OP_PAIRARGMIN;
function.AddOperation(
pairOpcode,
{workValueList[i], workIndexList[i], workValueList[i + 1], workIndexList[i + 1]},
{pairValue, pairIndex});
nextValueList.push_back(pairValue);
nextIndexList.push_back(pairIndex);
}
}
workValueList = std::move(nextValueList);
workIndexList = std::move(nextIndexList);
currentSize = halfSize;
}
if (indexDtype == DataType::DT_INT16 && result->Datatype() == DataType::DT_INT32) {
auto fp32Index = std::make_shared<LogicalTensor>(function, DataType::DT_FP32, result->shape, workIndexList[0]->GetDynValidShape());
auto& castToFp32 = function.AddOperation(Opcode::OP_CAST, {workIndexList[0]}, {fp32Index});
castToFp32.SetAttribute(OP_ATTR_PREFIX + "mode", CastMode::CAST_NONE);
castToFp32.SetAttribute(OP_ATTR_PREFIX + "satmode", static_cast<int64_t>(SaturationMode::OFF));
auto& castOp = function.AddOperation(Opcode::OP_CAST, {fp32Index}, {result});
castOp.SetAttribute(OP_ATTR_PREFIX + "mode", CastMode::CAST_NONE);
castOp.SetAttribute(OP_ATTR_PREFIX + "satmode", static_cast<int64_t>(SaturationMode::OFF));
} else {
auto& indexUpdateOp = function.AddOperation(Opcode::OP_ADDS, {workIndexList[0]}, {result});
indexUpdateOp.SetAttribute(OpAttributeKey::scalar, Element(indexDtype, 0));
indexUpdateOp.SetAttribute(OP_ATTR_PREFIX + "reverseOperand", false);
}
return;
}
void ReduceSingle(
size_t cur, const std::string& op, Input& input, const LogicalTensorPtr result, TileInfo& resultTileInfo, int axis,
Function& function, const TileShape& tileShape, std::vector<int> order)
{
if (order[cur] == axis && cur < order.size() - 1) {
std::swap(order[cur], order[cur + 1]);
}
if (order[cur] == axis) {
auto inputTile = input.tensor.GetStorage()->View(function, input.tileInfo.shape, input.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
if (op == "ARGMAX" || op == "ARGMIN") {
TileArgReduce(function, tileShape, op, inputTile, resultTile, axis);
} else {
TileReduceNew(function, tileShape, op, npu::tile_fwk::ReduceType::SINGLE, inputTile, resultTile, axis);
}
return;
}
auto vecTile = tileShape.GetVecTile();
for (int i = 0; i < result->shape[order[cur]]; i += vecTile[order[cur]]) {
resultTileInfo.offset[order[cur]] = i;
resultTileInfo.shape[order[cur]] =
std::min(result->shape[order[cur]] - resultTileInfo.offset[order[cur]], vecTile[order[cur]]);
input.tileInfo.offset[order[cur]] = i % input.tensor.GetShape()[order[cur]];
input.tileInfo.shape[order[cur]] =
std::min(input.tensor.GetShape()[order[cur]] - input.tileInfo.offset[order[cur]], vecTile[order[cur]]);
ReduceSingle(cur + 1, op, input, result, resultTileInfo, axis, function, tileShape, order);
}
}
void TiledReduceSingle(
Function& function, const TileShape& tileShape, const std::string& op, const LogicalTensorPtr& operand,
const LogicalTensorPtr& result, int axis)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, op == "MAX" || op == "MIN" || op == "SUM" || op == "PROD" ||
op == "ARGMAX" || op == "ARGMIN" || op == "MAX_COMBINE_AXIS" ||
op == "SUM_COMBINE_AXIS")
<< "Not support op:" << op;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, operand->shape.size() == operand->offset.size())
<< "The shape size of operand and offset should be equal";
if (axis < 0) {
axis = operand->shape.size() + axis;
}
TileInfo tileInfo(operand->shape, operand->offset);
TileInfo resultTileInfo(result->shape, result->offset);
auto input = Input{operand, tileInfo};
std::vector<int> defaultAxisOrder;
for (size_t i = 0; i < operand->shape.size(); i++) {
defaultAxisOrder.push_back(i);
}
ReduceSingle(0, op, input, result, resultTileInfo, axis, function, tileShape, defaultAxisOrder);
}
[[maybe_unused]] void TensorReduceSingle(
Function& function, const std::string& op, const Tensor& operand, Tensor& result, int axis)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, op == "MAX" || op == "MIN" || op == "SUM" || op == "PROD" ||
op == "ARGMAX" || op == "ARGMIN" || op == "MAX_COMBINE_AXIS" ||
op == "SUM_COMBINE_AXIS")
<< "Not support op:" << op;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, operand.GetShape().size() == operand.GetStorage()->offset.size())
<< "The shape size of operand and offset should be equal";
auto opCode = Opcode::OP_ROWMAX_SINGLE;
if (op == "MAX") {
opCode = Opcode::OP_ROWMAX_SINGLE;
} else if (op == "MIN") {
opCode = Opcode::OP_ROWMIN_SINGLE;
} else if (op == "SUM") {
opCode = Opcode::OP_ROWSUM_SINGLE;
} else if (op == "PROD") {
opCode = Opcode::OP_ROWPROD_SINGLE;
} else if (op == "ARGMAX") {
opCode = Opcode::OP_ROWARGMAX_SINGLE;
} else if (op == "ARGMIN") {
opCode = Opcode::OP_ROWARGMIN_SINGLE;
}
if (!operand.GetStorage()->GetDynValidShape().empty()) {
std::vector<SymbolicScalar> outValidShape;
for (auto shape : operand.GetStorage()->GetDynValidShape()) {
outValidShape.push_back(shape);
}
outValidShape[axis] = SymbolicScalar(1);
result.GetStorage()->UpdateDynValidShape(outValidShape);
}
auto& newOp = function.AddOperation(opCode, {operand.GetStorage()}, {result.GetStorage()});
newOp.SetAttribute(OP_ATTR_PREFIX + "AXIS", static_cast<int>(axis));
return;
}
[[maybe_unused]] Tensor ReduceSingle(const std::string& op, const Tensor& operand)
{
Tensor result(operand.GetStorage()->tensor->datatype, {operand.GetShape()[0], 1});
Program::GetInstance().AddOperation("REDUCE_" + op + "_SINGLE", {operand.GetStorage()}, {result.GetStorage()});
return result;
}
static void ValidateReductionAxis(const Tensor& self, int axis)
{
CheckAxisRange(self, axis);
const int lastDim = self.GetShape().size() - 1;
const int alignNum = BLOCK_SIZE / BytesOf(self.GetStorage()->tensor->datatype);
auto vecTile = TileShape::Current().GetVecTile();
if (axis == lastDim) {
ASSERT(VectorErrorCode::ERR_CONFIG_ALIGNMENT, vecTile[lastDim] % alignNum == 0)
<< "Reduce op: the tileShape of last axis need to 32Byte align!";
}
}
static Tensor ProcessResultShape(const Tensor& result, const Tensor& self, int axis, bool keepDim)
{
const int lastDim = self.GetShape().size() - 1;
if (keepDim || lastDim == 0) {
return result;
} else {
std::vector<SymbolicScalar> outValidShape;
for (auto shape : self.GetStorage()->GetDynValidShape()) {
outValidShape.push_back(shape);
}
auto outShape = result.GetShape();
outShape.erase(outShape.begin() + axis);
outValidShape.erase(outValidShape.begin() + axis);
return Reshape(result, outShape, outValidShape);
}
}
Tensor Amax(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP16, DT_BF16, DT_INT16, DT_INT32, DT_FP32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "AMAX");
CheckTensorDimRange(self.GetStorage(), 1, 4, "AMAX");
CheckTensorShapeSize(self.GetStorage(), "AMAX");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
resultShape[axis] = 1;
Tensor result(self.GetStorage()->Datatype(), resultShape);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "MAX", self, result, axis);
return ProcessResultShape(result, self, axis, keepDim);
}
Tensor ArgMax(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP16, DT_BF16, DT_FP32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "ARGMAX");
CheckTensorDimRange(self.GetStorage(), 1, 4, "ARGMAX");
CheckTensorShapeSize(self.GetStorage(), "ARGMAX");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
auto vecTile = TileShape::Current().GetVecTile();
resultShape[axis] = 1;
Tensor result(DataType::DT_INT32, resultShape);
if (self.GetDataType() == DT_FP16) {
auto castSelf = CALL(CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(),
self.GetStorage(), DataType::DT_FP32, CastMode::CAST_NONE);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "ARGMAX", castSelf, result, axis);
} else {
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "ARGMAX", self, result, axis);
}
return ProcessResultShape(result, self, axis, keepDim);
}
Tensor ArgMin(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP16, DT_BF16, DT_FP32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "ARGMIN");
CheckTensorDimRange(self.GetStorage(), 1, 4, "ARGMIN");
CheckTensorShapeSize(self.GetStorage(), "ARGMIN");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
auto vecTile = TileShape::Current().GetVecTile();
resultShape[axis] = 1;
Tensor result(DataType::DT_INT32, resultShape);
if (self.GetDataType() == DT_FP16) {
auto castSelf = CALL(CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(),
self.GetStorage(), DataType::DT_FP32, CastMode::CAST_NONE);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "ARGMIN", castSelf, result, axis);
} else {
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "ARGMIN", self, result, axis);
}
return ProcessResultShape(result, self, axis, keepDim);
}
Tensor Amin(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP16, DT_BF16, DT_INT16, DT_INT32, DT_FP32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "AMIN");
CheckTensorDimRange(self.GetStorage(), 1, 4, "AMIN");
CheckTensorShapeSize(self.GetStorage(), "AMIN");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
resultShape[axis] = 1;
Tensor result(self.GetStorage()->Datatype(), resultShape);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "MIN", self, result, axis);
return ProcessResultShape(result, self, axis, keepDim);
}
Tensor Sum(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_BF16, DT_INT32, DT_INT16};
CheckTensorDataType(self.GetStorage(), supportedTypes, "SUM");
CheckTensorDimRange(self.GetStorage(), 1, 4, "SUM");
CheckTensorShapeSize(self.GetStorage(), "SUM");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
resultShape[axis] = 1;
Tensor result(self.GetStorage()->Datatype(), resultShape);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "SUM", self, result, axis);
return ProcessResultShape(result, self, axis, keepDim);
}
Tensor Prod(const Tensor& self, int axis, bool keepDim)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_INT32, DT_INT16};
CheckTensorDataType(self.GetStorage(), supportedTypes, "PROD");
CheckTensorDimRange(self.GetStorage(), 1, 4, "PROD");
CheckTensorShapeSize(self.GetStorage(), "PROD");
axis = axis < 0 ? self.GetShape().size() + axis : axis;
ValidateReductionAxis(self, axis);
auto resultShape = self.GetShape();
resultShape[axis] = 1;
Tensor result(self.GetStorage()->Datatype(), resultShape);
CALL(ReduceSingle, *Program::GetInstance().GetCurrentFunction(), "PROD", self, result, axis);
return ProcessResultShape(result, self, axis, keepDim);
}
void TiledReduceExpand(
Function& function, const TileShape& tileShape, const std::string& op, const LogicalTensorPtr& operand,
const LogicalTensorPtr& result)
{
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, op == "MAX" || op == "SUM") << "Not support op:" << op;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, operand->shape.size() == operand->offset.size())
<< "The shape size of operand and offset should be equal";
if (operand->shape.size() != 2) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unsupported dimension";
}
auto& vecTile = tileShape.GetVecTile();
TileInfo tileInfo({vecTile[0], operand->shape[1]}, std::vector<int64_t>(operand->offset.size()));
for (int i = 0; i < operand->shape[0]; i += vecTile[0]) {
tileInfo.offset[0] = i;
auto inputTile = operand->View(function, tileInfo.shape, tileInfo.offset);
auto resultTile = result->View(function, tileInfo.shape, tileInfo.offset);
TileReduceNew(function, tileShape, op, npu::tile_fwk::ReduceType::EXPAND, inputTile, resultTile);
}
}
[[maybe_unused]] void TensorReduceExpand(
Function& function, const std::string& op, const LogicalTensorPtr& operand, const LogicalTensorPtr& result)
{
function.AddOperation(op == "MAX" ? Opcode::OP_ROWEXPMAX : Opcode::OP_ROWEXPSUM, {operand}, {result});
}
void TensorReduceExpand(Function& function, const std::string& op, const Tensor& operand, const Tensor& result)
{
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, op == "MAX" || op == "SUM") << "Not support op:" << op;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, operand.GetShape().size() == operand.GetStorage()->offset.size())
<< "The shape size of operand and offset must be equal";
function.AddOperation(
op == "MAX" ? Opcode::OP_ROWEXPMAX : Opcode::OP_ROWEXPSUM, {operand.GetStorage()}, {result.GetStorage()});
return;
}
[[maybe_unused]] Tensor ReduceExpand(const std::string& op, const Tensor& operand)
{
Tensor result(operand.GetStorage()->tensor->datatype, operand.GetShape());
Program::GetInstance().AddOperation("ROW_" + op + "_EXPAND", {operand.GetStorage()}, {result.GetStorage()});
return result;
}
void TiledReduceExpandNew(
Function& function, const TileShape& tileShape, const std::string& op, const LogicalTensorPtr& operand,
const LogicalTensorPtr& result)
{
if (operand->shape.size() != 2) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unsupported dimension";
}
auto& vecTile = tileShape.GetVecTile();
TileInfo tileInfo({vecTile[0], operand->shape[1]}, std::vector<int64_t>(operand->offset.size()));
for (int i = 0; i < operand->shape[0]; i += vecTile[0]) {
tileInfo.offset[0] = i;
auto inputTile = operand->View(function, tileInfo.shape, tileInfo.offset);
auto resultTile = result->View(function, tileInfo.shape, tileInfo.offset);
TileReduceNew(function, tileShape, op, npu::tile_fwk::ReduceType::EXPAND, inputTile, resultTile);
}
}
Tensor RowSumExpand(const Tensor& operand)
{
DECLARE_TRACER();
Tensor result(operand.GetStorage()->Datatype(), operand.GetShape());
CALL(ReduceExpand, *Program::GetInstance().GetCurrentFunction(), "SUM", operand, result);
return result;
}
Tensor RowMaxExpand(const Tensor& operand)
{
DECLARE_TRACER();
Tensor result(operand.GetStorage()->Datatype(), operand.GetShape());
CALL(ReduceExpand, *Program::GetInstance().GetCurrentFunction(), "MAX", operand, result);
return result;
}
void RowMaxSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "MAX", iOperand[0], oOperand[0], axis);
}
void RowMinSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "MIN", iOperand[0], oOperand[0], axis);
}
void RowSumSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "SUM", iOperand[0], oOperand[0], axis);
}
void RowProdSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "PROD", iOperand[0], oOperand[0], axis);
}
void RowExpMaxSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
TiledReduceExpand(function, tileShape, "MAX", iOperand[0], oOperand[0]);
}
void RowExpSumSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
TiledReduceExpand(function, tileShape, "SUM", iOperand[0], oOperand[0]);
}
void RowArgMaxSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "ARGMAX", iOperand[0], oOperand[0], axis);
}
void RowArgMinSingleOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
UnaryOperationOperandCheck(iOperand, oOperand);
auto axis = op.GetIntAttribute(OP_ATTR_PREFIX + "AXIS");
TiledReduceSingle(function, tileShape, "ARGMIN", iOperand[0], oOperand[0], axis);
}
REGISTER_OPERATION_TILED_FUNC(OP_ROWMAX_SINGLE, Opcode::OP_ROWMAX_SINGLE, RowMaxSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWMIN_SINGLE, Opcode::OP_ROWMIN_SINGLE, RowMinSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWSUM_SINGLE, Opcode::OP_ROWSUM_SINGLE, RowSumSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWPROD_SINGLE, Opcode::OP_ROWPROD_SINGLE, RowProdSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWARGMAX_SINGLE, Opcode::OP_ROWARGMAX_SINGLE, RowArgMaxSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWARGMIN_SINGLE, Opcode::OP_ROWARGMIN_SINGLE, RowArgMinSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWEXPMAX, Opcode::OP_ROWEXPMAX, RowExpMaxSingleOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_ROWEXPSUM, Opcode::OP_ROWEXPSUM, RowExpSumSingleOperationTileFunc);
}