* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file compare.cpp
* \brief
*/
#include "binary.h"
#include "tensor_transformation.h"
#include "interface/utils/operator_tracer.h"
#include "passes/pass_utils/graph_utils.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
void TiledCompareOperationImpl(
Function& function, const TileShape& tileShape, size_t cur, Input& input1, Input& input2,
const LogicalTensorPtr& result, TileInfo& resultTileInfo, OpType operation, OutType mode)
{
if (cur == result->shape.size()) {
auto inputTile1 = input1.tensor.GetStorage()->View(function, input1.tileInfo.shape, input1.tileInfo.offset);
auto inputTile2 = input2.tensor.GetStorage()->View(function, input2.tileInfo.shape, input2.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
const int64_t COUNT_MODE_SIZE = 4096;
size_t element_size = BytesOf(input1.tensor.GetDataType());
ASSERT(VectorErrorCode::ERR_RUNTIME_LOGIC, element_size != 0) << "Element size cannot be zero.";
int64_t elements_per_chunk = COUNT_MODE_SIZE / element_size;
int64_t vcmp_bits_size = (elements_per_chunk + 7) / 8;
const size_t ALIGN_SIZE = 32;
size_t vcmpBitResult_size = ((vcmp_bits_size + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE;
size_t array_size = elements_per_chunk * element_size;
size_t aligned_array_size = ((array_size + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE;
size_t total_bytes = vcmpBitResult_size + 3 * aligned_array_size + ALIGN_SIZE * 2;
std::vector<int64_t> tmp_shape({static_cast<int64_t>(total_bytes)});
auto tmp_tensor = std::make_shared<LogicalTensor>(function, DT_UINT8, tmp_shape);
auto& op = function.AddOperation(Opcode::OP_CMP, {inputTile1, inputTile2}, {resultTile, tmp_tensor});
std::vector<bool> dimMap({true, true});
op.SetAttr(OpAttributeKey::rowPad, dimMap);
op.SetAttribute(OP_ATTR_PREFIX + "cmp_operation", static_cast<int64_t>(operation));
op.SetAttribute(OP_ATTR_PREFIX + "cmp_mode", static_cast<int64_t>(mode));
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t step = vecTile[cur];
if (mode == OutType::BIT && cur == result->shape.size() - 1) {
step = vecTile[cur] / NUM_VALUE_8;
if (step < 1)
step = 1;
int64_t actualInputStep = step * NUM_VALUE_8;
for (int i = 0; i < result->shape[cur]; i += step) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - i, step);
input1.tileInfo.offset[cur] = (i * NUM_VALUE_8) % input1.tensor.GetShape()[cur];
input1.tileInfo.shape[cur] =
std::min(input1.tensor.GetShape()[cur] - input1.tileInfo.offset[cur], actualInputStep);
input2.tileInfo.offset[cur] = (i * NUM_VALUE_8) % input2.tensor.GetShape()[cur];
input2.tileInfo.shape[cur] =
std::min(input2.tensor.GetShape()[cur] - input2.tileInfo.offset[cur], actualInputStep);
TiledCompareOperationImpl(
function, tileShape, cur + 1, input1, input2, result, resultTileInfo, operation, mode);
}
} else {
for (int i = 0; i < result->shape[cur]; i += step) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - i, step);
input1.tileInfo.offset[cur] = i % input1.tensor.GetShape()[cur];
input1.tileInfo.shape[cur] = std::min(input1.tensor.GetShape()[cur] - input1.tileInfo.offset[cur], step);
input2.tileInfo.offset[cur] = i % input2.tensor.GetShape()[cur];
input2.tileInfo.shape[cur] = std::min(input2.tensor.GetShape()[cur] - input2.tileInfo.offset[cur], step);
TiledCompareOperationImpl(
function, tileShape, cur + 1, input1, input2, result, resultTileInfo, operation, mode);
}
}
}
void TiledCompareOperation(
Function& function, const TileShape& tileShape, LogicalTensorPtr operand1, LogicalTensorPtr operand2,
const LogicalTensorPtr& result, OpType operation, OutType mode)
{
auto broadcastOperand = [&](LogicalTensorPtr& operand, LogicalTensorPtr& other) {
auto dstShape = result->shape;
if (mode == OutType::BIT) {
dstShape[dstShape.size() - 1] *= NUM_VALUE_8;
}
if (operand->shape == dstShape) {
return;
}
auto expanded = std::make_shared<LogicalTensor>(function, operand->Datatype(), dstShape);
Expand(function, tileShape, operand, {other}, expanded);
operand = expanded;
};
broadcastOperand(operand1, operand2);
broadcastOperand(operand2, operand1);
TileInfo tileInfo1(result->shape.size(), result->offset.size());
TileInfo tileInfo2(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto input1 = Input{operand1, tileInfo1};
auto input2 = Input{operand2, tileInfo2};
TiledCompareOperationImpl(function, tileShape, 0, input1, input2, result, resultTileInfo, operation, mode);
}
LogicalTensorPtr TensorCompareOperation(
Function& function, const Tensor& self, const Tensor& other, OpType operation, OutType mode)
{
auto operandT1 = self.GetStorage();
auto operandT2 = other.GetStorage();
if (operandT1->shape.size() != operandT2->shape.size()) {
std::vector<int> broadCastShape = GetBroadCastShape(operandT1, operandT2);
operandT1 = BinaryOperationBroadCast(operandT1, broadCastShape);
operandT2 = BinaryOperationBroadCast(operandT2, broadCastShape);
}
std::vector<SymbolicScalar> resultValidShape;
std::vector<int64_t> resultShape = BinaryOperationResultShape(operandT1, operandT2);
if (!operandT1->GetDynValidShape().empty() && !operandT2->GetDynValidShape().empty()) {
for (size_t i = 0; i < resultShape.size(); ++i) {
if (resultShape[i] == operandT1->shape[i]) {
resultValidShape.push_back(operandT1->GetDynValidShape()[i]);
} else {
resultValidShape.push_back(operandT2->GetDynValidShape()[i]);
}
}
}
auto resultType = DT_BOOL;
if (mode == OutType::BIT) {
resultType = DT_UINT8;
ASSERT(VectorErrorCode::ERR_CONFIG_ALIGNMENT, resultShape.empty() || resultShape.back() % NUM_VALUE_8 == 0)
<< "Last dimension must be divisible by 8 in BIT mode";
if (!resultShape.empty()) {
resultShape.back() /= NUM_VALUE_8;
if (!resultValidShape.empty()) {
resultValidShape.back() = resultValidShape.back() / NUM_VALUE_8;
}
}
}
auto result = std::make_shared<LogicalTensor>(function, resultType, resultShape, resultValidShape);
auto& op = function.AddOperation(Opcode::OP_CMP, {operandT1, operandT2}, {result});
std::vector<bool> dimMap({true, true});
op.SetAttr(OpAttributeKey::rowPad, dimMap);
op.SetAttribute(OP_ATTR_PREFIX + "cmp_operation", static_cast<int64_t>(operation));
op.SetAttribute(OP_ATTR_PREFIX + "cmp_mode", static_cast<int64_t>(mode));
return result;
}
LogicalTensorPtr TensorCompareOperationScalar(
Function& function, const Tensor& operand1, const Element& value, OpType operation, OutType mode)
{
DECLARE_TRACER();
auto operandT1 = operand1.GetStorage();
std::vector<int64_t> resultShape = operandT1->shape;
std::vector<SymbolicScalar> resultValidShape = operandT1->GetDynValidShape();
DataType resultType = DT_BOOL;
if (mode == OutType::BIT) {
resultType = DT_UINT8;
if (!resultShape.empty()) {
int64_t lastDim = resultShape.back();
ASSERT(VectorErrorCode::ERR_CONFIG_ALIGNMENT, lastDim % NUM_VALUE_8 == 0)
<< "Last dimension must be divisible by 8 in BIT mode";
resultShape.back() = lastDim / NUM_VALUE_8;
if (!resultValidShape.empty()) {
auto& lastSymDim = resultValidShape.back();
resultValidShape.back() = lastSymDim / NUM_VALUE_8;
}
}
}
auto result = std::make_shared<LogicalTensor>(function, resultType, resultShape, resultValidShape);
auto& op = function.AddOperation(Opcode::OP_CMPS, {operandT1}, {result});
std::vector<bool> dimMap({true});
op.SetAttr(OpAttributeKey::rowPad, dimMap);
op.SetAttribute(OpAttributeKey::scalar, value);
op.SetAttribute(OP_ATTR_PREFIX + "cmp_operation", static_cast<int64_t>(operation));
op.SetAttribute(OP_ATTR_PREFIX + "cmp_mode", static_cast<int64_t>(mode));
return result;
}
LogicalTensorPtr TensorCompareOperationScalar(
Function& function, const Element& value, const Tensor& operand1, OpType operation, OutType mode)
{
switch (operation) {
case OpType::LT:
operation = OpType::GT;
break;
case OpType::GT:
operation = OpType::LT;
break;
case OpType::LE:
operation = OpType::GE;
break;
case OpType::GE:
operation = OpType::LE;
break;
default:
break;
}
Element converted_value = value;
return TensorCompareOperationScalar(function, operand1, converted_value, operation, mode);
}
void TiledCmpsOperationImpl(
Function& function, const TileShape& tileShape, size_t cur, Input& input, const Element& scalar,
const LogicalTensorPtr& result, TileInfo& resultTileInfo, OpType operation, OutType mode)
{
if (cur == result->shape.size()) {
auto inputTile = input.tensor.GetStorage()->View(function, input.tileInfo.shape, input.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
const int64_t COUNT_MODE_SIZE = 4096;
size_t element_size = BytesOf(input.tensor.GetDataType());
ASSERT(VectorErrorCode::ERR_RUNTIME_LOGIC, element_size != 0) << "Element size cannot be zero.";
int64_t elements_per_chunk = COUNT_MODE_SIZE / element_size;
int64_t vcmp_bits_size = (elements_per_chunk + 8 - 1) / 8;
const size_t ALIGN_SIZE = 32;
size_t vcmpBitResult_size = ((vcmp_bits_size + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE;
size_t array_size = elements_per_chunk * element_size;
size_t aligned_array_size = ((array_size + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE;
size_t total_bytes = vcmpBitResult_size + 3 * aligned_array_size + ALIGN_SIZE;
std::vector<int64_t> tmp_shape({static_cast<int64_t>(total_bytes)});
auto tmp_tensor = std::make_shared<LogicalTensor>(function, DT_UINT8, tmp_shape);
auto& op = function.AddOperation(Opcode::OP_CMPS, {inputTile}, {resultTile, tmp_tensor});
std::vector<bool> dimMap({true});
op.SetAttr(OpAttributeKey::rowPad, dimMap);
op.SetAttribute(OP_ATTR_PREFIX + "cmp_operation", static_cast<int64_t>(operation));
op.SetAttribute(OP_ATTR_PREFIX + "cmp_mode", static_cast<int64_t>(mode));
op.SetAttribute(OpAttributeKey::scalar, scalar);
return;
}
auto& vecTile = tileShape.GetVecTile();
int64_t step = vecTile[cur];
if (mode == OutType::BIT && cur == result->shape.size() - 1) {
step = vecTile[cur] / NUM_VALUE_8;
if (step < 1)
step = 1;
int64_t actualInputStep = step * NUM_VALUE_8;
for (int i = 0; i < result->shape[cur]; i += step) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - i, step);
input.tileInfo.offset[cur] = (i * NUM_VALUE_8) % input.tensor.GetShape()[cur];
input.tileInfo.shape[cur] =
std::min(input.tensor.GetShape()[cur] - input.tileInfo.offset[cur], actualInputStep);
TiledCmpsOperationImpl(
function, tileShape, cur + 1, input, scalar, result, resultTileInfo, operation, mode);
}
} else {
for (int i = 0; i < result->shape[cur]; i += step) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - i, step);
input.tileInfo.offset[cur] = i % input.tensor.GetShape()[cur];
input.tileInfo.shape[cur] = std::min(input.tensor.GetShape()[cur] - input.tileInfo.offset[cur], step);
TiledCmpsOperationImpl(
function, tileShape, cur + 1, input, scalar, result, resultTileInfo, operation, mode);
}
}
}
void TiledCmpsOperation(
Function& function, const TileShape& tileShape, LogicalTensorPtr operand, const Element& scalar,
const LogicalTensorPtr& result, OpType operation, OutType mode)
{
TileInfo tileInfo(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto input = Input{operand, tileInfo};
TiledCmpsOperationImpl(function, tileShape, 0, input, scalar, result, resultTileInfo, operation, mode);
}
Tensor Compare(const Tensor& self, const Tensor& other, OpType op, OutType mode)
{
DECLARE_TRACER();
CheckTensorsDataTypeConsistency(self.GetStorage(), other.GetStorage(), "COMPARE");
static const std::unordered_set<DataType> CMP_A2A3_TYPES = {DT_FP16, DT_FP32};
static const std::unordered_set<DataType> CMP_A5_TYPES = {DT_FP16, DT_FP32, DT_INT16};
const auto& supportedTypes = GetSupportedDataTypesByArch(CMP_A2A3_TYPES, CMP_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "COMPARE");
CheckBinaryInputTensors(self.GetStorage(), other.GetStorage(), "COMPARE");
RETURN_CALL(CompareOperation, *Program::GetInstance().GetCurrentFunction(), self, other, op, mode);
}
Tensor Compare(const Tensor& self, const Element& other, OpType op, OutType mode)
{
DECLARE_TRACER();
static const std::unordered_set<DataType> CMP_A2A3_TYPES = {DT_FP16, DT_FP32};
static const std::unordered_set<DataType> CMP_A5_TYPES = {DT_FP16, DT_FP32, DT_INT16};
const auto& supportedTypes = GetSupportedDataTypesByArch(CMP_A2A3_TYPES, CMP_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "COMPARE");
CheckTensorDimRange(self.GetStorage(), 1, 4, "COMPARE");
CheckTensorShapeSize(self.GetStorage(), "COMPARE");
RETURN_CALL(CompareOperationScalar, *Program::GetInstance().GetCurrentFunction(), self, other, op, mode);
}
Tensor Compare(const Element& self, const Tensor& other, OpType op, OutType mode)
{
DECLARE_TRACER();
static const std::unordered_set<DataType> CMP_A2A3_TYPES = {DT_FP16, DT_FP32};
static const std::unordered_set<DataType> CMP_A5_TYPES = {DT_FP16, DT_FP32, DT_INT16};
const auto& supportedTypes = GetSupportedDataTypesByArch(CMP_A2A3_TYPES, CMP_A5_TYPES);
CheckTensorDataType(other.GetStorage(), supportedTypes, "COMPARE");
CheckTensorDimRange(other.GetStorage(), 1, 4, "COMPARE");
CheckTensorShapeSize(other.GetStorage(), "COMPARE");
RETURN_CALL(CompareOperationScalar, *Program::GetInstance().GetCurrentFunction(), self, other, op, mode);
}
void CompareOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
BinaryOperationOperandCheck(iOperand, oOperand);
auto operation = static_cast<OpType>(op.GetIntAttribute(OP_ATTR_PREFIX + "cmp_operation"));
auto mode = static_cast<OutType>(op.GetIntAttribute(OP_ATTR_PREFIX + "cmp_mode"));
TiledCompareOperation(function, tileShape, iOperand[0], iOperand[1], oOperand[0], operation, mode);
}
void CmpsOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, const Operation& op)
{
auto operation = static_cast<OpType>(op.GetIntAttribute(OP_ATTR_PREFIX + "cmp_operation"));
auto mode = static_cast<OutType>(op.GetIntAttribute(OP_ATTR_PREFIX + "cmp_mode"));
TiledCmpsOperation(
function, tileShape, iOperand[0], op.GetElementAttribute(OpAttributeKey::scalar), oOperand[0], operation, mode);
}
REGISTER_OPERATION_TILED_FUNC(OP_CMP, Opcode::OP_CMP, CompareOperationTileFunc);
REGISTER_OPERATION_TILED_FUNC(OP_CMPS, Opcode::OP_CMPS, CmpsOperationTileFunc);
}