* Copyright (c) 2023 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "constantfold.h"
#include <cfloat>
namespace maple {
namespace {
constexpr uint32 kByteSizeOfBit64 = 8;
constexpr uint32 kBitSizePerByte = 8;
constexpr maple::int32 kMaxOffset = INT_MAX - 8;
enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
{
if (!v1 && !v2) {
return std::nullopt;
}
return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
}
std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
{
if (!v1 && !v2) {
return std::nullopt;
}
if (v1 && v2) {
return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
}
if (v1) {
return v1->TruncOrExtend(PTY_i64);
}
return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
}
std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
{
return AddSub(v1, v2, true);
}
std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
{
return AddSub(v1, v2, false);
}
}
static bool ContiguousBitsOf1(uint64 x)
{
if (x == 0) {
return false;
}
return (~x & (x + 1)) == (x + 1);
}
inline bool IsPowerOf2(uint64 num)
{
if (num == 0) {
return false;
}
return (~(num - 1) & num) == num;
}
BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
BaseNode *rhs) const
{
CHECK_NULL_FATAL(old);
BinaryNode *result = nullptr;
if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
result = old;
} else {
result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
}
return result;
}
UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
{
CHECK_NULL_FATAL(old);
UnaryNode *result = nullptr;
if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
result = old;
} else {
result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
}
return result;
}
BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
{
CHECK_NULL_FATAL(pair.first);
BaseNode *result = pair.first;
if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
return result;
}
if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
static_cast<uint64>(pair.second->GetExtValue()), resultType);
BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
} else {
if ((!pair.second->GetSignBit() &&
pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
pair.second->TruncOrExtend(resultType).IsMinValue() ||
pair.second->GetSXTValue() == INT64_MIN) {
ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
static_cast<uint64>(pair.second->GetExtValue()), resultType);
result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
} else {
ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
}
}
return result;
}
std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
{
return std::make_pair(node, std::nullopt);
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
{
CHECK_NULL_FATAL(node);
if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
return {node, std::nullopt};
}
switch (node->GetOpCode()) {
case OP_abs:
case OP_bnot:
case OP_lnot:
case OP_neg:
case OP_sqrt:
return FoldUnary(static_cast<UnaryNode*>(node));
case OP_ceil:
case OP_floor:
case OP_trunc:
case OP_cvt:
return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
case OP_sext:
case OP_zext:
case OP_extractbits:
return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
case OP_iread:
return FoldIread(static_cast<IreadNode*>(node));
case OP_add:
case OP_ashr:
case OP_band:
case OP_bior:
case OP_bxor:
case OP_div:
case OP_lshr:
case OP_max:
case OP_min:
case OP_mul:
case OP_rem:
case OP_shl:
case OP_sub:
return FoldBinary(static_cast<BinaryNode*>(node));
case OP_eq:
case OP_ne:
case OP_ge:
case OP_gt:
case OP_le:
case OP_lt:
case OP_cmp:
return FoldCompare(static_cast<CompareNode*>(node));
case OP_retype:
return FoldRetype(static_cast<RetypeNode*>(node));
default:
return FoldBase(static_cast<BaseNode*>(node));
}
}
BaseNode *ConstantFold::Negate(BaseNode *node) const
{
CHECK_NULL_FATAL(node);
return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
}
BaseNode *ConstantFold::Negate(UnaryNode *node) const
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
if (node->GetOpCode() == OP_neg) {
result = static_cast<BaseNode*>(node->Opnd(0));
} else {
BaseNode *n = static_cast<BaseNode*>(node);
result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
}
return result;
}
BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
{
CHECK_NULL_FATAL(node);
ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
CHECK_NULL_FATAL(copy);
copy->GetConstVal()->Neg();
return copy;
}
BaseNode *ConstantFold::NegateTree(BaseNode *node) const
{
CHECK_NULL_FATAL(node);
if (node->IsUnaryNode()) {
return Negate(static_cast<UnaryNode*>(node));
} else if (node->GetOpCode() == OP_constval) {
return Negate(static_cast<ConstvalNode*>(node));
} else {
return Negate(static_cast<BaseNode*>(node));
}
}
MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
const MIRIntConst &intConst0,
const MIRIntConst &intConst1) const
{
uint64 result = 0;
bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
switch (opcode) {
case OP_eq: {
result = equal;
break;
}
case OP_ge: {
result = (greater || equal) ? 1 : 0;
break;
}
case OP_gt: {
result = greater;
break;
}
case OP_le: {
result = (less || equal) ? 1 : 0;
break;
}
case OP_lt: {
result = less;
break;
}
case OP_ne: {
result = !equal;
break;
}
case OP_cmp: {
if (greater) {
result = kGreater;
} else if (equal) {
result = kEqual;
} else if (less) {
result = static_cast<uint64>(kLess);
}
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
break;
}
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
return constValue;
}
ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
const ConstvalNode &const0, const ConstvalNode &const1) const
{
const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
CHECK_NULL_FATAL(intConst0);
CHECK_NULL_FATAL(intConst1);
MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(resultType);
resultConst->SetConstVal(constValue);
return resultConst;
}
MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
const MIRIntConst &intConst1)
{
IntVal intVal0 = intConst0.GetValue();
IntVal intVal1 = intConst1.GetValue();
IntVal result(static_cast<uint64>(0), resultType);
switch (opcode) {
case OP_add: {
result = intVal0.Add(intVal1, resultType);
break;
}
case OP_sub: {
result = intVal0.Sub(intVal1, resultType);
break;
}
case OP_mul: {
result = intVal0.Mul(intVal1, resultType);
break;
}
case OP_div: {
result = intVal0.Div(intVal1, resultType);
break;
}
case OP_rem: {
result = intVal0.Rem(intVal1, resultType);
break;
}
case OP_ashr: {
result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
break;
}
case OP_lshr: {
result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
break;
}
case OP_shl: {
result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
break;
}
case OP_max: {
result = Max(intVal0, intVal1, resultType);
break;
}
case OP_min: {
result = Min(intVal0, intVal1, resultType);
break;
}
case OP_band: {
result = intVal0.And(intVal1, resultType);
break;
}
case OP_bior: {
result = intVal0.Or(intVal1, resultType);
break;
}
case OP_bxor: {
result = intVal0.Xor(intVal1, resultType);
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
break;
}
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
MIRIntConst *constValue =
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
return constValue;
}
ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
const ConstvalNode &const1) const
{
const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
CHECK_NULL_FATAL(intConst0);
CHECK_NULL_FATAL(intConst1);
MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(resultType);
resultConst->SetConstVal(constValue);
return resultConst;
}
ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
const ConstvalNode &const1) const
{
DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
const MIRDoubleConst *doubleConst0 = nullptr;
const MIRDoubleConst *doubleConst1 = nullptr;
const MIRFloatConst *floatConst0 = nullptr;
const MIRFloatConst *floatConst1 = nullptr;
bool useDouble = (const0.GetPrimType() == PTY_f64);
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(resultType);
if (useDouble) {
doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
CHECK_NULL_FATAL(doubleConst0);
CHECK_NULL_FATAL(doubleConst1);
} else {
floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
CHECK_NULL_FATAL(floatConst0);
CHECK_NULL_FATAL(floatConst1);
}
float constValueFloat = 0.0;
double constValueDouble = 0.0;
switch (opcode) {
case OP_add: {
if (useDouble) {
constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
} else {
constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
}
break;
}
case OP_sub: {
if (useDouble) {
constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
} else {
constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
}
break;
}
case OP_mul: {
if (useDouble) {
constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
} else {
constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
}
break;
}
case OP_div: {
if (useDouble) {
constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
} else {
constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
}
break;
}
case OP_max: {
if (useDouble) {
constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
: doubleConst1->GetValue();
} else {
constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
: floatConst1->GetValue();
}
break;
}
case OP_min: {
if (useDouble) {
constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
: doubleConst1->GetValue();
} else {
constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
: floatConst1->GetValue();
}
break;
}
case OP_rem:
case OP_ashr:
case OP_lshr:
case OP_shl:
case OP_band:
case OP_bior:
case OP_bxor: {
DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
break;
}
if (resultType == PTY_f64) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
} else {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
}
return resultConst;
}
bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
{
return (leftValue == rightValue);
}
bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
{
auto result = fabs(leftValue - rightValue);
return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
}
bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
{
auto result = fabs(leftValue - rightValue);
return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
}
template<typename T>
bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
{
if (std::isinf(leftValue) && std::isinf(rightValue)) {
return true;
} else {
return ConstValueEqual(leftValue, rightValue);
}
}
template<typename T>
int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
{
DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
typename T::value_type leftValue = leftConst->GetValue();
DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
typename T::value_type rightValue = rightConst->GetValue();
int64 result = 0;
switch (op) {
case OP_eq: {
result = FullyEqual(leftValue, rightValue);
break;
}
case OP_ge: {
result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
break;
}
case OP_gt: {
result = (leftValue > rightValue);
break;
}
case OP_le: {
result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
break;
}
case OP_lt: {
result = (leftValue < rightValue);
break;
}
case OP_ne: {
result = !FullyEqual(leftValue, rightValue);
break;
}
[[clang::fallthrough]];
case OP_cmp: {
if (leftValue > rightValue) {
result = kGreater;
} else if (FullyEqual(leftValue, rightValue)) {
result = kEqual;
} else {
result = kLess;
}
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for Comparison");
break;
}
return result;
}
MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
const MIRConst &leftConst, const MIRConst &rightConst) const
{
int64 result = 0;
bool useDouble = (opndType == PTY_f64);
if (useDouble) {
result =
ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
} else {
result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
}
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
return resultConst;
}
ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
const ConstvalNode &const0, const ConstvalNode &const1) const
{
DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(resultType);
resultConst->SetConstVal(
FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
return resultConst;
}
MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
const MIRConst &const0, const MIRConst &const1) const
{
MIRConst *returnValue = nullptr;
if (IsPrimitiveInteger(opndType)) {
const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
ASSERT_NOT_NULL(intConst0);
ASSERT_NOT_NULL(intConst1);
returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
} else if (opndType == PTY_f32 || opndType == PTY_f64) {
returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
} else {
DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
}
return returnValue;
}
ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
const ConstvalNode &const0, const ConstvalNode &const1) const
{
ConstvalNode *returnValue = nullptr;
if (IsPrimitiveInteger(opndType)) {
returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
} else if (opndType == PTY_f32 || opndType == PTY_f64) {
returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
} else {
DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
}
return returnValue;
}
CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
BaseNode &l, BaseNode &r) const
{
CompareNode *result = nullptr;
Opcode op = opcode;
switch (opcode) {
case OP_gt: {
op = OP_lt;
break;
}
case OP_lt: {
op = OP_gt;
break;
}
case OP_ge: {
op = OP_le;
break;
}
case OP_le: {
op = OP_ge;
break;
}
case OP_eq: {
break;
}
case OP_ne: {
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
break;
}
result =
mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
return result;
}
ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
const ConstvalNode &const1) const
{
ConstvalNode *returnValue = nullptr;
if (IsPrimitiveInteger(resultType)) {
returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
} else if (resultType == PTY_f32 || resultType == PTY_f64) {
returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
} else {
DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
}
return returnValue;
}
MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
{
CHECK_NULL_FATAL(constNode);
IntVal result = constNode->GetValue().TruncOrExtend(resultType);
switch (opcode) {
case OP_abs: {
if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
result = -result;
}
break;
}
case OP_bnot: {
result = ~result;
break;
}
case OP_lnot: {
uint64 resultInt = result == 0 ? 1 : 0;
result = {resultInt, resultType};
break;
}
case OP_neg: {
result = -result;
break;
}
case OP_sext:
case OP_zext:
case OP_extractbits:
case OP_sqrt: {
DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
break;
}
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
MIRIntConst *constValue =
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
return constValue;
}
template <typename T>
ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
{
CHECK_NULL_FATAL(constNode);
double constValue = 0;
T *fpCst = static_cast<T*>(constNode->GetConstVal());
switch (opcode) {
case OP_neg: {
constValue = typename T::value_type(-fpCst->GetValue());
break;
}
case OP_abs: {
constValue = typename T::value_type(fabs(fpCst->GetValue()));
break;
}
case OP_sqrt: {
constValue = typename T::value_type(sqrt(fpCst->GetValue()));
break;
}
case OP_bnot:
case OP_lnot:
case OP_sext:
case OP_zext:
case OP_extractbits: {
DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
break;
}
default:
DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
break;
}
auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(resultType);
if (resultType == PTY_f32) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
} else if (resultType == PTY_f64) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
} else {
CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
}
return resultConst;
}
ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
{
ConstvalNode *returnValue = nullptr;
if (IsPrimitiveInteger(resultType)) {
const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
returnValue->SetPrimType(resultType);
returnValue->SetConstVal(constValue);
} else if (resultType == PTY_f32) {
returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
} else if (resultType == PTY_f64) {
returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
} else {
DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
}
return returnValue;
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = node;
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
if (node->Opnd(0) != p.first) {
RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
result = newRetNode;
}
return std::make_pair(result, std::nullopt);
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
std::optional<IntVal> sum = std::nullopt;
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
if (cst != nullptr) {
result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
} else {
bool isInt = IsPrimitiveInteger(node->GetPrimType());
if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
result = NegateTree(p.first);
if (result->GetOpCode() == OP_neg) {
PrimType origPtyp = node->GetPrimType();
PrimType newPtyp = result->GetPrimType();
if (newPtyp == origPtyp) {
if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
result = node;
}
} else {
if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
PairToExpr(node->Opnd(0)->GetPrimType(), p));
return std::make_pair(result, std::nullopt);
} else {
result->SetPrimType(origPtyp);
}
}
}
if (p.second) {
sum = -(*p.second);
}
} else {
result =
NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
}
}
return std::make_pair(result, sum);
}
static bool FloatToIntOverflow(float fval, PrimType totype)
{
static const float safeFloatMaxToInt32 = 2147483520.0f;
static const float safeFloatMinToInt32 = -2147483520.0f;
static const float safeFloatMaxToInt64 = 9223372036854775680.0f;
static const float safeFloatMinToInt64 = -9223372036854775680.0f;
if (!std::isfinite(fval)) {
return true;
}
if (totype == PTY_i64 || totype == PTY_u64) {
if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
return true;
}
} else {
if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
return true;
}
}
return false;
}
static bool DoubleToIntOverflow(double dval, PrimType totype)
{
static const double safeDoubleMaxToInt32 = 2147482624.0;
static const double safeDoubleMinToInt32 = -2147482624.0;
static const double safeDoubleMaxToInt64 = 9223372036854774784.0;
static const double safeDoubleMinToInt64 = -9223372036854774784.0;
if (!std::isfinite(dval)) {
return true;
}
if (totype == PTY_i64 || totype == PTY_u64) {
if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
return true;
}
} else {
if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
return true;
}
}
return false;
}
ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
{
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(toType);
MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
if (fromType == PTY_f32) {
const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
ASSERT_NOT_NULL(constValue);
float floatValue = ceil(constValue->GetValue());
if (IsPrimitiveFloat(toType)) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
} else if (FloatToIntOverflow(floatValue, toType)) {
return nullptr;
} else {
resultConst->SetConstVal(
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
}
} else {
const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
ASSERT_NOT_NULL(constValue);
double doubleValue = ceil(constValue->GetValue());
if (IsPrimitiveFloat(toType)) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
} else if (DoubleToIntOverflow(doubleValue, toType)) {
return nullptr;
} else {
resultConst->SetConstVal(
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
}
}
return resultConst;
}
template <class T>
T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
{
DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
bool isSigned = IsSignedInteger(resultType.GetPrimType());
int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
if (isSigned && (value > max)) {
return static_cast<T>(max);
} else if (!isSigned && (value > umax)) {
return static_cast<T>(umax);
} else if (value < min) {
return static_cast<T>(min);
}
return value;
}
MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
{
MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
if (fromType == PTY_f32) {
const auto &constValue = static_cast<const MIRFloatConst&>(cst);
float floatValue = constValue.GetValue();
if (isFloor) {
floatValue = floor(constValue.GetValue());
}
if (IsPrimitiveFloat(toType)) {
return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
}
if (FloatToIntOverflow(floatValue, toType)) {
return nullptr;
}
floatValue = CalIntValueFromFloatValue(floatValue, resultType);
return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
} else {
const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
double doubleValue = constValue.GetValue();
if (isFloor) {
doubleValue = floor(constValue.GetValue());
}
if (IsPrimitiveFloat(toType)) {
return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
}
if (DoubleToIntOverflow(doubleValue, toType)) {
return nullptr;
}
doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
}
}
ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
{
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(toType);
resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
return resultConst;
}
MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
{
MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
if (fromType == PTY_f32) {
const auto &constValue = static_cast<const MIRFloatConst&>(cst);
float floatValue = round(constValue.GetValue());
if (FloatToIntOverflow(floatValue, toType)) {
return nullptr;
}
return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
} else if (fromType == PTY_f64) {
const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
double doubleValue = round(constValue.GetValue());
if (DoubleToIntOverflow(doubleValue, toType)) {
return nullptr;
}
return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
} else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
const auto &constValue = static_cast<const MIRIntConst&>(cst);
if (IsSignedInteger(fromType)) {
int64 fromValue = constValue.GetExtValue();
float floatValue = round(static_cast<float>(fromValue));
if (static_cast<int64>(floatValue) == fromValue) {
return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
}
} else {
uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
float floatValue = round(static_cast<float>(fromValue));
if (static_cast<uint64>(floatValue) == fromValue) {
return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
}
}
} else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
const auto &constValue = static_cast<const MIRIntConst&>(cst);
if (IsSignedInteger(fromType)) {
int64 fromValue = constValue.GetExtValue();
double doubleValue = round(static_cast<double>(fromValue));
if (static_cast<int64>(doubleValue) == fromValue) {
return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
}
} else {
uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
double doubleValue = round(static_cast<double>(fromValue));
if (static_cast<uint64>(doubleValue) == fromValue) {
return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
}
}
}
return nullptr;
}
ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
{
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(toType);
resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
return resultConst;
}
ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
{
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
resultConst->SetPrimType(toType);
MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
if (fromType == PTY_f32) {
const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
CHECK_NULL_FATAL(constValue);
float floatValue = trunc(constValue->GetValue());
if (IsPrimitiveFloat(toType)) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
} else if (FloatToIntOverflow(floatValue, toType)) {
return nullptr;
} else {
resultConst->SetConstVal(
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
}
} else {
const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
CHECK_NULL_FATAL(constValue);
double doubleValue = trunc(constValue->GetValue());
if (IsPrimitiveFloat(toType)) {
resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
} else if (DoubleToIntOverflow(doubleValue, toType)) {
return nullptr;
} else {
resultConst->SetConstVal(
GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
}
}
return resultConst;
}
MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
{
if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
MIRConst *toConst = nullptr;
uint32 fromSize = GetPrimTypeBitSize(fromType);
uint32 toSize = GetPrimTypeBitSize(toType);
if (fromType == PTY_u1) {
fromSize = 1;
}
if (toType == PTY_u1) {
toSize = 1;
}
if (toSize > fromSize) {
Opcode op = OP_zext;
if (IsSignedInteger(fromType)) {
op = OP_sext;
}
const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
ASSERT_NOT_NULL(constVal);
toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
constVal->GetValue().TruncOrExtend(fromType));
} else {
const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
ASSERT_NOT_NULL(constVal);
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
static_cast<uint64>(constVal->GetExtValue()), type);
}
return toConst;
}
if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
MIRConst *toConst = nullptr;
if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64");
const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
ASSERT_NOT_NULL(fromValue);
float floatValue = static_cast<float>(fromValue->GetValue());
MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
toConst = toValue;
} else {
DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64");
const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
ASSERT_NOT_NULL(fromValue);
double doubleValue = static_cast<double>(fromValue->GetValue());
MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
toConst = toValue;
}
return toConst;
}
if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
return FoldFloorMIRConst(cst, fromType, toType, false);
}
if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
return FoldRoundMIRConst(cst, fromType, toType);
}
CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
return nullptr;
}
ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
{
MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
if (toConstValue == nullptr) {
return nullptr;
}
ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
toConst->SetPrimType(toConstValue->GetType().GetPrimType());
toConst->SetConstVal(toConstValue);
return toConst;
}
PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
{
bool isSigned = IsSignedInteger(ptyp);
bool isFloat = IsPrimitiveFloat(ptyp);
if (bitSize == 1) {
return PTY_u1;
}
if (bitSize <= 8) {
return isSigned ? PTY_i8 : PTY_u8;
}
if (bitSize <= 16) {
return isSigned ? PTY_i16 : PTY_u16;
}
if (bitSize <= 32) {
return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
}
if (bitSize <= 64) {
return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
}
return ptyp;
}
size_t GetIntPrimTypeMax(PrimType ptyp)
{
switch (ptyp) {
case PTY_u1:
return 1;
case PTY_u8:
return UINT8_MAX;
case PTY_i8:
return INT8_MAX;
case PTY_u16:
return UINT16_MAX;
case PTY_i16:
return INT16_MAX;
case PTY_u32:
return UINT32_MAX;
case PTY_i32:
return INT32_MAX;
case PTY_u64:
return UINT64_MAX;
case PTY_i64:
return INT64_MAX;
default:
CHECK_FATAL(false, "NYI");
}
}
ssize_t GetIntPrimTypeMin(PrimType ptyp)
{
if (IsUnsignedInteger(ptyp)) {
return 0;
}
switch (ptyp) {
case PTY_i8:
return INT8_MIN;
case PTY_i16:
return INT16_MIN;
case PTY_i32:
return INT32_MIN;
case PTY_i64:
return INT64_MIN;
default:
CHECK_FATAL(false, "NYI");
}
}
static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
{
if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
return false;
}
if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
return false;
}
return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
(IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
(IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
return {node, std::nullopt};
}
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
PrimType destPtyp = node->GetPrimType();
PrimType fromPtyp = node->FromType();
if (cst != nullptr) {
switch (node->GetOpCode()) {
case OP_ceil: {
result = FoldCeil(*cst, fromPtyp, destPtyp);
break;
}
case OP_cvt: {
result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
break;
}
case OP_floor: {
result = FoldFloor(*cst, fromPtyp, destPtyp);
break;
}
case OP_trunc: {
result = FoldTrunc(*cst, fromPtyp, destPtyp);
break;
}
default:
DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
break;
}
} else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
}
if (result == nullptr) {
BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
if (e != node->Opnd(0)) {
result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
} else {
result = node;
}
}
return std::make_pair(result, std::nullopt);
}
MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
{
uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
return constValue;
}
ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
const ConstvalNode &cst) const
{
ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
ASSERT_NOT_NULL(intCst);
IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
resultConst->SetPrimType(toConst->GetType().GetPrimType());
resultConst->SetConstVal(toConst);
return resultConst;
}
static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
{
if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
return false;
}
BaseNode *opnd = x.Opnd(0);
MIRType *mirType = nullptr;
if (opnd->GetOpCode() == OP_dread) {
DreadNode *dread = static_cast<DreadNode*>(opnd);
MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
ASSERT_NOT_NULL(sym);
mirType = sym->GetType();
} else if (opnd->GetOpCode() == OP_iread) {
IreadNode *iread = static_cast<IreadNode*>(opnd);
MIRPtrType *ptrType =
dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
if (ptrType == nullptr) {
return false;
}
mirType = ptrType->GetPointedType();
} else if (opnd->GetOpCode() == OP_extractbits &&
x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
IsUnsignedInteger(opnd->GetPrimType()));
} else {
return false;
}
return IsPrimitiveInteger(mirType->GetPrimType()) &&
((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
(x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
mirType->GetPrimType() == x.GetPrimType();
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
uint8 offset = node->GetBitsOffset();
uint8 size = node->GetBitsSize();
Opcode opcode = node->GetOpCode();
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
return std::make_pair(result, std::nullopt);
}
BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
if (e != node->Opnd(0)) {
result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
size, e);
} else {
result = node;
}
BaseNode *opnd = result->Opnd(0);
DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
Opcode opndOp = opnd->GetOpCode();
if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
if (offset == opndOffset && size == opndSize) {
result->SetOpnd(opnd->Opnd(0), 0);
}
}
if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
return std::make_pair(result->Opnd(0), std::nullopt);
}
}
return std::make_pair(result, std::nullopt);
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
{
CHECK_NULL_FATAL(node);
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
node->SetOpnd(e, 0);
BaseNode *result = node;
if (e->GetOpCode() != OP_addrof) {
return std::make_pair(result, std::nullopt);
}
AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
DEBUG_ASSERT(msy != nullptr, "nullptr check");
TyIdx typeId = msy->GetTyIdx();
CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
if (ptrType->GetPointedType() != msyType) {
return std::make_pair(result, std::nullopt);
}
Opcode op = node->GetOpCode();
if (op == OP_iread) {
result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
node->GetFieldID() + addrofNode->GetFieldID());
}
return std::make_pair(result, std::nullopt);
}
bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
{
switch (op) {
case OP_add: {
int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
if (IsUnsignedInteger(primType)) {
return static_cast<uint64>(res) < static_cast<uint64>(cstA);
}
auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
(static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
}
case OP_sub: {
if (IsUnsignedInteger(primType)) {
return cstA < cstB;
}
int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
(static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
}
default: {
return false;
}
}
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
std::optional<IntVal> sum = std::nullopt;
Opcode op = node->GetOpCode();
PrimType primType = node->GetPrimType();
PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
BaseNode *l = lp.first;
BaseNode *r = rp.first;
ASSERT_NOT_NULL(r);
ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
bool isInt = IsPrimitiveInteger(primType);
if (lConst != nullptr && rConst != nullptr) {
MIRConst *lConstVal = lConst->GetConstVal();
MIRConst *rConstVal = rConst->GetConstVal();
ASSERT_NOT_NULL(lConstVal);
ASSERT_NOT_NULL(rConstVal);
if ((op == OP_div || op == OP_rem) && isInt &&
!IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
result = NewBinaryNode(node, op, primType, lConst, rConst);
} else {
result = FoldConstBinary(op, primType, *lConst, *rConst);
}
} else if (lConst != nullptr && isInt) {
MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
ASSERT_NOT_NULL(mcst);
PrimType cstTyp = mcst->GetType().GetPrimType();
IntVal cst = mcst->GetValue();
if (op == OP_add) {
if (IsSignedInteger(cstTyp) && rp.second &&
IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
} else {
sum = cst + rp.second;
result = r;
}
} else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
sum = cst - rp.second;
if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
}
result = NegateTree(r);
} else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
op == OP_band) &&
cst == 0) {
result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
} else if (op == OP_mul && cst == 1) {
sum = rp.second;
result = r;
} else if (op == OP_bior && cst == -1) {
result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
} else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
sum = cst * rp.second;
if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
}
result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
} else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
sum = rp.second;
result = r;
} else {
result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
}
if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
}
} else if (rConst != nullptr && isInt) {
MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
ASSERT_NOT_NULL(mcst);
PrimType cstTyp = mcst->GetType().GetPrimType();
IntVal cst = mcst->GetValue();
if (op == OP_add) {
if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
} else {
result = l;
sum = lp.second + cst;
}
} else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
result = l;
sum = lp.second - cst;
} else if ((op == OP_mul || op == OP_band) && cst == 0) {
result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
} else if ((op == OP_mul || op == OP_div) && cst == 1) {
sum = lp.second;
result = l;
} else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
BaseNode *x = l->Opnd(0);
BaseNode *y = l->Opnd(1);
ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
bool foldMulDiv = false;
if (yConst != nullptr && xConst == nullptr &&
IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
ASSERT_NOT_NULL(yCst);
IntVal mulCst = yCst->GetValue();
if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
mulCst.GetExtValue() == cst.GetExtValue()) {
foldMulDiv = true;
result = x;
}
} else if (xConst != nullptr && yConst == nullptr &&
IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
ASSERT_NOT_NULL(xCst);
IntVal mulCst = xCst->GetValue();
if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
mulCst.GetExtValue() == cst.GetExtValue()) {
foldMulDiv = true;
result = y;
}
}
if (!foldMulDiv) {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
}
} else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
sum = lp.second * cst;
if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
}
if (lp.first->GetOpCode() == OP_neg && cst == -1) {
result = lp.first->Opnd(0);
} else {
result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
}
} else if (op == OP_band && cst == -1) {
sum = lp.second;
result = l;
} else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
(!lp.second.has_value() || lp.second == 0)) {
bool fold2extractbits = false;
if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
BinaryNode *shrNode = static_cast<BinaryNode *>(l);
if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
uint64 ucst = cst.GetZXTValue();
uint32 bsize = 0;
do {
bsize++;
ucst >>= 1;
} while (ucst != 0);
if (shrAmt + static_cast<int64>(bsize) <=
static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
fold2extractbits = true;
result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
sum = std::nullopt;
}
}
}
if (!fold2extractbits) {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
sum = std::nullopt;
}
} else if (op == OP_bior && cst == -1) {
result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
} else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
sum = lp.second;
result = l;
} else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
BaseNode *base = cvtNode->Opnd(0);
BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
}
}
} else if (op == OP_rem && cst == 1) {
result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
} else {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
}
if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
}
} else if (isInt && (op == OP_add || op == OP_sub)) {
if (op == OP_add) {
result = NewBinaryNode(node, op, primType, l, r);
sum = lp.second + rp.second;
} else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
result = node;
} else {
result = NewBinaryNode(node, op, primType, l, r);
sum = lp.second - rp.second;
}
} else {
result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
}
return std::make_pair(result, sum);
}
BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
{
if (isRConstval) {
ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
}
} else {
ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
if (isGtOrLt) {
return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
} else {
return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
}
}
}
return &node;
}
BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
{
CompareNode *node = &compareNode;
BaseNode *result = node;
BaseNode *l = node->Opnd(0);
BaseNode *r = node->Opnd(1);
if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
(r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
} else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
BaseNode *opnd = l;
do {
if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
result = opnd;
break;
} else if (opnd->GetOpCode() == OP_cvt) {
TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
opnd = cvtNode->Opnd(0);
} else {
opnd = nullptr;
}
} while (opnd != nullptr);
}
} else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
(l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
}
}
} else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
(r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
result = SimplifyDoubleConstvalCompare(*node,
(l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
}
}
return result;
}
std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
{
CHECK_NULL_FATAL(node);
BaseNode *result = nullptr;
std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
Opcode opcode = node->GetOpCode();
if (lConst != nullptr && rConst != nullptr) {
result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
} else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
lConst->GetConstVal()->GetKind() == kConstInt) {
BaseNode *l = lp.first;
BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
} else {
BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
if (l != node->Opnd(0) || r != node->Opnd(1)) {
result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
} else {
result = node;
}
auto *compareNode = static_cast<CompareNode*>(result);
CHECK_NULL_FATAL(compareNode);
result = SimplifyDoubleCompare(*compareNode);
}
return std::make_pair(result, std::nullopt);
}
BaseNode *ConstantFold::Fold(BaseNode *node)
{
if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
return nullptr;
}
std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
BaseNode *result = PairToExpr(node->GetPrimType(), p);
if (result == node) {
result = nullptr;
}
return result;
}
}