* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file indexing.cpp
* \brief
*/
#include <climits>
#include <limits>
#include <cmath>
#include "interface/utils/operator_tracer.h"
#include "passes/pass_utils/graph_utils.h"
#include "interface/function/function.h"
#include "interface/program/program.h"
#include "interface/operation/operation_common.h"
#include "tensor_transformation.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
constexpr float FP16_MAX = 65504.0f;
namespace {
bool IsFp8DataType(DataType dtype) { return dtype == DT_FP8E4M3 || dtype == DT_FP8E5M2 || dtype == DT_FP8E8M0; }
void CheckFp8ArchSupport(const Tensor& tensor, const std::string& opName)
{
if (!IsFp8DataType(tensor.GetDataType())) {
return;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, Platform::Instance().GetSoc().GetNPUArch() == NPUArch::DAV_3510)
<< opName << ": DT_FP8E4M3, DT_FP8E5M2 and DT_FP8E8M0 are only supported on DAV_3510 architecture.";
}
}
struct IndexAddPara {
const LogicalTensorPtr& selfInput;
const LogicalTensorPtr& srcInput;
const LogicalTensorPtr& indicesInput;
const LogicalTensorPtr& dstTensor;
const int axis;
const Element& alpha;
};
struct IndexAddTileInfoPara {
TileInfo selfTileInfo;
TileInfo srcTileInfo;
TileInfo indicesTileInfo;
TileInfo dstTileInfo;
};
Shape GetTempShape(Shape shape, size_t axis)
{
Shape newShape(shape.size(), 1);
for (size_t i = axis + 1; i < shape.size(); ++i) {
newShape[i] = shape[i];
}
auto alignSize = BLOCK_SIZE / BytesOf(DT_BF16);
newShape[shape.size() - 1] = (newShape[shape.size() - 1] + alignSize - 1) / alignSize * alignSize;
return newShape;
}
void IndexAddUBExpandFunc(Function& function, const IndexAddPara& indexaddPara, IndexAddTileInfoPara& indexaddTileInfo)
{
const LogicalTensorPtr& selfInput = indexaddPara.selfInput;
const LogicalTensorPtr& srcInput = indexaddPara.srcInput;
const LogicalTensorPtr& indicesInput = indexaddPara.indicesInput;
const LogicalTensorPtr& dstTensor = indexaddPara.dstTensor;
const int axis = indexaddPara.axis;
const Element& alpha = indexaddPara.alpha;
auto dstTile = dstTensor->View(function, indexaddTileInfo.dstTileInfo.shape, indexaddTileInfo.dstTileInfo.offset);
auto selfTile =
selfInput->View(function, indexaddTileInfo.selfTileInfo.shape, indexaddTileInfo.selfTileInfo.offset);
auto srcTile = srcInput->View(function, indexaddTileInfo.srcTileInfo.shape, indexaddTileInfo.srcTileInfo.offset);
indexaddTileInfo.indicesTileInfo.offset = {
indexaddTileInfo.srcTileInfo.offset[axis]};
indexaddTileInfo.indicesTileInfo.shape = {indexaddTileInfo.srcTileInfo.shape[axis]};
auto indexTile =
indicesInput->View(function, indexaddTileInfo.indicesTileInfo.shape, indexaddTileInfo.indicesTileInfo.offset);
Shape tempShape(dstTile->GetShape().size(), 1);
auto alignSize = BLOCK_SIZE / BytesOf(DT_BF16);
tempShape[dstTile->GetShape().size() - 1] =
(tempShape[dstTile->GetShape().size() - 1] + alignSize - 1) / alignSize * alignSize;
auto tempBuffer = std::make_shared<LogicalTensor>(function, DT_BF16, tempShape);
if (selfTile->Datatype() == DT_BF16 || (selfTile->Datatype() == DT_FP16 && indexTile->Datatype() == DT_INT64 &&
(std::abs(alpha.Cast<float>() - 1) < 1e-6f))) {
LogicalTensorPtr selfConvertedTile = std::make_shared<LogicalTensor>(function, DT_FP32, selfTile->GetShape());
Operation& castSelfOp = function.AddOperation(Opcode::OP_CAST, {selfTile}, {selfConvertedTile});
selfConvertedTile->UpdateDynValidShape(selfTile->GetDynValidShape());
castSelfOp.SetAttribute(OP_ATTR_PREFIX + "mode", CastMode::CAST_NONE);
LogicalTensorPtr srcConvertedTile = std::make_shared<LogicalTensor>(function, DT_FP32, srcTile->GetShape());
Operation& castSrcOp = function.AddOperation(Opcode::OP_CAST, {srcTile}, {srcConvertedTile});
srcConvertedTile->UpdateDynValidShape(srcTile->GetDynValidShape());
castSrcOp.SetAttribute(OP_ATTR_PREFIX + "mode", CastMode::CAST_NONE);
LogicalTensorPtr dstConvertedTile = std::make_shared<LogicalTensor>(function, DT_FP32, dstTile->GetShape());
tempBuffer = std::make_shared<LogicalTensor>(function, DT_BF16, GetTempShape(dstTile->GetShape(), axis));
auto& op = function.AddOperation(
Opcode::OP_INDEX_ADD_UB, {selfConvertedTile, srcConvertedTile, indexTile}, {dstConvertedTile, tempBuffer});
dstConvertedTile->UpdateDynValidShape(dstTile->GetDynValidShape());
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
op.SetAttribute(OpAttributeKey::scalar, alpha);
Operation& castDstOp = function.AddOperation(Opcode::OP_CAST, {dstConvertedTile}, {dstTile});
castDstOp.SetAttribute(OP_ATTR_PREFIX + "mode", CastMode::CAST_RINT);
} else {
auto& op =
function.AddOperation(Opcode::OP_INDEX_ADD_UB, {selfTile, srcTile, indexTile}, {dstTile, tempBuffer});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
op.SetAttribute(OpAttributeKey::scalar, alpha);
}
}
void InnerTiledIndexAddUB(
size_t cur, Function& function, const TileShape& tileShape, const IndexAddPara& indexaddPara,
IndexAddTileInfoPara& indexaddTileInfo)
{
if (cur == indexaddPara.dstTensor->shape.size()) {
IndexAddUBExpandFunc(function, indexaddPara, indexaddTileInfo);
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t tmpTile = vecTile[cur];
if (static_cast<int>(cur) == indexaddPara.axis) {
indexaddTileInfo.dstTileInfo.offset[cur] = 0;
indexaddTileInfo.dstTileInfo.shape[cur] = indexaddPara.dstTensor->GetShape()[cur];
indexaddTileInfo.selfTileInfo.offset[cur] = 0;
indexaddTileInfo.selfTileInfo.shape[cur] = indexaddPara.selfInput->GetShape()[cur];
indexaddTileInfo.srcTileInfo.offset[cur] = 0;
indexaddTileInfo.srcTileInfo.shape[cur] = indexaddPara.srcInput->GetShape()[cur];
InnerTiledIndexAddUB(cur + 1, function, tileShape, indexaddPara, indexaddTileInfo);
return;
}
for (int64_t i = 0; i < indexaddPara.srcInput->GetShape()[cur]; i += tmpTile) {
indexaddTileInfo.dstTileInfo.offset[cur] = i;
indexaddTileInfo.dstTileInfo.shape[cur] = std::min(indexaddPara.dstTensor->GetShape()[cur] - i, tmpTile);
indexaddTileInfo.selfTileInfo.offset[cur] = i;
indexaddTileInfo.selfTileInfo.shape[cur] = std::min(indexaddPara.selfInput->GetShape()[cur] - i, tmpTile);
indexaddTileInfo.srcTileInfo.offset[cur] = i;
indexaddTileInfo.srcTileInfo.shape[cur] = std::min(indexaddPara.srcInput->GetShape()[cur] - i, tmpTile);
InnerTiledIndexAddUB(cur + 1, function, tileShape, indexaddPara, indexaddTileInfo);
}
}
void TiledIndexAddUB(Function& function, const TileShape& tileShape, const IndexAddPara& indexaddPara)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.selfInput->GetShape().size() == indexaddPara.selfInput->GetOffset().size())
<< "The size of indexaddPara selfinput shape and selfinput offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.srcInput->GetShape().size() == indexaddPara.srcInput->GetOffset().size())
<< "The size of indexaddPara srcInput shape and srcInput offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.indicesInput->GetShape().size() == indexaddPara.indicesInput->GetOffset().size())
<< "The size of indexaddPara indicesInput shape and indicesInput offset should be equal";
IndexAddTileInfoPara indexaddTileInfo{
TileInfo(indexaddPara.selfInput->GetShape().size(), indexaddPara.selfInput->GetOffset().size()),
TileInfo(indexaddPara.srcInput->GetShape().size(), indexaddPara.srcInput->GetOffset().size()),
TileInfo(indexaddPara.indicesInput->GetShape().size(), indexaddPara.indicesInput->GetOffset().size()),
TileInfo(indexaddPara.dstTensor->GetShape().size(), indexaddPara.dstTensor->GetOffset().size())};
InnerTiledIndexAddUB(0, function, tileShape, indexaddPara, indexaddTileInfo);
}
void TensorIndexAddUB(Function& function, const IndexAddPara& indexaddPara)
{
auto& op = GraphUtils::AddDynOperation(
function, Opcode::OP_INDEX_ADD_UB, {indexaddPara.selfInput, indexaddPara.srcInput, indexaddPara.indicesInput},
{indexaddPara.dstTensor});
op.SetAttribute(OP_ATTR_PREFIX + "axis", indexaddPara.axis);
op.SetAttribute(OpAttributeKey::scalar, indexaddPara.alpha);
}
bool CheckAlphaOverflow(Element alpha, DataType dtype)
{
double value = alpha.Cast<double>();
if (std::isnan(value) || std::isinf(value))
return true;
switch (dtype) {
case DT_INT8:
return value < std::numeric_limits<int8_t>::min() || value > std::numeric_limits<int8_t>::max();
case DT_INT16:
return value < std::numeric_limits<int16_t>::min() || value > std::numeric_limits<int16_t>::max();
case DT_INT32:
return value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max();
case DT_FP16:
return std::abs(value) > FP16_MAX;
case DT_BF16:
return std::abs(value) > std::numeric_limits<float>::max();
case DT_FP32:
return std::abs(value) > std::numeric_limits<float>::max();
default:
return false;
}
}
void CheckIndexAddParamsInvalid(
const Tensor& self, const Tensor& src, const Tensor& indices, const int axis, const Element& alpha,
const Opcode& opCode)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
axis < static_cast<int>(self.GetShape().size()) && axis >= -static_cast<int>(self.GetShape().size()))
<< "axis out of range of shape size";
CheckTensorDimRange(self.GetStorage(), 1, 5, "INDEXADD");
CheckTensorDimRange(indices.GetStorage(), 1, 1, "INDEXADD");
CheckTensorShapeSize(self.GetStorage(), "INDEXADD");
CheckTensorShapeSize(src.GetStorage(), "INDEXADD");
CheckTensorShapeSize(indices.GetStorage(), "INDEXADD");
std::vector<LogicalTensorPtr> tensors = {self.GetStorage(), src.GetStorage()};
CheckTensorsDimConsistency(tensors, "INDEXADD");
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, src.GetShape()[axis] == indices.GetShape()[0])
<< "src shape[axis] and indices[0] must equal";
for (size_t i = 0; i < self.GetShape().size(); ++i) {
if (static_cast<int>(i) == axis) {
continue;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, src.GetShape()[i] == self.GetShape()[i])
<< "src shape and self shape should be equal";
}
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_FP16, DT_BF16, DT_INT32, DT_INT16};
if (opCode == Opcode::OP_INDEX_ADD) {
supportedTypes.insert(DT_INT8);
}
CheckTensorDataType(self.GetStorage(), supportedTypes, "INDEXADD");
CheckTensorsDataTypeConsistency(self.GetStorage(), src.GetStorage(), "INDEXADD");
CheckTensorsFormatConsistency(self.GetStorage(), src.GetStorage(), "INDEXADD");
std::unordered_set<DataType> indexSupportedTypes = {DT_INT32, DT_INT64};
CheckTensorDataType(indices.GetStorage(), indexSupportedTypes, "INDEXADD");
if (CheckAlphaOverflow(alpha, self.GetDataType())) {
ASSERT(VectorErrorCode::ERR_RUNTIME_LOGIC, false)
<< "Value cannot be converted to type " << DataType2String(self.GetDataType()) << " without overflow!";
}
}
Tensor IndexAddUB(const Tensor& self, const Tensor& src, const Tensor& indices, int axis, const Element& alpha)
{
DECLARE_TRACER();
CheckAxisRange(self, axis);
CheckIndexAddParamsInvalid(self, src, indices, axis, alpha, Opcode::OP_INDEX_ADD_UB);
DataType selfDataType = self.GetDataType();
Element alpha_ = Element(selfDataType, alpha.Cast<float>());
Tensor result(selfDataType, self.GetShape());
CALL(
IndexAddUB, *Program::GetInstance().GetCurrentFunction(),
{self.GetStorage(), src.GetStorage(), indices.GetStorage(), result.GetStorage(), axis, alpha_});
return result;
}
void IndexAddExpandFunc(
Function& function, const IndexAddPara& indexaddPara, IndexAddTileInfoPara& indexaddTileInfo,
const LogicalTensorPtr& cachedDstTile = nullptr, const LogicalTensorPtr& cachedSelfTile = nullptr)
{
const LogicalTensorPtr& selfInput = indexaddPara.selfInput;
const LogicalTensorPtr& srcInput = indexaddPara.srcInput;
const LogicalTensorPtr& indicesInput = indexaddPara.indicesInput;
const LogicalTensorPtr& dstTensor = indexaddPara.dstTensor;
const int axis = indexaddPara.axis;
auto selfTile =
cachedSelfTile ?
cachedSelfTile :
selfInput->View(function, indexaddTileInfo.selfTileInfo.shape, indexaddTileInfo.selfTileInfo.offset);
auto dstTile =
cachedDstTile ?
cachedDstTile :
dstTensor->View(function, indexaddTileInfo.dstTileInfo.shape, indexaddTileInfo.dstTileInfo.offset);
auto srcTile = srcInput->View(function, indexaddTileInfo.srcTileInfo.shape, indexaddTileInfo.srcTileInfo.offset);
indexaddTileInfo.indicesTileInfo.offset = {indexaddTileInfo.srcTileInfo.offset[axis]};
indexaddTileInfo.indicesTileInfo.shape = {indexaddTileInfo.srcTileInfo.shape[axis]};
auto indexTile =
indicesInput->View(function, indexaddTileInfo.indicesTileInfo.shape, indexaddTileInfo.indicesTileInfo.offset);
Shape tmpShape(2, 1);
auto alignSize = BLOCK_SIZE / BytesOf(srcTile->Datatype());
tmpShape[1] = AlignUp(srcTile->GetShape()[srcTile->GetShape().size() - 1], alignSize);
auto tmpTile = std::make_shared<LogicalTensor>(function, srcTile->Datatype(), tmpShape);
auto& op = function.AddOperation(Opcode::OP_INDEX_ADD, {selfTile, srcTile, indexTile}, {dstTile, tmpTile});
op.SetAttribute(OpAttributeKey::inplaceIdx, 0);
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
op.SetAttribute(OpAttributeKey::scalar, indexaddPara.alpha);
}
using TileCache = std::unordered_map<int64_t, std::pair<LogicalTensorPtr, LogicalTensorPtr>>;
void InnerTiledIndexAdd(
size_t cur, Function& function, const TileShape& tileShape, const IndexAddPara& indexaddPara,
IndexAddTileInfoPara& indexaddTileInfo, TileCache& tileCache, int64_t encodeKey = 0)
{
if (cur == indexaddPara.dstTensor->shape.size()) {
auto it = tileCache.find(encodeKey);
if (it == tileCache.end()) {
auto selfTile = indexaddPara.selfInput->View(
function, indexaddTileInfo.selfTileInfo.shape, indexaddTileInfo.selfTileInfo.offset);
auto dstTile = indexaddPara.dstTensor->View(
function, indexaddTileInfo.dstTileInfo.shape, indexaddTileInfo.dstTileInfo.offset);
it = tileCache.emplace(encodeKey, std::make_pair(dstTile, selfTile)).first;
}
IndexAddExpandFunc(function, indexaddPara, indexaddTileInfo, it->second.first, it->second.second);
return;
}
const auto& vecTile = tileShape.GetVecTile();
int64_t tileStep = vecTile[cur];
const auto& srcShape = indexaddPara.srcInput->GetShape();
const auto& dstShape = indexaddPara.dstTensor->GetShape();
int64_t numTilesInCurDim = (srcShape[cur] + tileStep - 1) / tileStep;
if (static_cast<int>(cur) == indexaddPara.axis) {
indexaddTileInfo.dstTileInfo.offset[cur] = 0;
indexaddTileInfo.dstTileInfo.shape[cur] = dstShape[cur];
indexaddTileInfo.selfTileInfo.offset[cur] = 0;
indexaddTileInfo.selfTileInfo.shape[cur] = dstShape[cur];
for (int i = 0; i < srcShape[cur]; i += tileStep) {
indexaddTileInfo.srcTileInfo.offset[cur] = i;
indexaddTileInfo.srcTileInfo.shape[cur] = std::min(srcShape[cur] - i, tileStep);
InnerTiledIndexAdd(cur + 1, function, tileShape, indexaddPara, indexaddTileInfo, tileCache, encodeKey);
}
} else {
int64_t tileIndex = 0;
for (int i = 0; i < srcShape[cur]; i += tileStep) {
indexaddTileInfo.dstTileInfo.offset[cur] = i;
indexaddTileInfo.dstTileInfo.shape[cur] = std::min(dstShape[cur] - i, tileStep);
indexaddTileInfo.selfTileInfo.offset[cur] = i;
indexaddTileInfo.selfTileInfo.shape[cur] = std::min(dstShape[cur] - i, tileStep);
indexaddTileInfo.srcTileInfo.offset[cur] = i;
indexaddTileInfo.srcTileInfo.shape[cur] = std::min(srcShape[cur] - i, tileStep);
int64_t newKey = encodeKey * numTilesInCurDim + tileIndex;
tileIndex++;
InnerTiledIndexAdd(cur + 1, function, tileShape, indexaddPara, indexaddTileInfo, tileCache, newKey);
}
}
}
void TiledIndexAdd(Function& function, const TileShape& tileShape, const IndexAddPara& indexaddPara)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.selfInput->GetShape().size() == indexaddPara.selfInput->GetOffset().size())
<< "The size of indexaddPara selfinput shape and selfinput offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.srcInput->GetShape().size() == indexaddPara.srcInput->GetOffset().size())
<< "The size of indexaddPara srcInput shape and srcInput offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
indexaddPara.indicesInput->GetShape().size() == indexaddPara.indicesInput->GetOffset().size())
<< "The size of indexaddPara indicesInput shape and indicesInput offset should be equal";
IndexAddTileInfoPara indexaddTileInfo{
TileInfo(indexaddPara.selfInput->GetShape().size(), indexaddPara.selfInput->GetOffset().size()),
TileInfo(indexaddPara.srcInput->GetShape().size(), indexaddPara.srcInput->GetOffset().size()),
TileInfo(indexaddPara.indicesInput->GetShape().size(), indexaddPara.indicesInput->GetOffset().size()),
TileInfo(indexaddPara.dstTensor->GetShape().size(), indexaddPara.dstTensor->GetOffset().size())};
TileCache tileCache;
InnerTiledIndexAdd(0, function, tileShape, indexaddPara, indexaddTileInfo, tileCache);
}
void TensorIndexAdd(Function& function, const IndexAddPara& indexaddPara)
{
auto& op = GraphUtils::AddDynOperation(
function, Opcode::OP_INDEX_ADD, {indexaddPara.selfInput, indexaddPara.srcInput, indexaddPara.indicesInput},
{indexaddPara.dstTensor});
op.SetAttribute(OpAttributeKey::inplaceIdx, 0);
op.SetAttribute(OP_ATTR_PREFIX + "axis", indexaddPara.axis);
op.SetAttribute(OpAttributeKey::scalar, indexaddPara.alpha);
}
void IndexAdd_(Tensor& self, const Tensor& src, const Tensor& indices, int axis, const Element& alpha)
{
DECLARE_TRACER();
CheckAxisRange(self, axis);
CheckIndexAddParamsInvalid(self, src, indices, axis, alpha, Opcode::OP_INDEX_ADD);
DataType selfDataType = self.GetDataType();
Element castedAlpha = Element(selfDataType, alpha.Cast<float>());
Tensor result(selfDataType, self.GetShape());
CALL(
IndexAdd, *Program::GetInstance().GetCurrentFunction(),
{self.GetStorage(), src.GetStorage(), indices.GetStorage(), result.GetStorage(), axis, castedAlpha});
self = result;
}
void TiledGatherOperation(
Function& function, const TileShape& tileShape, size_t cur, Input& paramsInput, Input& indicesInput, int axis,
const LogicalTensorPtr& result, TileInfo& resultTileInfo)
{
if (cur == result->shape.size()) {
auto paramsTile =
paramsInput.tensor.GetStorage()->View(function, paramsInput.tileInfo.shape, paramsInput.tileInfo.offset);
auto indicesTile =
indicesInput.tensor.GetStorage()->View(function, indicesInput.tileInfo.shape, indicesInput.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
if (function.IsStatic()) {
auto& op = function.AddOperation(Opcode::OP_GATHER_FROM_UB, {paramsTile, indicesTile}, {resultTile});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
} else {
auto& op = function.AddOperation(Opcode::OP_GATHER, {paramsTile, indicesTile}, {resultTile});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
}
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t tmpTile = vecTile[cur];
for (int i = 0; i < result->shape[cur]; i += tmpTile) {
if (cur < static_cast<size_t>(axis)) {
paramsInput.tileInfo.offset[cur] = i % paramsInput.tensor.GetShape()[cur];
paramsInput.tileInfo.shape[cur] =
std::min(paramsInput.tensor.GetShape()[cur] - paramsInput.tileInfo.offset[cur], tmpTile);
} else if (
cur >= static_cast<size_t>(axis) &&
(cur < static_cast<size_t>(axis) + indicesInput.tensor.GetShape().size())) {
paramsInput.tileInfo.offset[axis] = 0;
paramsInput.tileInfo.shape[axis] = paramsInput.tensor.GetShape()[axis];
indicesInput.tileInfo.offset[cur - axis] = i % indicesInput.tensor.GetShape()[cur - axis];
indicesInput.tileInfo.shape[cur - axis] = std::min(
indicesInput.tensor.GetShape()[cur - axis] - indicesInput.tileInfo.offset[cur - axis], tmpTile);
} else {
int paramHighAxis = cur - indicesInput.tensor.GetShape().size() + 1;
paramsInput.tileInfo.offset[paramHighAxis] = i % paramsInput.tensor.GetShape()[paramHighAxis];
paramsInput.tileInfo.shape[paramHighAxis] = std::min(
paramsInput.tensor.GetShape()[paramHighAxis] - paramsInput.tileInfo.offset[paramHighAxis], tmpTile);
}
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], tmpTile);
TiledGatherOperation(function, tileShape, cur + 1, paramsInput, indicesInput, axis, result, resultTileInfo);
}
}
std::vector<int64_t> GatherOperationResultShape(LogicalTensorPtr params, LogicalTensorPtr indices, int axis)
{
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, params->shape.size() == params->offset.size())
<< "The size of params shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices->shape.size() == indices->offset.size())
<< "The size of indices shape and offset should be equal";
int paramsRank = params->shape.size();
if (axis < 0) {
axis = axis + paramsRank;
}
std::vector<int64_t> resultShape = params->shape;
resultShape.erase(resultShape.begin() + axis);
resultShape.insert(resultShape.begin() + axis, indices->shape.begin(), indices->shape.end());
return resultShape;
}
void TiledGatherOperation(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& params, const LogicalTensorPtr& indices,
int axis, const LogicalTensorPtr& result)
{
std::vector<int64_t> expectedShape = GatherOperationResultShape(params, indices, axis);
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, result->shape.size() == expectedShape.size())
<< "The size of result shape and expectedShape should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, result->shape.size() == result->offset.size())
<< "The size of result shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, params->shape.size() == params->offset.size())
<< "The size of params shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices->shape.size() == indices->offset.size())
<< "The size of indices shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_SHAPE_DIM_UNSUPPORTED, result->shape.size() <= NUM_VALUE_5)
<< "Not support shape size of result greater than 5";
ASSERT(VectorErrorCode::ERR_PARAM_SHAPE_DIM_UNSUPPORTED, indices->shape.size() <= NUM_VALUE_2)
<< "Not support shape size of indices greater than 2";
if (axis < 0) {
axis += params->shape.size();
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, axis >= 0 && axis < static_cast<int>(params->shape.size()))
<< "The axis should be greater than or equal to 0 and less than shape size of params";
TileInfo paramsTileInfo(params->shape.size(), params->offset.size());
TileInfo indicesTileInfo(indices->shape.size(), indices->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto paramsInput = Input{params, paramsTileInfo};
auto indicesInput = Input{indices, indicesTileInfo};
TiledGatherOperation(function, tileShape, 0, paramsInput, indicesInput, axis, result, resultTileInfo);
}
LogicalTensorPtr TensorGatherOperation(
Function& function, const LogicalTensorPtr& params, const LogicalTensorPtr& indices, int axis)
{
const auto& paramsDynShape = params->GetDynValidShape();
const auto& indicesDynShape = indices->GetDynValidShape();
const int paramsRank = paramsDynShape.size();
if (axis < 0) {
axis += paramsRank;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, axis >= 0 && axis < paramsRank)
<< "The configuration of the axis is incorrect";
}
std::vector<int64_t> resultShape = GatherOperationResultShape(params, indices, axis);
auto result = std::make_shared<LogicalTensor>(function, params->Datatype(), resultShape);
std::vector<SymbolicScalar> outValidShape = paramsDynShape;
outValidShape.erase(outValidShape.begin() + axis);
outValidShape.insert(outValidShape.begin() + axis, indicesDynShape.begin(), indicesDynShape.end());
auto& op = GraphUtils::AddDynOperation(function, Opcode::OP_GATHER, {params, indices}, {result}, {outValidShape});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
return result;
}
void TensorGatherMask(
Function& function, const LogicalTensorPtr& self, const LogicalTensorPtr& result, const uint8_t& patternMode)
{
if (patternMode != 0) {
auto& op = function.AddOperation(Opcode::OP_GATHER_MASK_BUILDIN, {self}, {result});
op.SetAttribute(OP_ATTR_PREFIX + "patternMode", patternMode);
return;
}
}
void CheckGatherParamsInvalid(const Tensor& params, const Tensor& indices, int axis, const std::string& opName)
{
static const std::unordered_set<DataType> a2a3Types = {DT_FP32, DT_FP16, DT_BF16, DT_INT32, DT_INT16,
DT_INT8, DT_UINT32, DT_UINT16, DT_UINT8};
static const std::unordered_set<DataType> a5Types = {DT_FP32, DT_FP16, DT_BF16, DT_INT32, DT_INT16,
DT_INT8, DT_UINT32, DT_UINT16, DT_UINT8, DT_BOOL,
DT_FP8E4M3, DT_FP8E5M2, DT_FP8E8M0};
const auto& supportedTypes = GetSupportedDataTypesByArch(a2a3Types, a5Types);
CheckTensorDataType(params.GetStorage(), supportedTypes, opName);
CheckFp8ArchSupport(params, opName);
std::unordered_set<DataType> indexSupportedTypes = {DT_INT32, DT_INT64};
CheckTensorDataType(indices.GetStorage(), indexSupportedTypes, opName);
CheckTensorDimRange(params.GetStorage(), 1, 4, opName);
CheckTensorDimRange(indices.GetStorage(), 1, 2, opName);
CheckTensorShapeSize(params.GetStorage(), opName);
CheckTensorShapeSize(indices.GetStorage(), opName);
CheckAxisRange(params, axis);
CheckTensorsFormatConsistency(params.GetStorage(), indices.GetStorage(), opName);
}
Tensor Gather(const Tensor& params, const Tensor& indices, int axis)
{
DECLARE_TRACER();
CheckGatherParamsInvalid(params, indices, axis, "GATHER");
RETURN_CALL(
GatherOperation, *Program::GetInstance().GetCurrentFunction(), params.GetStorage(), indices.GetStorage(), axis);
}
Tensor TensorIndex(const Tensor& params, const Tensor& indices)
{
DECLARE_TRACER();
CheckGatherParamsInvalid(params, indices, 0, "TENSORINDEX");
RETURN_CALL(
GatherOperation, *Program::GetInstance().GetCurrentFunction(), params.GetStorage(), indices.GetStorage(), 0);
}
void TiledGatherElementOperation(
Function& function, const TileShape& tileShape, size_t cur, Input& paramsInput, Input& indicesInput, int axis,
const LogicalTensorPtr& result, TileInfo& resultTileInfo)
{
if (cur == result->shape.size()) {
auto paramsTile =
paramsInput.tensor.GetStorage()->View(function, paramsInput.tileInfo.shape, paramsInput.tileInfo.offset);
auto indicesTile =
indicesInput.tensor.GetStorage()->View(function, indicesInput.tileInfo.shape, indicesInput.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
Shape tmpShape({indicesTile->GetShape()[indicesTile->GetShape().size() - 1]});
tmpShape[0] = 2 * AlignUp(tmpShape[0], BLOCK_SIZE / BytesOf(indicesTile->Datatype()));
auto tmpBuffer = std::make_shared<LogicalTensor>(function, indicesTile->Datatype(), tmpShape);
auto& op = function.AddOperation(Opcode::OP_GATHER_ELEMENT, {paramsTile, indicesTile}, {resultTile, tmpBuffer});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t tmpTile = vecTile[cur];
for (int i = 0; i < result->shape[cur]; i += tmpTile) {
if (cur == static_cast<size_t>(axis)) {
paramsInput.tileInfo.offset[cur] = 0;
paramsInput.tileInfo.shape[cur] = paramsInput.tensor.GetShape()[cur];
indicesInput.tileInfo.offset[cur] = i % indicesInput.tensor.GetShape()[cur];
indicesInput.tileInfo.shape[cur] =
std::min(indicesInput.tensor.GetShape()[cur] - indicesInput.tileInfo.offset[cur], tmpTile);
} else {
paramsInput.tileInfo.offset[cur] = i % paramsInput.tensor.GetShape()[cur];
paramsInput.tileInfo.shape[cur] =
std::min(paramsInput.tensor.GetShape()[cur] - paramsInput.tileInfo.offset[cur], tmpTile);
indicesInput.tileInfo.offset[cur] = i % indicesInput.tensor.GetShape()[cur];
indicesInput.tileInfo.shape[cur] =
std::min(indicesInput.tensor.GetShape()[cur] - indicesInput.tileInfo.offset[cur], tmpTile);
}
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], tmpTile);
TiledGatherElementOperation(
function, tileShape, cur + 1, paramsInput, indicesInput, axis, result, resultTileInfo);
}
}
void TiledGatherElementOperation(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& params, const LogicalTensorPtr& indices,
int axis, const LogicalTensorPtr& result)
{
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, result->shape.size() == result->offset.size())
<< "The size of result shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, params->shape.size() == params->offset.size())
<< "The size of params shape and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices->shape.size() == indices->offset.size())
<< "The size of indices shape and offset should be equal";
TileInfo paramsTileInfo(params->shape.size(), params->offset.size());
TileInfo indicesTileInfo(indices->shape.size(), indices->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto paramsInput = Input{params, paramsTileInfo};
auto indicesInput = Input{indices, indicesTileInfo};
TiledGatherElementOperation(function, tileShape, 0, paramsInput, indicesInput, axis, result, resultTileInfo);
}
LogicalTensorPtr TensorGatherElementOperation(
Function& function, const LogicalTensorPtr& params, const LogicalTensorPtr& indices, int axis)
{
auto result = std::make_shared<LogicalTensor>(function, params->Datatype(), indices->shape);
std::vector<std::vector<SymbolicScalar>> outValidShape;
outValidShape.push_back(indices->GetDynValidShape());
auto& op =
GraphUtils::AddDynOperation(function, Opcode::OP_GATHER_ELEMENT, {params, indices}, {result}, outValidShape);
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
return result;
}
Tensor GatherElements(const Tensor& params, const Tensor& indices, int axis)
{
DECLARE_TRACER();
std::vector<LogicalTensorPtr> tensors = {params.GetStorage(), indices.GetStorage()};
CheckTensorsDimConsistency(tensors, "GATHERELEMENTS");
CheckAxisRange(params, axis);
for (size_t i = 0; i < params.GetShape().size(); ++i) {
if (static_cast<int>(i) == axis) {
continue;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices.GetShape()[i] <= params.GetShape()[i])
<< "The shape of params and indices should be equal";
}
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_FP16, DT_BF16, DT_INT32, DT_INT16};
CheckTensorDataType(params.GetStorage(), supportedTypes, "GATHERELEMENTS");
std::unordered_set<DataType> indexSupportedTypes = {DT_INT32, DT_INT64};
CheckTensorDataType(indices.GetStorage(), indexSupportedTypes, "GATHERELEMENTS");
CheckTensorDimRange(params.GetStorage(), 1, 4, "GATHERELEMENTS");
CheckTensorShapeSize(params.GetStorage(), "GATHERELEMENTS");
CheckTensorShapeSize(indices.GetStorage(), "GATHERELEMENTS");
CheckTensorsFormatConsistency(params.GetStorage(), indices.GetStorage(), "GATHERELEMENTS");
RETURN_CALL(
GatherElementOperation, *Program::GetInstance().GetCurrentFunction(), params.GetStorage(), indices.GetStorage(),
axis);
}
struct ScatterElementSPara {
const LogicalTensorPtr& dstTensor;
const LogicalTensorPtr& srcInput;
const LogicalTensorPtr& idxInput;
const Element& scalar;
const int axis;
const int scatterMode;
};
struct ScatterElementSTileInfoPara {
TileInfo srcTileInfo;
TileInfo idxTileInfo;
TileInfo dstTileInfo;
};
void InnerTiledScatterElementS(
size_t cur, Function& function, const TileShape& tileShape, const ScatterElementSPara& scatterPara,
ScatterElementSTileInfoPara& scatterTileInfo)
{
const LogicalTensorPtr& dstTensor = scatterPara.dstTensor;
const LogicalTensorPtr& srcInput = scatterPara.srcInput;
const LogicalTensorPtr& idxInput = scatterPara.idxInput;
const Element& scalar = scatterPara.scalar;
const int axis = scatterPara.axis;
const int mode = scatterPara.scatterMode;
if (cur == dstTensor->shape.size()) {
auto srcTile = srcInput->View(function, scatterTileInfo.srcTileInfo.shape, scatterTileInfo.srcTileInfo.offset);
auto idxTile = idxInput->View(function, scatterTileInfo.idxTileInfo.shape, scatterTileInfo.idxTileInfo.offset);
auto dstTile = dstTensor->View(function, scatterTileInfo.dstTileInfo.shape, scatterTileInfo.dstTileInfo.offset);
auto& op = function.AddOperation(Opcode::OP_SCATTER_ELEMENT, {srcTile, idxTile}, {dstTile});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
op.SetAttribute(OpAttributeKey::scalar, scalar);
op.SetAttribute(OP_ATTR_PREFIX + "scatter_mode", mode);
return;
}
auto& vecTile = tileShape.GetVecTile();
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, vecTile[axis] >= dstTensor->shape[axis])
<< "The axis is not supported for tile splitting";
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, vecTile[axis] >= idxInput->shape[axis])
<< "The axis is not supported for tile splitting";
int64_t tmpTile = vecTile[cur];
if (static_cast<int>(cur) == axis) {
tmpTile = std::max(dstTensor->shape[axis], idxInput->shape[axis]);
}
for (int i = 0; i < idxInput->shape[cur]; i += tmpTile) {
if (static_cast<int>(cur) == axis) {
scatterTileInfo.idxTileInfo.offset[cur] = 0;
scatterTileInfo.idxTileInfo.shape[cur] = idxInput->shape[cur];
scatterTileInfo.dstTileInfo.offset[cur] = 0;
scatterTileInfo.dstTileInfo.shape[cur] = dstTensor->shape[cur];
scatterTileInfo.srcTileInfo.offset[cur] = 0;
scatterTileInfo.srcTileInfo.shape[cur] = srcInput->shape[cur];
} else {
scatterTileInfo.idxTileInfo.offset[cur] = i % idxInput->shape[cur];
scatterTileInfo.idxTileInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxTileInfo.offset[cur], tmpTile);
scatterTileInfo.dstTileInfo.offset[cur] = i;
scatterTileInfo.dstTileInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxTileInfo.offset[cur], tmpTile);
scatterTileInfo.srcTileInfo.offset[cur] = i;
scatterTileInfo.srcTileInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxTileInfo.offset[cur], tmpTile);
}
InnerTiledScatterElementS(cur + 1, function, tileShape, scatterPara, scatterTileInfo);
}
}
void TiledScatterElementS(Function& function, const TileShape& tileShape, const ScatterElementSPara& scatterPara)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.srcInput->shape.size() == scatterPara.srcInput->offset.size())
<< "The size of srcInput shape and offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.idxInput->shape.size() == scatterPara.idxInput->offset.size())
<< "The size of idxInput shape and offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.dstTensor->shape.size() == scatterPara.dstTensor->offset.size())
<< "The size of dst shape and offset should be equal";
ScatterElementSTileInfoPara scatterTileInfo{
TileInfo(scatterPara.srcInput->shape.size(), scatterPara.srcInput->offset.size()),
TileInfo(scatterPara.idxInput->shape.size(), scatterPara.idxInput->offset.size()),
TileInfo(scatterPara.dstTensor->shape.size(), scatterPara.dstTensor->offset.size()),
};
InnerTiledScatterElementS(0, function, tileShape, scatterPara, scatterTileInfo);
}
void TensorScatterElementS(Function& function, const ScatterElementSPara& scatterPara)
{
auto& op = GraphUtils::AddDynOperation(
function, Opcode::OP_SCATTER_ELEMENT, {scatterPara.srcInput, scatterPara.idxInput}, {scatterPara.dstTensor});
op.SetAttribute(OP_ATTR_PREFIX + "axis", scatterPara.axis);
op.SetAttribute(OpAttributeKey::scalar, scatterPara.scalar);
op.SetAttribute(OP_ATTR_PREFIX + "scatter_mode", scatterPara.scatterMode);
std::map<int, int> inplaceInfo = {{0, 0}};
op.SetAttr(OpAttributeKey::inplaceInfo, inplaceInfo);
}
static void CheckScatterElementSParamsInvalid(
const Tensor& self, const Tensor& indices, int axis, const ScatterMode reduce)
{
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_FP16, DT_BF16, DT_INT8, DT_UINT8, DT_INT16, DT_INT32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "SCATTER");
std::unordered_set<DataType> indexSupportedTypes = {DT_INT32, DT_INT64};
CheckTensorDataType(indices.GetStorage(), indexSupportedTypes, "SCATTER");
std::vector<LogicalTensorPtr> tensors = {self.GetStorage(), indices.GetStorage()};
CheckTensorsDimConsistency(tensors, "SCATTER");
CheckTensorsFormatConsistency(self.GetStorage(), indices.GetStorage(), "SCATTER");
CheckAxisRange(self, axis);
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, reduce <= ScatterMode::UNKNOWN)
<< "The ScatterMode of reduce should be less than UNKNOWN";
for (size_t i = 0; i < self.GetShape().size(); i++) {
if (static_cast<int>(i) == axis) {
continue;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices.GetShape()[i] <= self.GetShape()[i])
<< "The shape of indices and self should be equal";
}
CheckTensorDimRange(self.GetStorage(), 1, 4, "SCATTER");
CheckTensorShapeSize(self.GetStorage(), "SCATTER");
CheckTensorShapeSize(indices.GetStorage(), "SCATTER");
}
Tensor Scatter(const Tensor& self, const Tensor& indices, const Element& src, int axis, ScatterMode reduce)
{
DECLARE_TRACER();
CheckScatterElementSParamsInvalid(self, indices, axis < 0 ? self.GetShape().size() + axis : axis, reduce);
DataType orgDtype = self.GetDataType();
auto operandCast = Tensor(DataType::DT_FP32, self.GetShape());
if ((orgDtype == DataType::DT_FP16 || orgDtype == DataType::DT_BF16) &&
(reduce == ScatterMode::ADD || reduce == ScatterMode::MULTIPLY)) {
operandCast = CALL(
CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(), self.GetStorage(),
DataType::DT_FP32, CastMode::CAST_NONE);
} else {
operandCast = self;
}
axis = axis < 0 ? operandCast.GetShape().size() + axis : axis;
Tensor result(operandCast.GetStorage()->Datatype(), operandCast.GetShape());
CALL(
ScatterElementS, *Program::GetInstance().GetCurrentFunction(),
{result.GetStorage(), operandCast.GetStorage(), indices.GetStorage(), src, axis, static_cast<int>(reduce)});
if ((orgDtype == DataType::DT_FP16 || orgDtype == DataType::DT_BF16) &&
(reduce == ScatterMode::ADD || reduce == ScatterMode::MULTIPLY)) {
RETURN_CALL(
CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(), result.GetStorage(),
orgDtype, CastMode::CAST_RINT);
}
return result;
}
struct ScatterPara {
const LogicalTensorPtr& dstTensor;
const LogicalTensorPtr& selfInput;
const LogicalTensorPtr& idxInput;
const LogicalTensorPtr& srcInput;
const int axis;
const int scatterMode;
};
struct ScatterTileInfoPara {
TileInfo srcInfo;
TileInfo idxInfo;
TileInfo dstInfo;
TileInfo selfInfo;
};
void InnerTiledScatter(
size_t cur, Function& function, const TileShape& tileShape, const ScatterPara& scatterPara,
ScatterTileInfoPara& scatterTileInfo)
{
const LogicalTensorPtr& dstTensor = scatterPara.dstTensor;
const LogicalTensorPtr& selfInput = scatterPara.selfInput;
const LogicalTensorPtr& idxInput = scatterPara.idxInput;
const LogicalTensorPtr& srcInput = scatterPara.srcInput;
const int axis = scatterPara.axis;
const int mode = scatterPara.scatterMode;
if (cur == dstTensor->shape.size()) {
auto selfTile = selfInput->View(function, scatterTileInfo.selfInfo.shape, scatterTileInfo.selfInfo.offset);
auto idxTile = idxInput->View(function, scatterTileInfo.idxInfo.shape, scatterTileInfo.idxInfo.offset);
auto srcTile = srcInput->View(function, scatterTileInfo.srcInfo.shape, scatterTileInfo.srcInfo.offset);
auto dstTile = dstTensor->View(function, scatterTileInfo.dstInfo.shape, scatterTileInfo.dstInfo.offset);
Shape tmpShape({idxTile->GetShape()[idxTile->GetShape().size() - 1]});
auto tmpBuffer = std::make_shared<LogicalTensor>(function, idxTile->Datatype(), tmpShape);
auto& op = function.AddOperation(Opcode::OP_SCATTER, {selfTile, idxTile, srcTile}, {dstTile, tmpBuffer});
op.SetAttribute(OP_ATTR_PREFIX + "axis", axis);
op.SetAttribute(OP_ATTR_PREFIX + "scatter_mode", mode);
return;
}
auto& vecTile = tileShape.GetVecTile();
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, vecTile[axis] >= dstTensor->shape[axis])
<< "The axis is not supported for tile splitting";
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, vecTile[axis] >= idxInput->shape[axis])
<< "The axis is not supported for tile splitting";
int64_t tmpTile = vecTile[cur];
if (static_cast<int>(cur) == axis) {
tmpTile = std::max(dstTensor->shape[axis], idxInput->shape[axis]);
}
for (int i = 0; i < idxInput->shape[cur]; i += tmpTile) {
if (static_cast<int>(cur) == axis) {
scatterTileInfo.idxInfo.offset[cur] = 0;
scatterTileInfo.idxInfo.shape[cur] = idxInput->shape[cur];
scatterTileInfo.dstInfo.offset[cur] = 0;
scatterTileInfo.dstInfo.shape[cur] = dstTensor->shape[cur];
scatterTileInfo.srcInfo.offset[cur] = 0;
scatterTileInfo.srcInfo.shape[cur] = idxInput->shape[cur];
scatterTileInfo.selfInfo.offset[cur] = 0;
scatterTileInfo.selfInfo.shape[cur] = selfInput->shape[cur];
} else {
scatterTileInfo.idxInfo.offset[cur] = i % idxInput->shape[cur];
scatterTileInfo.idxInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxInfo.offset[cur], tmpTile);
scatterTileInfo.dstInfo.offset[cur] = i;
scatterTileInfo.dstInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxInfo.offset[cur], tmpTile);
scatterTileInfo.srcInfo.offset[cur] = i;
scatterTileInfo.srcInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxInfo.offset[cur], tmpTile);
scatterTileInfo.selfInfo.offset[cur] = i;
scatterTileInfo.selfInfo.shape[cur] =
std::min(idxInput->shape[cur] - scatterTileInfo.idxInfo.offset[cur], tmpTile);
}
InnerTiledScatter(cur + 1, function, tileShape, scatterPara, scatterTileInfo);
}
}
void TiledScatter(Function& function, const TileShape& tileShape, const ScatterPara& scatterPara)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.srcInput->shape.size() == scatterPara.srcInput->offset.size())
<< "The shape size of srcInput and offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.idxInput->shape.size() == scatterPara.idxInput->offset.size())
<< "The shape size of idxInput and offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.dstTensor->shape.size() == scatterPara.dstTensor->offset.size())
<< "The shape size of dst and offset should be equal";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, scatterPara.selfInput->shape.size() == scatterPara.selfInput->offset.size())
<< "The shape size of selfInput and offset should be equal";
ScatterTileInfoPara scatterTileInfo{
TileInfo(scatterPara.srcInput->shape.size(), scatterPara.srcInput->offset.size()),
TileInfo(scatterPara.idxInput->shape.size(), scatterPara.idxInput->offset.size()),
TileInfo(scatterPara.dstTensor->shape.size(), scatterPara.dstTensor->offset.size()),
TileInfo(scatterPara.selfInput->shape.size(), scatterPara.selfInput->offset.size()),
};
InnerTiledScatter(0, function, tileShape, scatterPara, scatterTileInfo);
}
void TensorScatter(Function& function, const ScatterPara& scatterPara)
{
auto& op = GraphUtils::AddDynOperation(
function, Opcode::OP_SCATTER, {scatterPara.selfInput, scatterPara.idxInput, scatterPara.srcInput},
{scatterPara.dstTensor});
op.SetAttribute(OP_ATTR_PREFIX + "axis", scatterPara.axis);
op.SetAttribute(OP_ATTR_PREFIX + "scatter_mode", scatterPara.scatterMode);
std::map<int, int> inplaceInfo = {{0, 0}};
op.SetAttr(OpAttributeKey::inplaceInfo, inplaceInfo);
}
static void CheckScatterParamsInvalid(
const Tensor& self, const Tensor& indices, const Tensor& src, int axis, const ScatterMode reduce)
{
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_FP16, DT_BF16, DT_INT8, DT_UINT8, DT_INT16, DT_INT32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "SCATTER");
CheckTensorsDataTypeConsistency(self.GetStorage(), src.GetStorage(), "SCATTER");
std::unordered_set<DataType> indexSupportedTypes = {DT_INT32, DT_INT64};
CheckTensorDataType(indices.GetStorage(), indexSupportedTypes, "SCATTER");
std::vector<LogicalTensorPtr> tensors = {self.GetStorage(), indices.GetStorage(), src.GetStorage()};
CheckTensorsDimConsistency(tensors, "SCATTER");
CheckTensorsFormatConsistency(self.GetStorage(), indices.GetStorage(), "SCATTER");
CheckTensorsFormatConsistency(self.GetStorage(), src.GetStorage(), "SCATTER");
CheckTensorsFormatConsistency(indices.GetStorage(), src.GetStorage(), "SCATTER");
CheckAxisRange(self, axis);
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, reduce <= ScatterMode::UNKNOWN)
<< "The ScatterMode of reduce should be less than UNKNOWN";
for (size_t i = 0; i < self.GetShape().size(); i++) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices.GetShape()[i] <= src.GetShape()[i])
<< "The shape size of src and indices should be equal";
if (static_cast<int>(i) == axis) {
continue;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices.GetShape()[i] <= self.GetShape()[i])
<< "The shape size of src and indices should be equal";
}
CheckTensorDimRange(self.GetStorage(), 1, 4, "SCATTER");
CheckTensorShapeSize(self.GetStorage(), "SCATTER");
CheckTensorShapeSize(indices.GetStorage(), "SCATTER");
CheckTensorShapeSize(src.GetStorage(), "SCATTER");
}
Tensor Scatter(const Tensor& self, const Tensor& indices, const Tensor& src, int axis, ScatterMode reduce)
{
DECLARE_TRACER();
CheckScatterParamsInvalid(self, indices, src, axis < 0 ? self.GetShape().size() + axis : axis, reduce);
DataType orgDtype = self.GetDataType();
auto operandSelfCast = Tensor(DataType::DT_FP32, self.GetShape());
auto operandSrcCast = Tensor(DataType::DT_FP32, src.GetShape());
if ((orgDtype == DataType::DT_FP16 || orgDtype == DataType::DT_BF16) &&
(reduce == ScatterMode::ADD || reduce == ScatterMode::MULTIPLY)) {
operandSelfCast = CALL(
CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(), self.GetStorage(),
DataType::DT_FP32, CastMode::CAST_NONE);
operandSrcCast = CALL(
CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(), src.GetStorage(),
DataType::DT_FP32, CastMode::CAST_NONE);
} else {
operandSelfCast = self;
operandSrcCast = src;
}
axis = axis < 0 ? operandSelfCast.GetShape().size() + axis : axis;
Tensor result(operandSelfCast.GetStorage()->Datatype(), operandSelfCast.GetShape());
CALL(
Scatter, *Program::GetInstance().GetCurrentFunction(),
{result.GetStorage(), operandSelfCast.GetStorage(), indices.GetStorage(), operandSrcCast.GetStorage(), axis,
static_cast<int>(reduce)});
if ((orgDtype == DataType::DT_FP16 || orgDtype == DataType::DT_BF16) &&
(reduce == ScatterMode::ADD || reduce == ScatterMode::MULTIPLY)) {
RETURN_CALL(
CastOperation<CastOpType::CAST>, *Program::GetInstance().GetCurrentFunction(), result.GetStorage(),
orgDtype, CastMode::CAST_RINT);
}
return result;
}
void TiledScatterUpdate(
size_t cur, Function& function, const TileShape& tileShape, Input& srcInput, Input& indexInput, Input& dstInput,
int axis, const LogicalTensorPtr& dst, TileInfo& dstTileInfo, std::string cacheMode, int blockSize)
{
if (cur == dst->shape.size()) {
auto srcTile = srcInput.tensor.GetStorage()->View(function, srcInput.tileInfo.shape, srcInput.tileInfo.offset);
auto dstTile = dstInput.tensor.GetStorage()->View(function, dstTileInfo.shape, dstTileInfo.offset);
auto resultTile = dst->View(function, dstTileInfo.shape, dstTileInfo.offset);
auto indexTile =
indexInput.tensor.GetStorage()->View(function, indexInput.tileInfo.shape, indexInput.tileInfo.offset);
auto& op = function.AddOperation("TILE_INDEX_OUTCAST", {srcTile, indexTile, dstTile}, {resultTile});
op.SetAttribute("axis", axis);
op.SetAttribute(OpAttributeKey::panzBlockSize, blockSize);
op.SetAttribute(OpAttributeKey::cacheMode, cacheMode);
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t tmpTile = vecTile[cur];
if (static_cast<int>(cur) == axis) {
tmpTile = dst->shape[cur];
}
for (int i = 0; i < dst->shape[cur]; i += tmpTile) {
if (static_cast<int>(cur) == axis) {
srcInput.tileInfo.offset[cur] = 0;
srcInput.tileInfo.shape[cur] = srcInput.tensor.GetShape()[cur];
if (cur <= 1) {
indexInput.tileInfo.offset[cur] = 0;
indexInput.tileInfo.shape[cur] = indexInput.tensor.GetShape()[cur];
}
dstTileInfo.offset[cur] = 0;
dstTileInfo.shape[cur] = dst->shape[cur];
} else {
srcInput.tileInfo.offset[cur] = i % srcInput.tensor.GetShape()[cur];
srcInput.tileInfo.shape[cur] =
std::min(srcInput.tensor.GetShape()[cur] - srcInput.tileInfo.offset[cur], tmpTile);
if (cur == 0) {
indexInput.tileInfo.offset[cur] = i % indexInput.tensor.GetShape()[cur];
indexInput.tileInfo.shape[cur] =
std::min(indexInput.tensor.GetShape()[cur] - indexInput.tileInfo.offset[cur], tmpTile);
} else {
indexInput.tileInfo.offset[1] = 0;
indexInput.tileInfo.shape[1] = indexInput.tensor.GetShape()[1];
}
dstTileInfo.offset[cur] = i;
dstTileInfo.shape[cur] = std::min(dst->shape[cur] - dstTileInfo.offset[cur], tmpTile);
}
TiledScatterUpdate(
cur + 1, function, tileShape, srcInput, indexInput, dstInput, axis, dst, dstTileInfo, cacheMode, blockSize);
}
}
void TiledIndexScatterUpdate(
size_t cur, Function& function, const TileShape& tileShape, Input& srcInput, Input& indexInput, Input& dstInput,
int axis, const std::shared_ptr<LogicalTensor>& dst, TileInfo& dstTileInfo, std::string cacheMode, int blockSize)
{
if (cur == dst->shape.size()) {
auto srcTile = srcInput.tensor.GetStorage()->View(function, srcInput.tileInfo.shape, srcInput.tileInfo.offset);
auto dstTile = dstInput.tensor.GetStorage()->View(function, dstTileInfo.shape, dstTileInfo.offset);
auto indexTile =
indexInput.tensor.GetStorage()->View(function, indexInput.tileInfo.shape, indexInput.tileInfo.offset);
auto& op = function.AddOperation("TILE_INDEX_OUTCAST", {srcTile, indexTile, dstTile}, {dst});
op.SetAttribute("axis", axis);
op.SetAttribute(OpAttributeKey::panzBlockSize, blockSize);
op.SetAttribute(OpAttributeKey::cacheMode, cacheMode);
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t tmpTile = vecTile[cur];
if (static_cast<int>(cur) == axis) {
tmpTile = srcInput.tensor.GetShape()[cur];
}
for (int i = 0; i < srcInput.tensor.GetShape()[cur]; i += tmpTile) {
if (static_cast<int>(cur) == axis) {
srcInput.tileInfo.offset[cur] = 0;
srcInput.tileInfo.shape[cur] = srcInput.tensor.GetShape()[cur];
int64_t indexTileLen = vecTile[0];
indexInput.tileInfo.offset[cur] = 0;
indexInput.tileInfo.shape[cur] =
std::min(indexInput.tensor.GetShape()[cur] - indexInput.tileInfo.offset[0], indexTileLen);
indexInput.tileInfo.offset[cur] = indexInput.tileInfo.offset[0];
indexInput.tileInfo.offset[0] = 0;
dstTileInfo.offset[cur] = 0;
dstTileInfo.shape[cur] = dst->shape[cur];
} else {
srcInput.tileInfo.offset[cur] = i % srcInput.tensor.GetShape()[cur];
srcInput.tileInfo.shape[cur] =
std::min(srcInput.tensor.GetShape()[cur] - srcInput.tileInfo.offset[cur], tmpTile);
indexInput.tileInfo.offset[0] = i % indexInput.tensor.GetShape()[1];
indexInput.tileInfo.shape[0] = indexInput.tensor.GetShape()[0];
dstTileInfo.offset[cur] = i;
dstTileInfo.shape[cur] = tmpTile;
}
TiledIndexScatterUpdate(
cur + 1, function, tileShape, srcInput, indexInput, dstInput, axis, dst, dstTileInfo, cacheMode, blockSize);
}
}
void TiledScatterUpdateFor2Dims(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& result, const LogicalTensorPtr& src,
const LogicalTensorPtr& index, const LogicalTensorPtr& dst, int axis, std::string cacheMode, int blockSize)
{
auto& vecTile = tileShape.GetVecTile();
int64_t tileBS = vecTile[NUM_VALUE_0];
int64_t tileD = vecTile[NUM_VALUE_1];
int64_t s = index->shape[1];
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, s != 0) << "1 dim of index is zero.";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, tileBS != 0) << "0 dim of tileshape is zero.";
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, (tileBS <= s && s % tileBS == 0) || (tileBS > s && tileBS % s == 0))
<< "tileshape 0 is invalid, tileshape(" << tileBS << ", " << tileD << ")";
ASSERT(VectorErrorCode::ERR_CONFIG_TILE, tileD == src->shape[NUM_VALUE_1])
<< "The tileD and src shape[0] should be equal";
int64_t tileB = CeilDiv(tileBS, s);
int64_t tileS = tileBS < s ? tileBS : s;
int64_t bsOffset = 0;
for (int64_t bIdx = 0; bIdx < index->shape[0]; bIdx += tileB) {
for (int64_t sIdx = 0; sIdx < index->shape[1]; sIdx += tileS) {
auto indexTile = index->View(
function, {std::min(index->shape[0] - bIdx, tileB), std::min(index->shape[1] - sIdx, tileS)},
{bIdx, sIdx});
for (int64_t j = 0; j < src->shape[1]; j += tileD) {
auto srcTile = src->View(
function, {std::min(src->shape[0] - bsOffset, tileBS), std::min(src->shape[1] - j, tileD)},
{bsOffset, j});
auto& op = function.AddOperation("TILE_INDEX_OUTCAST", {srcTile, indexTile, dst}, {result});
op.SetAttribute("axis", axis);
op.SetAttribute(OpAttributeKey::panzBlockSize, blockSize);
op.SetAttribute(OpAttributeKey::cacheMode, cacheMode);
}
bsOffset += tileBS;
}
}
}
void TiledScatterUpdateFor4Dims(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& result, const LogicalTensorPtr& src,
const LogicalTensorPtr& index, const LogicalTensorPtr& dst, int axis, std::string cacheMode, int blockSize)
{
auto& vecTile = tileShape.GetVecTile();
int64_t tileB = vecTile[NUM_VALUE_0];
int64_t tileS = vecTile[NUM_VALUE_1];
int64_t tileN = vecTile[NUM_VALUE_2];
int64_t tileD = vecTile[NUM_VALUE_3];
for (int64_t i = 0; i < src->shape[0]; i += tileB) {
for (int64_t j = 0; j < src->shape[1]; j += tileS) {
auto indexTile = index->View(
function, {std::min(index->shape[0] - i, tileB), std::min(index->shape[1] - j, tileS)}, {i, j});
for (int64_t n = 0; n < src->shape[2]; n += tileN) {
for (int64_t d = 0; d < src->shape[3]; d += tileD) {
auto srcTile = src->View(
function,
{std::min(src->shape[0] - i, tileB), std::min(src->shape[1] - j, tileS),
std::min(src->shape[2] - n, tileN), std::min(src->shape[3] - d, tileD)},
{i, j, n, d});
auto& op = function.AddOperation("TILE_INDEX_OUTCAST", {srcTile, indexTile, dst}, {result});
op.SetAttribute("axis", axis);
op.SetAttribute(OpAttributeKey::panzBlockSize, blockSize);
op.SetAttribute(OpAttributeKey::cacheMode, cacheMode);
}
}
}
}
}
void TiledScatterUpdate(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& result, const LogicalTensorPtr& src,
const LogicalTensorPtr& index, const LogicalTensorPtr& dst, int axis, std::string cacheMode, int blockSize)
{
if (cacheMode == "PA_BSND") {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, src->shape.size() == NUM_VALUE_2 || src->shape.size() == NUM_VALUE_4)
<< "shape must be 2 or 4";
if (src->shape.size() == NUM_VALUE_2) {
TiledScatterUpdateFor2Dims(function, tileShape, result, src, index, dst, axis, cacheMode, blockSize);
} else if (src->shape.size() == NUM_VALUE_4) {
TiledScatterUpdateFor4Dims(function, tileShape, result, src, index, dst, axis, cacheMode, blockSize);
}
return;
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, result->shape.size() == result->offset.size())
<< "The shape of result and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, src->shape.size() == src->offset.size())
<< "The shape of src and offset should be equal";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, index->shape.size() == index->offset.size())
<< "The shape of index and offset should be equal";
TileInfo srcTileInfo(src->shape.size(), src->offset.size());
TileInfo indexTileInfo(index->shape.size(), index->offset.size());
TileInfo dstTileInfo(dst->shape.size(), dst->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto srcInput = Input{src, srcTileInfo};
auto indexInput = Input{index, indexTileInfo};
auto dstInput = Input{dst, dstTileInfo};
auto& vecTile = tileShape.GetVecTile();
if (axis == 1 && src->shape.size() == NUM_VALUE_2 && vecTile[1] == src->shape[1]) {
TiledIndexScatterUpdate(
0, function, tileShape, srcInput, indexInput, dstInput, axis, result, resultTileInfo, cacheMode, blockSize);
} else {
TiledScatterUpdate(
0, function, tileShape, srcInput, indexInput, dstInput, axis, result, resultTileInfo, cacheMode, blockSize);
}
}
void TensorScatterUpdate(
Function& function, const LogicalTensorPtr& result, const LogicalTensorPtr& dst, const LogicalTensorPtr& index,
const LogicalTensorPtr& src, int axis, std::string cacheMode, int blockSize)
{
std::vector<int> newOffset(src->shape.size(), 0);
auto& op = function.AddOperation(Opcode::OP_INDEX_OUTCAST, {src, index, dst}, {result});
op.SetAttribute("axis", axis);
op.SetAttribute(OpAttributeKey::panzBlockSize, blockSize);
op.SetAttribute(OpAttributeKey::cacheMode, cacheMode);
}
static void CheckScatterUpdateInput(const Tensor& input)
{
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
(input.GetShape().size() == NUM_VALUE_2 &&
(input.GetShape(NUM_VALUE_0) != NUM_VALUE_0 && input.GetShape(NUM_VALUE_1) != NUM_VALUE_0)) ||
(input.GetShape().size() == NUM_VALUE_4 &&
(input.GetShape(NUM_VALUE_0) != NUM_VALUE_0 && input.GetShape(NUM_VALUE_1) != NUM_VALUE_0 &&
input.GetShape(NUM_VALUE_2) != NUM_VALUE_0 && input.GetShape(NUM_VALUE_3) != NUM_VALUE_0)))
<< "The shape of input is invaild";
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
input.GetShape().size() == NUM_VALUE_2 || input.GetShape().size() == NUM_VALUE_4)
<< "The shape size of input is invaild";
CheckTensorDimRange(input.GetStorage(), 2, 4, "SCATTERUPDATE");
}
static void CheckScatterUpdateIndex(const Tensor& index)
{
std::unordered_set<DataType> indexSupportedTypes = {DT_INT64, DT_INT32, DT_INT16};
CheckTensorDataType(index.GetStorage(), indexSupportedTypes, "SCATTERUPDATE");
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID, index.GetShape().size() == NUM_VALUE_2 &&
index.GetShape(NUM_VALUE_0) != NUM_VALUE_0 &&
index.GetShape(NUM_VALUE_1) != NUM_VALUE_0)
<< "The shape of index is invaild";
}
static void CheckScatterUpdateInvalid(const Tensor& dst, const Tensor& index, const Tensor& src)
{
std::vector<LogicalTensorPtr> tensors = {dst.GetStorage(), src.GetStorage()};
CheckTensorsDimConsistency(tensors, "SCATTERUPDATE");
CheckScatterUpdateIndex(index);
CheckScatterUpdateInput(src);
CheckScatterUpdateInput(dst);
}
Tensor ScatterUpdate(
const Tensor& dst, const Tensor& index, const Tensor& src, int axis, std::string cacheMode, int chunkSize)
{
DECLARE_TRACER();
CheckScatterUpdateInvalid(dst, index, src);
CheckAxisRange(dst, axis);
Tensor result(dst.GetStorage()->Datatype(), dst.GetStorage()->GetShape(), "", dst.Format());
if (std::find(dst.GetStorage()->GetShape().begin(), dst.GetStorage()->GetShape().end(), -1) !=
dst.GetStorage()->GetShape().end()) {
Tensor resTmp(dst.GetStorage()->Datatype(), dst.GetStorage()->GetDynValidShape(), "", dst.Format());
result = resTmp;
}
if (cacheMode == "PA_NZ") {
axis = 1;
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, src.GetShape().size() == NUM_VALUE_2)
<< "Only support 2 dim";
Tensor newIndex = Reshape(index, {1, index.GetShape()[0] * index.GetShape()[1]});
CALL(
ScatterUpdate, *Program::GetInstance().GetCurrentFunction(), result.GetStorage(), dst.GetStorage(),
newIndex.GetStorage(), src.GetStorage(), axis, cacheMode, chunkSize);
} else {
CALL(
ScatterUpdate, *Program::GetInstance().GetCurrentFunction(), result.GetStorage(), dst.GetStorage(),
index.GetStorage(), src.GetStorage(), axis, cacheMode, chunkSize);
}
return result;
}
void TiledIndexPut(
Function& function, const TileShape& tileShape, Input& inputSelf, Input& inputValues,
std::vector<Input>& inputIndices, const LogicalTensorPtr result, bool accumulate, size_t cur)
{
size_t selfDim = inputSelf.tileInfo.shape.size();
size_t valuesDim = inputValues.tileInfo.shape.size();
size_t indicesCount = inputIndices.size();
if (cur == valuesDim) {
auto inputSelfTile =
inputSelf.tensor.GetStorage()->View(function, inputSelf.tileInfo.shape, inputSelf.tileInfo.offset);
auto inputValuesTile =
inputValues.tensor.GetStorage()->View(function, inputValues.tileInfo.shape, inputValues.tileInfo.offset);
std::vector<LogicalTensorPtr> inputsTile;
inputsTile.push_back(inputSelfTile);
inputsTile.push_back(inputValuesTile);
for (size_t j = 0; j < indicesCount; j++) {
auto inputIndicesTile = inputIndices[j].tensor.GetStorage()->View(
function, inputIndices[j].tileInfo.shape, inputIndices[j].tileInfo.offset);
inputsTile.push_back(inputIndicesTile);
}
auto& newOp = function.AddOperation(Opcode::OP_INDEX_PUT, inputsTile, {result});
newOp.SetAttribute(OpAttributeKey::inplaceIdx, 0);
newOp.SetAttribute(OpAttributeKey::accumulate, accumulate);
newOp.SetAttribute(OpAttributeKey::indicesSize, static_cast<int>(indicesCount));
return;
}
const auto& vecTile = tileShape.GetVecTile();
int64_t tileSize = inputValues.tensor.GetShape()[cur];
if (cur < vecTile.size()) {
tileSize = vecTile[cur];
}
for (int64_t i = 0, size = inputValues.tensor.GetShape()[cur]; i < size; i += tileSize) {
if (cur != 0) {
size_t selfIndex = selfDim - valuesDim + cur;
inputSelf.tileInfo.shape[selfIndex] = std::min(inputSelf.tensor.GetShape()[selfIndex] - i, tileSize);
inputSelf.tileInfo.offset[selfIndex] = i;
}
inputValues.tileInfo.shape[cur] = std::min(inputValues.tensor.GetShape()[cur] - i, tileSize);
inputValues.tileInfo.offset[cur] = i;
if (cur == 0) {
for (size_t j = 0; j < indicesCount; ++j) {
inputIndices[j].tileInfo.shape[cur] = std::min(inputIndices[j].tensor.GetShape()[cur] - i, tileSize);
inputIndices[j].tileInfo.offset[cur] = i;
}
}
TiledIndexPut(function, tileShape, inputSelf, inputValues, inputIndices, result, accumulate, cur + 1);
}
}
void TiledIndexPut(
Function& function, const TileShape& tileShape, const LogicalTensorPtr& self, const LogicalTensorPtr& values,
const std::vector<LogicalTensorPtr>& indices, const LogicalTensorPtr& result, bool accumulate)
{
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, self->GetShape().size() == self->GetOffset().size());
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, values->GetShape().size() == values->GetOffset().size());
for (size_t i = 0; i < indices.size(); i++) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices[i]->GetShape().size() == indices[i]->GetOffset().size());
}
TileInfo valuesTileInfo(values->shape.size(), values->offset.size());
TileInfo selfTileInfo(self->shape.size(), self->offset.size());
auto inputValues = Input{values, valuesTileInfo};
auto inputSelf = Input{self, selfTileInfo};
for (size_t i = 0, size = self->shape.size(); i < size; ++i) {
inputSelf.tileInfo.shape[i] = self->shape[i];
inputSelf.tileInfo.offset[i] = 0;
}
std::vector<Input> inputIndices;
for (size_t i = 0, size = indices.size(); i < size; ++i) {
TileInfo indicesTileInfoTemp(indices[i]->shape.size(), indices[i]->offset.size());
auto inputIndicesTemp = Input{indices[i], indicesTileInfoTemp};
inputIndices.push_back(inputIndicesTemp);
}
TiledIndexPut(function, tileShape, inputSelf, inputValues, inputIndices, result, accumulate, 0);
}
void TensorIndexPut(
Function& function, const LogicalTensorPtr& self, const LogicalTensors& indices, const LogicalTensorPtr& values,
const LogicalTensorPtr& dst, bool accumulate)
{
LogicalTensors iOperands = indices;
iOperands.insert(iOperands.begin(), {self, values});
auto& op = function.AddOperation(Opcode::OP_INDEX_PUT, iOperands, {dst});
op.SetAttribute(OpAttributeKey::inplaceIdx, 0);
op.SetAttribute(OpAttributeKey::accumulate, accumulate);
op.SetAttribute(OpAttributeKey::indicesSize, static_cast<int>(indices.size()));
function.UpdateTensorDataUsage(op);
}
static void CheckIndexPutParamsInvalid(const Tensor& self, const std::vector<Tensor>& indices, const Tensor& values)
{
std::unordered_set<DataType> supportedTypes = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_UINT32,
DT_INT64, DT_UINT64, DT_BF16, DT_FP16, DT_FP32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "INDEXPUT");
CheckTensorsDataTypeConsistency(self.GetStorage(), values.GetStorage(), "INDEXPUT");
std::unordered_set<DataType> indexSupportedTypes = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16,
DT_INT32, DT_UINT32, DT_INT64, DT_UINT64};
int indicesShape = -1;
for (size_t i = 0; i < indices.size(); i++) {
CheckTensorDataType(indices[i].GetStorage(), indexSupportedTypes, "INDEXPUT");
CheckTensorDimRange(indices[i].GetStorage(), 1, 1, "INDEXPUT");
if (indicesShape == -1) {
indicesShape = indices[i].GetShape()[0];
} else {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices[i].GetShape()[0] == indicesShape)
<< "Tensors in indices should have the same shape";
}
CheckTensorShapeSize(indices[i].GetStorage(), "INDEXPUT");
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, indices.size() >= 1 && indices.size() <= 4)
<< "indicesSize is out of range [1, 4]";
CheckTensorDimRange(self.GetStorage(), 1, 4, "INDEXPUT");
CheckTensorDimRange(values.GetStorage(), 1, 4, "INDEXPUT");
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, self.GetShape().size() + 1 == indices.size() + values.GetShape().size())
<< "unsupport the inputs shape combination: dimSelf + 1 != indicesSize + dimValues";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, values.GetShape()[0] == indicesShape)
<< "valuesFirstDim should equal to indicesShape";
for (size_t i = 1; i < values.GetShape().size(); i++) {
ASSERT(
VectorErrorCode::ERR_PARAM_INVALID,
self.GetShape()[self.GetShape().size() - i] == values.GetShape()[values.GetShape().size() - i])
<< "valuesShape should match selfShape";
}
CheckTensorShapeSize(self.GetStorage(), "INDEXPUT");
CheckTensorShapeSize(values.GetStorage(), "INDEXPUT");
CheckTensorsFormatConsistency(self.GetStorage(), values.GetStorage(), "INDEXPUT");
for (size_t i = 0; i < indices.size(); i++) {
CheckTensorsFormatConsistency(self.GetStorage(), indices[i].GetStorage(), "INDEXPUT");
CheckTensorsFormatConsistency(values.GetStorage(), indices[i].GetStorage(), "INDEXPUT");
}
}
void IndexPut_(Tensor& self, const std::vector<Tensor>& indices, const Tensor& values, bool accumulate)
{
DECLARE_TRACER();
CheckIndexPutParamsInvalid(self, indices, values);
std::vector<LogicalTensorPtr> indicesLogical;
for (size_t i = 0; i < indices.size(); i++) {
indicesLogical.push_back(indices[i].GetStorage());
}
Tensor dst(self.GetDataType(), self.GetShape());
CALL(
IndexPut, *Program::GetInstance().GetCurrentFunction(), self.GetStorage(), indicesLogical, values.GetStorage(),
dst.GetStorage(), accumulate);
Program::GetInstance().GetCurrentFunction()->SetSameMemId(self.GetStorage(), dst.GetStorage());
self = dst;
}
template <typename T, DataType dataType>
Element GetCurStartElement(Element start, Element step, int id)
{
T startValue;
T stepValue;
if (dataType == DT_INT32 || dataType == DT_INT64) {
startValue = start.GetSignedData();
stepValue = step.GetSignedData();
} else if (dataType == DT_FP32) {
startValue = (float)start.GetFloatData();
stepValue = (float)step.GetFloatData();
}
T curStartValue = startValue + id * stepValue;
Element curStart(dataType, curStartValue);
return curStart;
}
const double EPSILON = (double)1e-12;
template <typename T, DataType dataType>
int64_t GetRangeResSize(Element& start, Element& end, Element& step)
{
int64_t resultSize;
if (dataType == DT_INT32 || dataType == DT_INT64) {
int64_t startValue = start.GetSignedData();
int64_t endValue = end.GetSignedData();
int64_t stepValue = step.GetSignedData();
if (abs(stepValue) <= 0) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "stepValue must not be 0";
}
resultSize = (endValue - startValue) % stepValue ? (endValue - startValue) / stepValue + 1 :
(endValue - startValue) / stepValue;
} else if (dataType == DT_FP32) {
double startValue = start.GetFloatData();
double endValue = end.GetFloatData();
double stepValue = step.GetFloatData();
if (abs(stepValue) <= EPSILON) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "stepValue must not be 0";
}
resultSize = static_cast<int64_t>(std::ceil((endValue - startValue) / stepValue));
}
return resultSize;
}
void TiledRange(
Function& function, const TileShape& tileShape, const Element start, const Element step,
const LogicalTensorPtr& result)
{
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto& vecTile = tileShape.GetVecTile();
for (int64_t i = 0; i < result->shape[0]; i += vecTile[0]) {
resultTileInfo.offset[0] = i;
resultTileInfo.shape[0] = std::min(result->shape[0] - resultTileInfo.offset[0], vecTile[0]);
int64_t curSizeValue = resultTileInfo.shape[0];
Element curSize(DT_INT64, curSizeValue);
Element curStart = start;
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
auto& op = function.AddOperation(Opcode::OP_RANGE, {}, {resultTile});
op.SetAttribute(OP_ATTR_PREFIX + "START", curStart);
op.SetAttribute(OP_ATTR_PREFIX + "SIZE", curSize);
op.SetAttribute(OP_ATTR_PREFIX + "STEP", step);
SymbolicScalar tileIdx(i);
op.SetAttribute(OpAttributeKey::dynScalar, tileIdx);
}
return;
}
LogicalTensorPtr TensorRange(Function& function, LogicalTensorPtr& result, Element& start, Element& step)
{
auto& op = function.AddOperation(Opcode::OP_RANGE, {}, {result});
op.SetAttribute(OP_ATTR_PREFIX + "START", start);
op.SetAttribute(OP_ATTR_PREFIX + "STEP", step);
Element size(DT_INT64, result->shape[0]);
op.SetAttribute(OP_ATTR_PREFIX + "SIZE", size);
return result;
}
Tensor RealRange(Element& start, Element& end, Element& step)
{
DECLARE_TRACER();
std::vector<int64_t> resTensorShape;
int64_t resultSize;
if (start.GetDataType() == DT_INT32) {
resultSize = GetRangeResSize<int32_t, DT_INT32>(start, end, step);
} else if (start.GetDataType() == DT_INT64) {
resultSize = GetRangeResSize<int64_t, DT_INT64>(start, end, step);
} else if (start.GetDataType() == DT_FP32) {
resultSize = GetRangeResSize<float, DT_FP32>(start, end, step);
} else {
ASSERT(VectorErrorCode::ERR_PARAM_DTYPE_UNSUPPORTED, false)
<< "Unsupported DataType " << DataType2String(start.GetDataType());
}
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, resultSize > 0)
<< "The positivity or negativity of the step should be aligned with the end-start";
resTensorShape.push_back(resultSize);
auto resTensor = Tensor(start.GetDataType(), resTensorShape);
RETURN_CALL(Range, *Program::GetInstance().GetCurrentFunction(), resTensor.GetStorage(), start, step);
}
bool IsDataTypeUnsupport(DataType dType)
{
return dType != DT_FP32 && dType != DT_INT64 && dType != DT_INT32 && dType != DT_FP16 && dType != DT_BF16 &&
dType != DT_INT16;
}
DataType GetComputeDataType(const Element& start, const Element& end, const Element& step)
{
DataType startType = start.GetDataType();
DataType endType = end.GetDataType();
DataType stepType = step.GetDataType();
if (IsDataTypeUnsupport(startType)) {
ASSERT(VectorErrorCode::ERR_PARAM_DTYPE_UNSUPPORTED, false)
<< "Unsupported Start DataType " << DataType2String(startType);
}
if (IsDataTypeUnsupport(endType)) {
ASSERT(VectorErrorCode::ERR_PARAM_DTYPE_UNSUPPORTED, false)
<< "Unsupported End DataType " << DataType2String(endType);
}
if (IsDataTypeUnsupport(stepType)) {
ASSERT(VectorErrorCode::ERR_PARAM_DTYPE_UNSUPPORTED, false)
<< "Unsupported Step DataType " << DataType2String(stepType);
}
bool startIsFloat = (startType == DT_FP32 || startType == DT_FP16 || startType == DT_BF16);
bool endIsFloat = (endType == DT_FP32 || endType == DT_FP16 || endType == DT_BF16);
bool stepIsFloat = (stepType == DT_FP32 || stepType == DT_FP16 || stepType == DT_BF16);
if (startIsFloat || endIsFloat || stepIsFloat) {
return DT_FP32;
}
int64_t startValue = start.GetSignedData();
int64_t endValue = end.GetSignedData();
int64_t stepValue = step.GetSignedData();
bool startFlag = startValue <= INT_MAX && startValue >= INT_MIN;
bool endFlag = endValue <= INT_MAX && endValue >= INT_MIN;
bool stepFlag = stepValue <= INT_MAX && stepValue >= INT_MIN;
if (startFlag && endFlag && stepFlag) {
return DT_INT32;
}
return DT_INT64;
}
DataType GetOutputDataType(const Element& start, const Element& end, const Element& step)
{
DataType startType = start.GetDataType();
DataType endType = end.GetDataType();
DataType stepType = step.GetDataType();
if (startType == DT_INT16 || endType == DT_INT16 || stepType == DT_INT16) {
return DT_INT16;
}
if (startType == DT_FP32 || endType == DT_FP32 || stepType == DT_FP32) {
return DT_FP32;
}
if (startType == DT_FP16 || endType == DT_FP16 || stepType == DT_FP16) {
return DT_FP16;
}
if (startType == DT_BF16 || endType == DT_BF16 || stepType == DT_BF16) {
return DT_BF16;
}
return DT_INT32;
}
Element GetElementWithDataType(const Element& element, DataType dataType)
{
DataType elementType = element.GetDataType();
bool elementIsFloat = (elementType == DT_FP32) || (elementType == DT_FP16) || (elementType == DT_BF16);
if (elementIsFloat && dataType == DT_FP32) {
return Element(dataType, element.GetFloatData());
} else if (elementIsFloat && dataType != DT_FP32) {
return Element(dataType, (int64_t)element.GetFloatData());
} else if (!elementIsFloat && dataType == DT_FP32) {
return Element(dataType, (double)element.GetSignedData());
}
return Element(dataType, element.GetSignedData());
}
Tensor Range(const Element& start, const Element& end, const Element& step)
{
DataType dataType = GetComputeDataType(start, end, step);
if (dataType != DT_FP32 && dataType != DT_INT32) {
ASSERT(VectorErrorCode::ERR_PARAM_DTYPE_UNSUPPORTED, false)
<< "Unsupported Output DataType " << DataType2String(dataType);
}
DataType outputDataType = DT_INT32;
outputDataType = GetOutputDataType(start, end, step);
Element realStart = GetElementWithDataType(start, dataType);
Element realEnd = GetElementWithDataType(end, dataType);
Element realStep = GetElementWithDataType(step, dataType);
auto resTensor = RealRange(realStart, realEnd, realStep);
if (outputDataType == DT_BF16) {
return Cast(resTensor, DT_BF16);
}
if (outputDataType == DT_FP16) {
return Cast(resTensor, DT_FP16);
}
if (outputDataType == DT_INT16) {
return Cast(resTensor, DT_INT16);
}
return resTensor;
}
Tensor GatherMask(const Tensor& self, const uint8_t patternMode)
{
DECLARE_TRACER();
std::unordered_set<DataType> supportedTypes = {DT_FP32, DT_FP16, DT_BF16, DT_INT32, DT_INT16, DT_UINT16, DT_UINT32};
CheckTensorDataType(self.GetStorage(), supportedTypes, "GATHERMASK");
CheckTensorDimRange(self.GetStorage(), 1, 4, "GATHERMASK");
CheckTensorShapeSize(self.GetStorage(), "GATHERMASK");
auto shape = self.GetShape();
auto& vecTile = TileShape::Current().GetVecTile();
if (patternMode == 1 || patternMode == 2) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, shape[shape.size() - 1] % 2 == 0)
<< "The last axis of input shape should be divisible by 2 when patternMode is 1 or 2";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, vecTile.tile[vecTile.tile.size() - 1] % 2 == 0)
<< "The last axis of tileshape should be divisible by 2 when patternMode is 1 or 2";
shape[shape.size() - 1] = shape[shape.size() - 1] / 2;
} else if (patternMode == 3 || patternMode == 4 || patternMode == 5 || patternMode == 6) {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, shape[shape.size() - 1] % 4 == 0)
<< "The last axis of input shape should be divisible by 4 when patternMode is 3, 4, 5 or 6";
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, vecTile.tile[vecTile.tile.size() - 1] % 4 == 0)
<< "The last axis of tileshape should be divisible by 4 when patternMode is 3, 4, 5 or 6";
shape[shape.size() - 1] = shape[shape.size() - 1] / 4;
} else {
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, patternMode == 7)
<< "Just support patternMode is 1, 2, 3, 4, 5, 6, 7";
}
auto result = Tensor(self.GetDataType(), shape);
if (!self.GetStorage()->GetDynValidShape().empty()) {
std::vector<SymbolicScalar> outValidShape;
for (auto dim : self.GetStorage()->GetDynValidShape()) {
outValidShape.push_back(dim);
}
if (patternMode == 1 || patternMode == 2) {
outValidShape[outValidShape.size() - 1] = outValidShape[outValidShape.size() - 1] / 2;
} else if (patternMode == 3 || patternMode == 4 || patternMode == 5 || patternMode == 6) {
outValidShape[outValidShape.size() - 1] = outValidShape[outValidShape.size() - 1] / 4;
}
result.GetStorage()->UpdateDynValidShape(outValidShape);
}
CALL(GatherMask, *Program::GetInstance().GetCurrentFunction(), self.GetStorage(), result.GetStorage(), patternMode);
return result;
}
void TiledGatherMaskBuildIn(
Function& function, const TileShape& tileShape, size_t cur, Input& input, const LogicalTensorPtr& result,
TileInfo& resultTileInfo, const uint8_t patternMode)
{
if (cur == input.tensor.GetShape().size()) {
auto inputTile = input.tensor.GetStorage()->View(function, input.tileInfo.shape, input.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
auto& op = function.AddOperation(Opcode::OP_GATHER_MASK, {inputTile}, {resultTile});
op.SetAttribute(OP_ATTR_PREFIX + "patternMode", patternMode);
return;
}
auto& vecTile = tileShape.GetVecTile();
for (int i = 0; i < input.tensor.GetShape()[cur]; i += vecTile[cur]) {
input.tileInfo.offset[cur] = i % input.tensor.GetShape()[cur];
input.tileInfo.shape[cur] = std::min(input.tensor.GetShape()[cur] - input.tileInfo.offset[cur], vecTile[cur]);
if ((cur == input.tensor.GetShape().size() - 1) && (patternMode == 1 || patternMode == 2)) {
resultTileInfo.offset[cur] = i / 2;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur] / 2);
} else if (
(cur == input.tensor.GetShape().size() - 1) &&
(patternMode == 3 || patternMode == 4 || patternMode == 5 || patternMode == 6)) {
resultTileInfo.offset[cur] = i / 4;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur] / 4);
} else {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur]);
}
TiledGatherMaskBuildIn(function, tileShape, cur + 1, input, result, resultTileInfo, patternMode);
}
}
void TiledGatherMaskBuildIn(
Function& function, const TileShape& tileShape, const LogicalTensorPtr operand, const LogicalTensorPtr resOperand,
const uint8_t patternMode)
{
TileInfo tileInfo(operand->shape.size(), operand->offset.size());
TileInfo resultTileInfo(resOperand->shape.size(), resOperand->offset.size());
tileInfo.shape = operand->shape;
resultTileInfo.shape = resOperand->shape;
auto input = Input{operand, tileInfo};
TiledGatherMaskBuildIn(function, tileShape, 0, input, resOperand, resultTileInfo, patternMode);
}
void IndexAddUBOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
Element alpha = op.GetElementAttribute(OpAttributeKey::scalar);
TiledIndexAddUB(function, tileShape, {iOperand[0], iOperand[1], iOperand[2], oOperand[0], axis, alpha});
}
void IndexAddOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
Element alpha = op.GetElementAttribute(OpAttributeKey::scalar);
TiledIndexAdd(function, tileShape, {iOperand[0], iOperand[1], iOperand[2], oOperand[0], axis, alpha});
}
void GatherOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
TiledGatherOperation(function, tileShape, iOperand[0], iOperand[1], axis, oOperand[0]);
}
void GatherElementOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
TiledGatherElementOperation(function, tileShape, iOperand[0], iOperand[1], axis, oOperand[0]);
}
void ScatterElementSOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
Element scalar = op.GetElementAttribute(OpAttributeKey::scalar);
int scatterMode = op.GetIntAttribute(OP_ATTR_PREFIX + "scatter_mode");
TiledScatterElementS(function, tileShape, {oOperand[0], iOperand[0], iOperand[1], scalar, axis, scatterMode});
}
void ScatterOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute(OP_ATTR_PREFIX + "axis");
int scatterMode = op.GetIntAttribute(OP_ATTR_PREFIX + "scatter_mode");
TiledScatter(function, tileShape, {oOperand[0], iOperand[0], iOperand[1], iOperand[2], axis, scatterMode});
}
void IndexPutOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
std::vector<LogicalTensorPtr> indices = iOperand;
constexpr size_t num2 = 2;
indices.erase(indices.begin(), indices.begin() + num2);
bool accumulate = op.GetBoolAttribute(OpAttributeKey::accumulate);
TiledIndexPut(function, tileShape, iOperand[0], iOperand[1], indices, oOperand[0], accumulate);
}
void IndexOutcastOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
int axis = op.GetIntAttribute("axis");
int blockSize = op.GetIntAttribute(OpAttributeKey::panzBlockSize);
std::string cacheMode = op.GetStringAttribute(OpAttributeKey::cacheMode);
TiledScatterUpdate(
function, tileShape, oOperand[0], iOperand[0], iOperand[1], iOperand[2], axis, cacheMode, blockSize);
}
void RangeOperationTileFunc(
Function& function, const TileShape& tileShape, [[maybe_unused]] const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
Element start = op.GetElementAttribute(OP_ATTR_PREFIX + "START");
Element step = op.GetElementAttribute(OP_ATTR_PREFIX + "STEP");
TiledRange(function, tileShape, start, step, oOperand[0]);
}
void GatherMaskBuildInOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
uint8_t patternMode = op.GetIntAttribute(OP_ATTR_PREFIX + "patternMode");
TiledGatherMaskBuildIn(function, tileShape, iOperand[0], oOperand[0], patternMode);
}
REGISTER_OPERATION_TILED_FUNC(OP_INDEX_ADD_UB, Opcode::OP_INDEX_ADD_UB, IndexAddUBOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_INDEX_ADD, Opcode::OP_INDEX_ADD, IndexAddOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_GATHER, Opcode::OP_GATHER, GatherOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_GATHER_ELEMENT, Opcode::OP_GATHER_ELEMENT, GatherElementOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_SCATTER_ELEMENT, Opcode::OP_SCATTER_ELEMENT, ScatterElementSOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_SCATTER, Opcode::OP_SCATTER, ScatterOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_INDEX_PUT, Opcode::OP_INDEX_PUT, IndexPutOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_INDEX_OUTCAST, Opcode::OP_INDEX_OUTCAST, IndexOutcastOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_RANGE, Opcode::OP_RANGE, RangeOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(
OP_GATHER_MASK_BUILDIN, Opcode::OP_GATHER_MASK_BUILDIN, GatherMaskBuildInOperationTileFunc);
}