* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file bitwise_shift.cpp
* \brief
*/
#include "binary.h"
#include "tensor_transformation.h"
#include "interface/utils/operator_tracer.h"
#include "interface/utils/common.h"
#include "interface/configs/config_manager.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk {
enum class BitwiseShiftOpType {
BITWISERIGHTSHIFT,
BITWISELEFTSHIFT,
SBITWISERIGHTSHIFT,
SBITWISELEFTSHIFT,
};
template <BitwiseShiftOpType T>
std::string GetBitwiseShiftOpName()
{
switch (T) {
case BitwiseShiftOpType::BITWISERIGHTSHIFT:
return "BITWISERIGHTSHIFT";
case BitwiseShiftOpType::BITWISELEFTSHIFT:
return "BITWISELEFTSHIFT";
case BitwiseShiftOpType::SBITWISERIGHTSHIFT:
return "SBITWISERIGHTSHIFT";
case BitwiseShiftOpType::SBITWISELEFTSHIFT:
return "SBITWISELEFTSHIFT";
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
return "";
}
}
template <BitwiseShiftOpType T, bool WithElement = false>
Opcode GetBitwiseShiftOpNameCode()
{
if constexpr (WithElement) {
#define CASE(X) \
case BitwiseShiftOpType::X: \
return Opcode::OP_##X##S
switch (T) {
CASE(BITWISERIGHTSHIFT);
CASE(BITWISELEFTSHIFT);
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
}
#undef CASE
}
#define CASE(X) \
case BitwiseShiftOpType::X: \
return Opcode::OP_##X
switch (T) {
CASE(BITWISERIGHTSHIFT);
CASE(BITWISELEFTSHIFT);
CASE(SBITWISERIGHTSHIFT);
CASE(SBITWISELEFTSHIFT);
default:
ASSERT(VectorErrorCode::ERR_PARAM_INVALID, false) << "unknown binary op type";
}
#undef CASE
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperation(
Function& function, const TileShape& tileShape, size_t cur, LogicalInput& input1, LogicalInput& input2,
const LogicalTensorPtr& result, TileInfo& resultTileInfo)
{
if (cur == input1.tensor->GetShape().size()) {
auto inputTile1 = input1.tensor->View(function, input1.tileInfo.shape, input1.tileInfo.offset);
auto inputTile2 = input2.tensor->View(function, input2.tileInfo.shape, input2.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
Shape tmpShape = resultTileInfo.shape;
auto alignSize = BLOCK_SIZE / static_cast<int64_t>(BytesOf(result->Datatype()));
tmpShape[tmpShape.size() - 1] = AlignUp(tmpShape[tmpShape.size() - 1], alignSize);
auto tmpTile = std::make_shared<LogicalTensor>(function, result->Datatype(), tmpShape);
function.AddOperation(GetBitwiseShiftOpNameCode<T, false>(), {inputTile1, inputTile2}, {resultTile, tmpTile});
return;
}
auto& vecTile = tileShape.GetVecTile();
for (int i = 0; i < result->shape[cur]; i += vecTile[cur]) {
resultTileInfo.offset[cur] = i;
input1.tileInfo.offset[cur] = i % input1.tensor->GetShape()[cur];
input2.tileInfo.offset[cur] = i % input2.tensor->GetShape()[cur];
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur]);
input1.tileInfo.shape[cur] =
std::min(input1.tensor->GetShape()[cur] - input1.tileInfo.offset[cur], vecTile[cur]);
input2.tileInfo.shape[cur] =
std::min(input2.tensor->GetShape()[cur] - input2.tileInfo.offset[cur], vecTile[cur]);
TiledBitwiseShiftOperation<T>(function, tileShape, cur + 1, input1, input2, result, resultTileInfo);
}
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperation(
Function& function, const TileShape& tileShape, LogicalTensorPtr operand1, LogicalTensorPtr operand2,
const LogicalTensorPtr& result)
{
CheckBinOpOperandsValid(operand1, operand2);
BroadcastOperandTensor(operand1, operand2, result, function, tileShape);
BroadcastOperandTensor(operand2, operand1, result, function, tileShape);
TileInfo tileInfo1(result->shape.size(), result->offset.size());
TileInfo tileInfo2(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto input1 = LogicalInput{operand1, tileInfo1};
auto input2 = LogicalInput{operand2, tileInfo2};
TiledBitwiseShiftOperation<T>(function, tileShape, 0, input1, input2, result, resultTileInfo);
}
template <BitwiseShiftOpType T>
LogicalTensorPtr TensorBitwiseShiftOperation(Function& function, const Tensor& self, const Tensor& other)
{
auto operand1 = self.GetStorage();
auto operand2 = other.GetStorage();
if (operand1->shape.size() != operand2->shape.size()) {
std::vector<int> broadCastShape = GetBroadCastShape(operand1, operand2);
operand1 = BinaryOperationBroadCast(operand1, broadCastShape);
operand2 = BinaryOperationBroadCast(operand2, broadCastShape);
}
auto opName = GetBitwiseShiftOpName<T>();
CheckBinaryInputTensors(operand1, operand2, opName);
CheckTensorsDataTypeConsistency(operand1, operand2, opName);
std::vector<SymbolicScalar> resultValidShape;
std::vector<int64_t> resultShape = BinaryOperationResultShape(operand1, operand2);
if ((!operand1->GetDynValidShape().empty()) && (!operand2->GetDynValidShape().empty())) {
for (size_t i = 0; i < resultShape.size(); ++i) {
if (resultShape[i] == operand1->shape[i]) {
resultValidShape.push_back(self.GetStorage()->GetDynValidShape()[i]);
} else {
resultValidShape.push_back(other.GetStorage()->GetDynValidShape()[i]);
}
}
}
auto result = std::make_shared<LogicalTensor>(
function, operand1->Datatype(), resultShape, resultValidShape, operand1->Format());
function.AddOperation(GetBitwiseShiftOpNameCode<T>(), {operand1, operand2}, {result});
return result;
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperationScalar(
Function& function, const TileShape& tileShape, size_t cur, LogicalInput& self, Element& value,
const LogicalTensorPtr& result, TileInfo& resultTileInfo, bool reverseOperand)
{
if (cur == self.tensor->GetShape().size()) {
auto inputTile = self.tensor->View(function, self.tileInfo.shape, self.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
auto& op = function.AddOperation(GetBitwiseShiftOpNameCode<T, true>(), {inputTile}, {resultTile});
op.SetAttribute(OpAttributeKey::scalar, value);
op.SetAttribute(OP_ATTR_PREFIX + "reverseOperand", reverseOperand);
return;
}
auto& vecTile = tileShape.GetVecTile();
for (int i = 0; i < result->shape[cur]; i += vecTile[cur]) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur]);
self.tileInfo.offset[cur] = i % self.tensor->GetShape()[cur];
self.tileInfo.shape[cur] = std::min(self.tensor->GetShape()[cur] - self.tileInfo.offset[cur], vecTile[cur]);
TiledBitwiseShiftOperationScalar<T>(
function, tileShape, cur + 1, self, value, result, resultTileInfo, reverseOperand);
}
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperationScalar(
Function& function, const TileShape& tileShape, LogicalTensorPtr operand, Element value,
const LogicalTensorPtr& result, bool reverseOperand = false)
{
TileInfo tileInfo(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto self = LogicalInput{operand, tileInfo};
TiledBitwiseShiftOperationScalar<T>(function, tileShape, 0, self, value, result, resultTileInfo, reverseOperand);
}
template <BitwiseShiftOpType T>
LogicalTensorPtr TensorBitwiseShiftOperationScalar(
Function& function, const LogicalTensorPtr& self, const Element& other)
{
auto opName = GetBitwiseShiftOpName<T>();
CheckTensorDimRange(self, MIN_TENSOR_DIM, MAX_TENSOR_DIM, opName);
CheckTensorShapeSize(self, opName);
auto result = std::make_shared<LogicalTensor>(function, self->Datatype(), self->shape, self->GetDynValidShape());
auto& op = function.AddOperation(GetBitwiseShiftOpNameCode<T, true>(), {self}, {result});
op.SetAttribute(OpAttributeKey::scalar, other);
return result;
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperationSelfScalar(
Function& function, const TileShape& tileShape, size_t cur, Element& value, LogicalInput& other,
const LogicalTensorPtr& result, TileInfo& resultTileInfo, bool reverseOperand)
{
if (cur == other.tensor->GetShape().size()) {
auto inputTile = other.tensor->View(function, other.tileInfo.shape, other.tileInfo.offset);
auto resultTile = result->View(function, resultTileInfo.shape, resultTileInfo.offset);
Shape tmpShape = resultTileInfo.shape;
auto alignSize = BLOCK_SIZE / static_cast<int64_t>(BytesOf(result->Datatype()));
tmpShape[tmpShape.size() - 1] = AlignUp(tmpShape[tmpShape.size() - 1], alignSize);
auto tmpTile = std::make_shared<LogicalTensor>(function, result->Datatype(), tmpShape);
auto& op = function.AddOperation(GetBitwiseShiftOpNameCode<T>(), {inputTile}, {resultTile, tmpTile});
op.SetAttribute(OpAttributeKey::scalar, value);
op.SetAttribute(OP_ATTR_PREFIX + "reverseOperand", reverseOperand);
return;
}
auto& vecTile = tileShape.GetVecTile();
for (int i = 0; i < result->shape[cur]; i += vecTile[cur]) {
resultTileInfo.offset[cur] = i;
resultTileInfo.shape[cur] = std::min(result->shape[cur] - resultTileInfo.offset[cur], vecTile[cur]);
other.tileInfo.offset[cur] = i % other.tensor->GetShape()[cur];
other.tileInfo.shape[cur] = std::min(other.tensor->GetShape()[cur] - other.tileInfo.offset[cur], vecTile[cur]);
TiledBitwiseShiftOperationSelfScalar<T>(
function, tileShape, cur + 1, value, other, result, resultTileInfo, reverseOperand);
}
}
template <BitwiseShiftOpType T>
void TiledBitwiseShiftOperationSelfScalar(
Function& function, const TileShape& tileShape, Element value, LogicalTensorPtr operand,
const LogicalTensorPtr& result, bool reverseOperand = false)
{
TileInfo tileInfo(result->shape.size(), result->offset.size());
TileInfo resultTileInfo(result->shape.size(), result->offset.size());
auto other = LogicalInput{operand, tileInfo};
TiledBitwiseShiftOperationSelfScalar<T>(
function, tileShape, 0, value, other, result, resultTileInfo, reverseOperand);
}
template <BitwiseShiftOpType T>
LogicalTensorPtr TensorBitwiseShiftOperationSelfScalar(
Function& function, const Element& self, const LogicalTensorPtr& other)
{
auto opName = GetBitwiseShiftOpName<T>();
CheckTensorDimRange(other, MIN_TENSOR_DIM, MAX_TENSOR_DIM, opName);
CheckTensorShapeSize(other, opName);
auto result = std::make_shared<LogicalTensor>(function, other->Datatype(), other->shape, other->GetDynValidShape());
auto& op = function.AddOperation(GetBitwiseShiftOpNameCode<T>(), {other}, {result});
op.SetAttribute(OpAttributeKey::scalar, self);
return result;
}
namespace {
const std::unordered_set<DataType> BITWISESHIFT_A2A3_TYPES = {DT_INT16, DT_UINT16};
const std::unordered_set<DataType> BITWISESHIFT_A5_TYPES = {
DT_INT16, DT_UINT16, DT_INT32, DT_UINT32, DT_INT8, DT_UINT8};
}
Tensor BitwiseRightShift(const Tensor& self, const Tensor& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "BitwiseRightShift");
RETURN_CALL(
BitwiseShiftOperation<BitwiseShiftOpType::BITWISERIGHTSHIFT>, *Program::GetInstance().GetCurrentFunction(),
self, other);
}
Tensor BitwiseRightShift(const Tensor& self, const Element& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "BitwiseRightShift");
Element newOther = other;
if (self.GetDataType() != other.GetDataType()) {
newOther = Element(self.GetDataType(), other.Cast<int32_t>());
}
RETURN_CALL(
BitwiseShiftOperationScalar<BitwiseShiftOpType::BITWISERIGHTSHIFT>,
*Program::GetInstance().GetCurrentFunction(), self.GetStorage(), newOther);
}
Tensor BitwiseRightShift(const Element& self, const Tensor& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(other.GetStorage(), supportedTypes, "BitwiseRightShift");
Element newSelf = self;
if (self.GetDataType() != other.GetDataType()) {
newSelf = Element(other.GetDataType(), self.Cast<int32_t>());
}
RETURN_CALL(
BitwiseShiftOperationSelfScalar<BitwiseShiftOpType::SBITWISERIGHTSHIFT>,
*Program::GetInstance().GetCurrentFunction(), newSelf, other.GetStorage());
}
Tensor BitwiseLeftShift(const Tensor& self, const Tensor& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "BitwiseLeftShift");
RETURN_CALL(
BitwiseShiftOperation<BitwiseShiftOpType::BITWISELEFTSHIFT>, *Program::GetInstance().GetCurrentFunction(), self,
other);
}
Tensor BitwiseLeftShift(const Tensor& self, const Element& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(self.GetStorage(), supportedTypes, "BitwiseLeftShift");
Element newOther = other;
if (self.GetDataType() != other.GetDataType()) {
newOther = Element(self.GetDataType(), other.Cast<int32_t>());
}
RETURN_CALL(
BitwiseShiftOperationScalar<BitwiseShiftOpType::BITWISELEFTSHIFT>, *Program::GetInstance().GetCurrentFunction(),
self.GetStorage(), newOther);
}
Tensor BitwiseLeftShift(const Element& self, const Tensor& other)
{
DECLARE_TRACER();
const auto& supportedTypes = GetSupportedDataTypesByArch(BITWISESHIFT_A2A3_TYPES, BITWISESHIFT_A5_TYPES);
CheckTensorDataType(other.GetStorage(), supportedTypes, "BitwiseLeftShift");
Element newSelf = self;
if (self.GetDataType() != other.GetDataType()) {
newSelf = Element(other.GetDataType(), self.Cast<int32_t>());
}
RETURN_CALL(
BitwiseShiftOperationSelfScalar<BitwiseShiftOpType::SBITWISELEFTSHIFT>,
*Program::GetInstance().GetCurrentFunction(), newSelf, other.GetStorage());
}
template <BitwiseShiftOpType T>
void BitwiseShiftOperationTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
BinaryOperationOperandCheck(iOperand, oOperand);
TiledBitwiseShiftOperation<T>(function, tileShape, iOperand[0], iOperand[1], oOperand[0]);
}
template <BitwiseShiftOpType T>
void BitwiseShiftOperationScalarTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
TiledBitwiseShiftOperationScalar<T>(
function, tileShape, iOperand[0], op.GetElementAttribute(OpAttributeKey::scalar), oOperand[0]);
}
template <BitwiseShiftOpType T>
void BitwiseShiftOperationSelfScalarTileFunc(
Function& function, const TileShape& tileShape, const std::vector<LogicalTensorPtr>& iOperand,
const std::vector<LogicalTensorPtr>& oOperand, [[maybe_unused]] const Operation& op)
{
TiledBitwiseShiftOperationSelfScalar<T>(
function, tileShape, op.GetElementAttribute(OpAttributeKey::scalar), iOperand[0], oOperand[0]);
}
REGISTER_OPERATION_TILED_FUNC(
OP_BITWISERIGHTSHIFT, Opcode::OP_BITWISERIGHTSHIFT,
BitwiseShiftOperationTileFunc<BitwiseShiftOpType::BITWISERIGHTSHIFT>);
REGISTER_OPERATION_TILED_FUNC(
OP_BITWISELEFTSHIFT, Opcode::OP_BITWISELEFTSHIFT,
BitwiseShiftOperationTileFunc<BitwiseShiftOpType::BITWISELEFTSHIFT>);
REGISTER_OPERATION_TILED_FUNC(
OP_BITWISERIGHTSHIFTS, Opcode::OP_BITWISERIGHTSHIFTS,
BitwiseShiftOperationScalarTileFunc<BitwiseShiftOpType::BITWISERIGHTSHIFT>);
REGISTER_OPERATION_TILED_FUNC(
OP_BITWISELEFTSHIFTS, Opcode::OP_BITWISELEFTSHIFTS,
BitwiseShiftOperationScalarTileFunc<BitwiseShiftOpType::BITWISELEFTSHIFT>);
REGISTER_OPERATION_TILED_FUNC(
OP_SBITWISERIGHTSHIFT, Opcode::OP_SBITWISERIGHTSHIFT,
BitwiseShiftOperationSelfScalarTileFunc<BitwiseShiftOpType::SBITWISERIGHTSHIFT>);
REGISTER_OPERATION_TILED_FUNC(
OP_SBITWISELEFTSHIFT, Opcode::OP_SBITWISELEFTSHIFT,
BitwiseShiftOperationSelfScalarTileFunc<BitwiseShiftOpType::SBITWISELEFTSHIFT>);
}