* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file expand_exp_dif.cpp
* \brief
*/
#include "binary.h"
#include "interface/utils/operator_tracer.h"
#include "passes/pass_utils/graph_utils.h"
#include "tensor_transformation.h"
namespace npu::tile_fwk {
void TiledExpandExpDifOperation(
Function& function, const TileShape& tileShape, size_t cur, LogicalInput& input1, LogicalInput& input2,
const LogicalTensorPtr& result, TileInfo& resultTileInfo)
{
if (cur == input1.tensor->GetShape().size()) {
auto inputTile1 = input1.tensor->View(function, input1.tileInfo.shape, input1.tileInfo.offset);
auto inputTile2 = input2.tensor->View(function, input2.tileInfo.shape, input2.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
auto opName = GetBinaryOpNameCode<BinaryOpType::EXPANDEXPDIF>();
auto& op = function.AddOperation(opName, {inputTile1, inputTile2}, {resultTile});
size_t shapeSize = input1.tensor->GetShape().size();
std::vector<int64_t> brcOperand(shapeSize, 0);
size_t brcAxesCount = 0;
for (size_t i = 0; i < shapeSize; i++) {
brcOperand[i] = BrcAxisBinaryOp(input1.tensor, input2.tensor, i);
if (brcOperand[i] != 0) {
brcAxesCount++;
}
}
if (brcAxesCount > 0) {
op.SetAttribute(OpAttributeKey::brcOperand, brcOperand);
}
} else {
auto& vecTile = tileShape.GetVecTile();
for (int i = 0; i < result->shape[cur]; i += vecTile[cur]) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur]);
input1.tileInfo.offset[cur] = i % input1.tensor->GetShape()[cur];
input1.tileInfo.shape[cur] =
std::min(input1.tensor->GetShape()[cur] - input1.tileInfo.offset[cur], vecTile[cur]);
input2.tileInfo.offset[cur] = i % input2.tensor->GetShape()[cur];
input2.tileInfo.shape[cur] =
std::min(input2.tensor->GetShape()[cur] - input2.tileInfo.offset[cur], vecTile[cur]);
TiledExpandExpDifOperation(function, tileShape, cur + 1, input1, input2, result, resultTileInfo);
}
}
}
void TiledExpandExpDifOperation(
Function& function, const TileShape& tileShape, LogicalTensorPtr operand1, LogicalTensorPtr operand2,
const LogicalTensorPtr& result)
{
CheckBinOpOperandsValid(operand1, operand2);
TileInfo tileInfo1(result->shape.size(), result->offset.size());
TileInfo tileInfo2(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto input1 = LogicalInput{operand1, tileInfo1};
auto input2 = LogicalInput{operand2, tileInfo2};
TiledExpandExpDifOperation(function, tileShape, 0, input1, input2, result, resultTileInfo);
}
Tensor ExpandExpDif(const Tensor& input, const Tensor& other)
{
DECLARE_TRACER();
CheckTensorsDataTypeConsistency(input.GetStorage(), other.GetStorage(), "EXPANDEXPDIF");
std::unordered_set<DataType> supportedTypes = {DT_FP16, DT_BF16, DT_FP32};
CheckTensorDataType(input.GetStorage(), supportedTypes, "EXPANDEXPDIF");
CheckBinaryInputTensors(input.GetStorage(), other.GetStorage(), "EXPANDEXPDIF");
config::SetOperationOption(KEY_COMBINE_AXIS, true);
RETURN_CALL(
BinaryOperation<BinaryOpType::EXPANDEXPDIF>, *Program::GetInstance().GetCurrentFunction(), input, other);
}
void ExpandExpDifOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
BinaryOperationOperandCheck(iOperand, oOperand);
TiledExpandExpDifOperation(function, tileShape, iOperand[0], iOperand[1], oOperand[0]);
}
REGISTER_OPERATION_TILED_FUNC(OP_EXPANDEXPDIF, Opcode::OP_EXPANDEXPDIF, ExpandExpDifOperationTileFunc);
}