* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file symbolic_scalar_simplify.cpp
* \brief Algebraic simplification for SymbolicScalar expressions.
*
* Uses PYPTO_TRY_REWRITE pattern matching adapted for RawSymbolicScalar.
* Visit dispatch follows SymbolicScalar definition (SymbolicScalarKind + SymbolicOpcode).
*/
#pragma once
#include "symbolic_scalar.h"
#include <algorithm>
#include <optional>
#include <utility>
namespace npu::tile_fwk {
using RawPtr = RawSymbolicScalarPtr;
using Imm = RawSymbolicImmediate;
using Expr = RawSymbolicExpression;
static inline RawPtr MakeConst(ScalarImmediateType v) { return std::make_shared<Imm>(v); }
static inline std::optional<ScalarImmediateType> GetConstVal(const RawPtr& n)
{
if (n && n->IsImmediate()) {
return std::static_pointer_cast<Imm>(n)->Immediate();
}
return std::nullopt;
}
template <typename Derived>
class SPattern {
friend Derived;
SPattern() = default;
public:
using Nested = Derived;
template <typename NodeType>
[[nodiscard]] bool Match(const NodeType& value) const
{
derived().InitMatch_();
return derived().Match_(value);
}
template <typename NodeType, typename Condition>
[[nodiscard]] bool Match(const NodeType& value, Condition cond) const
{
derived().InitMatch_();
return derived().Match_(value) && cond();
}
[[nodiscard]] const Derived& derived() const { return *static_cast<const Derived*>(this); }
};
class SEqualChecker {
public:
static bool Equal(const RawPtr& lhs, const RawPtr& rhs)
{
if (lhs.get() == rhs.get()) {
return true;
}
if (!lhs || !rhs || lhs->Kind() != rhs->Kind()) {
return false;
}
if (lhs->IsImmediate()) {
return std::static_pointer_cast<Imm>(lhs)->Immediate() == std::static_pointer_cast<Imm>(rhs)->Immediate();
}
return lhs->Dump() == rhs->Dump();
}
};
class PVarRaw : public SPattern<PVarRaw> {
public:
using Nested = const PVarRaw&;
void InitMatch_() const { filled_ = false; }
[[nodiscard]] bool Match_(const RawPtr& value) const
{
if (!filled_) {
value_ = value;
filled_ = true;
return true;
}
return SEqualChecker::Equal(value_, value);
}
[[nodiscard]] RawPtr Eval() const { return value_; }
protected:
mutable RawPtr value_;
mutable bool filled_{false};
};
class PVarImm : public SPattern<PVarImm> {
public:
using Nested = const PVarImm&;
void InitMatch_() const { filled_ = false; }
[[nodiscard]] bool Match_(const RawPtr& value) const
{
if (!value || !value->IsImmediate()) {
return false;
}
if (!filled_) {
value_ = value;
filled_ = true;
return true;
}
return SEqualChecker::Equal(value_, value);
}
[[nodiscard]] RawPtr Eval() const { return value_; }
[[nodiscard]] ScalarImmediateType Val() const { return std::static_pointer_cast<Imm>(value_)->Immediate(); }
protected:
mutable RawPtr value_;
mutable bool filled_{false};
};
class SLiteral : public SPattern<SLiteral> {
public:
explicit SLiteral(RawPtr value) : value_(std::move(value)) {}
void InitMatch_() const {}
[[nodiscard]] bool Match_(const RawPtr& value) const { return SEqualChecker::Equal(value_, value); }
[[nodiscard]] RawPtr Eval() const { return value_; }
private:
RawPtr value_;
};
template <SymbolicOpcode Opcode, typename TA, typename TB>
class SBinaryPattern : public SPattern<SBinaryPattern<Opcode, TA, TB>> {
public:
SBinaryPattern(const TA& a, const TB& b) : a_(a), b_(b) {}
void InitMatch_() const
{
a_.InitMatch_();
b_.InitMatch_();
}
[[nodiscard]] bool Match_(const RawPtr& value) const
{
if (!value || !value->IsExpression()) {
return false;
}
auto e = std::static_pointer_cast<Expr>(value);
auto op = e->Opcode();
bool opcodeMatch = (op == Opcode) || (Opcode == SymbolicOpcode::T_MOP_MIN && op == SymbolicOpcode::T_MOP_MIN) ||
(Opcode == SymbolicOpcode::T_MOP_MAX && op == SymbolicOpcode::T_MOP_MAX);
if (!opcodeMatch || e->OperandList().size() != 0x2) {
return false;
}
if (!a_.Match_(e->OperandList()[0])) {
return false;
}
if (!b_.Match_(e->OperandList()[1])) {
return false;
}
return true;
}
[[nodiscard]] RawPtr Eval() const { return Expr::Create(Opcode, {a_.Eval(), b_.Eval()}); }
private:
typename TA::Nested a_;
typename TB::Nested b_;
};
template <SymbolicOpcode Opcode, typename TA>
class SUnaryPattern : public SPattern<SUnaryPattern<Opcode, TA>> {
public:
explicit SUnaryPattern(const TA& a) : a_(a) {}
void InitMatch_() const { a_.InitMatch_(); }
[[nodiscard]] bool Match_(const RawPtr& value) const
{
if (!value || !value->IsExpression()) {
return false;
}
auto e = std::static_pointer_cast<Expr>(value);
if (e->Opcode() != Opcode || e->OperandList().size() != 1) {
return false;
}
return a_.Match_(e->OperandList()[0]);
}
[[nodiscard]] RawPtr Eval() const { return Expr::Create(Opcode, {a_.Eval()}); }
private:
typename TA::Nested a_;
};
template <typename TA, typename TB>
inline SBinaryPattern<SymbolicOpcode::T_BOP_ADD, TA, TB> operator+(const SPattern<TA>& a, const SPattern<TB>& b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_ADD, TA, TB>(a.derived(), b.derived());
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_ADD, TA, SLiteral> operator+(const SPattern<TA>& a, ScalarImmediateType b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_ADD, TA, SLiteral>(a.derived(), SLiteral(MakeConst(b)));
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_ADD, SLiteral, TA> operator+(ScalarImmediateType b, const SPattern<TA>& a)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_ADD, SLiteral, TA>(SLiteral(MakeConst(b)), a.derived());
}
template <typename TA, typename TB>
inline SBinaryPattern<SymbolicOpcode::T_BOP_SUB, TA, TB> operator-(const SPattern<TA>& a, const SPattern<TB>& b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_SUB, TA, TB>(a.derived(), b.derived());
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_SUB, TA, SLiteral> operator-(const SPattern<TA>& a, ScalarImmediateType b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_SUB, TA, SLiteral>(a.derived(), SLiteral(MakeConst(b)));
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_SUB, SLiteral, TA> operator-(ScalarImmediateType b, const SPattern<TA>& a)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_SUB, SLiteral, TA>(SLiteral(MakeConst(b)), a.derived());
}
template <typename TA, typename TB>
inline SBinaryPattern<SymbolicOpcode::T_BOP_MUL, TA, TB> operator*(const SPattern<TA>& a, const SPattern<TB>& b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_MUL, TA, TB>(a.derived(), b.derived());
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_MUL, TA, SLiteral> operator*(const SPattern<TA>& a, ScalarImmediateType b)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_MUL, TA, SLiteral>(a.derived(), SLiteral(MakeConst(b)));
}
template <typename TA>
inline SBinaryPattern<SymbolicOpcode::T_BOP_MUL, SLiteral, TA> operator*(ScalarImmediateType b, const SPattern<TA>& a)
{
return SBinaryPattern<SymbolicOpcode::T_BOP_MUL, SLiteral, TA>(SLiteral(MakeConst(b)), a.derived());
}
#define SYM_PATTERN_BINARY_NAMED(FuncName, Opcode) \
template <typename TA, typename TB> \
inline SBinaryPattern<Opcode, TA, TB> FuncName(const SPattern<TA>& a, const SPattern<TB>& b) \
{ \
return SBinaryPattern<Opcode, TA, TB>(a.derived(), b.derived()); \
} \
template <typename TA> \
inline SBinaryPattern<Opcode, TA, SLiteral> FuncName(const SPattern<TA>& a, ScalarImmediateType b) \
{ \
return SBinaryPattern<Opcode, TA, SLiteral>(a.derived(), SLiteral(MakeConst(b))); \
} \
template <typename TA> \
inline SBinaryPattern<Opcode, SLiteral, TA> FuncName(ScalarImmediateType b, const SPattern<TA>& a) \
{ \
return SBinaryPattern<Opcode, SLiteral, TA>(SLiteral(MakeConst(b)), a.derived()); \
} \
template <typename TA> \
inline SBinaryPattern<Opcode, TA, SLiteral> FuncName(const SPattern<TA>& a, const RawPtr& b) \
{ \
return SBinaryPattern<Opcode, TA, SLiteral>(a.derived(), SLiteral(b)); \
} \
template <typename TA> \
inline SBinaryPattern<Opcode, SLiteral, TA> FuncName(const RawPtr& b, const SPattern<TA>& a) \
{ \
return SBinaryPattern<Opcode, SLiteral, TA>(SLiteral(b), a.derived()); \
}
SYM_PATTERN_BINARY_NAMED(sym_div, SymbolicOpcode::T_BOP_DIV)
SYM_PATTERN_BINARY_NAMED(sym_mod, SymbolicOpcode::T_BOP_MOD)
SYM_PATTERN_BINARY_NAMED(sym_min, SymbolicOpcode::T_MOP_MIN)
SYM_PATTERN_BINARY_NAMED(sym_max, SymbolicOpcode::T_MOP_MAX)
SYM_PATTERN_BINARY_NAMED(sym_eq, SymbolicOpcode::T_BOP_EQ)
SYM_PATTERN_BINARY_NAMED(sym_ne, SymbolicOpcode::T_BOP_NE)
SYM_PATTERN_BINARY_NAMED(sym_lt, SymbolicOpcode::T_BOP_LT)
SYM_PATTERN_BINARY_NAMED(sym_le, SymbolicOpcode::T_BOP_LE)
SYM_PATTERN_BINARY_NAMED(sym_gt, SymbolicOpcode::T_BOP_GT)
SYM_PATTERN_BINARY_NAMED(sym_ge, SymbolicOpcode::T_BOP_GE)
template <typename TA>
inline SUnaryPattern<SymbolicOpcode::T_UOP_NEG, TA> sym_neg(const SPattern<TA>& a)
{
return SUnaryPattern<SymbolicOpcode::T_UOP_NEG, TA>(a.derived());
}
template <typename TA>
inline SUnaryPattern<SymbolicOpcode::T_UOP_NOT, TA> operator!(const SPattern<TA>& a)
{
return SUnaryPattern<SymbolicOpcode::T_UOP_NOT, TA>(a.derived());
}
template <typename T>
inline auto SPatternEval(T&& val) -> decltype(val.Eval())
{
return val.Eval();
}
inline RawPtr SPatternEval(const RawPtr& val) { return val; }
#define SYM_TRY_REWRITE(SrcExpr, ResExpr) \
if ((SrcExpr).Match(ret)) { \
auto r = SPatternEval(ResExpr); \
return RecursiveRewrite(r); \
}
#define SYM_TRY_REWRITE_IF(SrcExpr, ResExpr, Cond) \
if ((SrcExpr).Match(ret)) { \
if (Cond) { \
auto r = SPatternEval(ResExpr); \
return RecursiveRewrite(r); \
} \
}
class SymbolicScalarSimplify {
public:
SymbolicScalarSimplify() = default;
RawPtr Simplify(const RawPtr& node) { return Visit(node); }
private:
static constexpr int kMaxRecursiveDepth = 5;
int recursive_depth_ = 0;
RawPtr RecursiveRewrite(const RawPtr& node)
{
if (recursive_depth_ >= kMaxRecursiveDepth || !node) {
return node;
}
++recursive_depth_;
RawPtr result = Visit(node);
--recursive_depth_;
return result;
}
RawPtr Visit(const RawPtr& node)
{
if (!node) {
return node;
}
switch (node->Kind()) {
case SymbolicScalarKind::T_SCALAR_SYMBOLIC_IMMEDIATE:
return node;
case SymbolicScalarKind::T_SCALAR_SYMBOLIC_SYMBOL:
return node;
case SymbolicScalarKind::T_SCALAR_SYMBOLIC_EXPRESSION:
return VisitExpression(node);
default:
return node;
}
}
RawPtr VisitExpression(const RawPtr& node)
{
auto e = std::static_pointer_cast<Expr>(node);
auto& ops = e->OperandList();
std::vector<RawPtr> newOps;
newOps.reserve(ops.size());
for (auto& op : ops) {
newOps.push_back(Visit(op));
}
if (Expr::AllImmediate(newOps)) {
auto immList = Expr::ToImmediateList(newOps);
return MakeConst(Expr::FoldAllImmediate(e->Opcode(), immList));
}
SymbolicOpcode opcode = e->Opcode();
if (opcode == SymbolicOpcode::T_UOP_NEG) {
return VisitNeg(newOps[0]);
}
if (opcode == SymbolicOpcode::T_UOP_POS) {
return newOps[0];
}
if (opcode == SymbolicOpcode::T_UOP_NOT) {
return VisitNot(newOps[0]);
}
if (opcode == SymbolicOpcode::T_BOP_ADD) {
return VisitAdd(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_SUB) {
return VisitSub(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_MUL) {
return VisitMul(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_DIV) {
return VisitDiv(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_MOD) {
return VisitMod(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_MOP_MIN) {
return VisitMin(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_MOP_MAX) {
return VisitMax(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_EQ) {
return VisitEq(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_NE) {
return VisitNe(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_LT) {
return VisitLt(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_LE) {
return VisitLe(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_BOP_GT) {
return VisitLt(newOps[1], newOps[0]);
}
if (opcode == SymbolicOpcode::T_BOP_GE) {
return VisitLe(newOps[1], newOps[0]);
}
if (opcode == SymbolicOpcode::T_MOP_MIN && newOps.size() == 0x2) {
return VisitMin(newOps[0], newOps[1]);
}
if (opcode == SymbolicOpcode::T_MOP_MAX && newOps.size() == 0x2) {
return VisitMax(newOps[0], newOps[1]);
}
return Expr::Create(opcode, newOps);
}
RawPtr VisitNeg(const RawPtr& a)
{
RawPtr ret = Expr::CreateUopNeg(a);
PVarRaw x, y;
SYM_TRY_REWRITE(sym_neg(sym_neg(x)), x);
SYM_TRY_REWRITE(sym_neg(x - y), y - x);
return ret;
}
RawPtr VisitNot(const RawPtr& a)
{
RawPtr ret = Expr::CreateUopNot(a);
PVarRaw x, y;
SYM_TRY_REWRITE(!(!x), x);
SYM_TRY_REWRITE(!(sym_lt(x, y)), sym_le(y, x));
SYM_TRY_REWRITE(!(sym_le(x, y)), sym_lt(y, x));
SYM_TRY_REWRITE(!(sym_eq(x, y)), sym_ne(x, y));
SYM_TRY_REWRITE(!(sym_ne(x, y)), sym_eq(x, y));
return ret;
}
RawPtr VisitAdd(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopAdd(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE((x + c1) + c2, x + (c1.Val() + c2.Val()));
SYM_TRY_REWRITE((c1 + x) + c2, x + (c1.Val() + c2.Val()));
SYM_TRY_REWRITE((x - y) + y, x);
SYM_TRY_REWRITE(x + (y - x), y);
SYM_TRY_REWRITE((x - y) + (y - z), x - z);
SYM_TRY_REWRITE((x - y) + (z - x), z - y);
SYM_TRY_REWRITE(x + x, x * 2);
SYM_TRY_REWRITE(x * y + x, (y + 1) * x);
SYM_TRY_REWRITE(x + x * y, (y + 1) * x);
SYM_TRY_REWRITE(y * x + x, (y + 1) * x);
SYM_TRY_REWRITE(x + y * x, (y + 1) * x);
SYM_TRY_REWRITE(x * y + x * z, (y + z) * x);
SYM_TRY_REWRITE(y * x + x * z, (y + z) * x);
SYM_TRY_REWRITE(x * y + z * x, (y + z) * x);
SYM_TRY_REWRITE(y * x + z * x, (y + z) * x);
SYM_TRY_REWRITE(sym_min(x, y - z) + z, sym_min(x + z, y));
SYM_TRY_REWRITE(sym_min(x - z, y) + z, sym_min(x, y + z));
SYM_TRY_REWRITE(z + sym_min(x, y - z), sym_min(x + z, y));
SYM_TRY_REWRITE(sym_max(x, y - z) + z, sym_max(x + z, y));
SYM_TRY_REWRITE(sym_max(x - z, y) + z, sym_max(x, y + z));
SYM_TRY_REWRITE(z + sym_max(x, y - z), sym_max(x + z, y));
SYM_TRY_REWRITE(sym_max(x, y) + sym_min(x, y), x + y);
SYM_TRY_REWRITE(sym_min(x, y) + sym_max(x, y), x + y);
SYM_TRY_REWRITE(sym_max(x, y) + sym_min(y, x), x + y);
SYM_TRY_REWRITE(sym_min(x, y) + sym_max(y, x), x + y);
SYM_TRY_REWRITE(c1 + x, x + c1);
SYM_TRY_REWRITE(x + (c1 - y), (x - y) + c1);
SYM_TRY_REWRITE((c1 - y) + x, (x - y) + c1);
SYM_TRY_REWRITE((x + c1) + y, (x + y) + c1);
return ret;
}
RawPtr VisitSub(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopSub(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(x - x, MakeConst(0));
SYM_TRY_REWRITE((x + y) - y, x);
SYM_TRY_REWRITE((y + x) - x, y);
SYM_TRY_REWRITE((x + y) - x, y);
SYM_TRY_REWRITE((y + x) - y, x);
SYM_TRY_REWRITE(x - (x + y), sym_neg(y));
SYM_TRY_REWRITE(x - (y + x), sym_neg(y));
SYM_TRY_REWRITE((x + c1) - c2, x + (c1.Val() - c2.Val()));
SYM_TRY_REWRITE((c1 + x) - c2, x + (c1.Val() - c2.Val()));
SYM_TRY_REWRITE(c1 - (c2 - x), x + (c1.Val() - c2.Val()));
SYM_TRY_REWRITE(c1 - (x + c2), (c1.Val() - c2.Val()) - x);
SYM_TRY_REWRITE(c1 - (c2 + x), (c1.Val() - c2.Val()) - x);
SYM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1.Val() - c2.Val()));
SYM_TRY_REWRITE((x - y) - (x - z), z - y);
SYM_TRY_REWRITE((x + y) - (x + z), y - z);
SYM_TRY_REWRITE((y + x) - (x + z), y - z);
SYM_TRY_REWRITE((x + y) - (z + x), y - z);
SYM_TRY_REWRITE((y + x) - (z + x), y - z);
SYM_TRY_REWRITE(x * y - x, (y - 1) * x);
SYM_TRY_REWRITE(y * x - x, (y - 1) * x);
SYM_TRY_REWRITE(x - x * y, (1 - y) * x);
SYM_TRY_REWRITE(x - y * x, (1 - y) * x);
SYM_TRY_REWRITE(x * y - x * z, (y - z) * x);
SYM_TRY_REWRITE(y * x - x * z, (y - z) * x);
SYM_TRY_REWRITE(x * y - z * x, (y - z) * x);
SYM_TRY_REWRITE(y * x - z * x, (y - z) * x);
SYM_TRY_REWRITE(sym_min(x, y) - x, sym_min(MakeConst(0), y - x));
SYM_TRY_REWRITE(sym_min(x, y) - y, sym_min(x - y, MakeConst(0)));
SYM_TRY_REWRITE(sym_max(x, y) - x, sym_max(MakeConst(0), y - x));
SYM_TRY_REWRITE(sym_max(x, y) - y, sym_max(x - y, MakeConst(0)));
SYM_TRY_REWRITE(x - sym_min(x, y), sym_max(MakeConst(0), x - y));
SYM_TRY_REWRITE(x - sym_max(x, y), sym_min(MakeConst(0), x - y));
SYM_TRY_REWRITE(sym_min(x + y, z) - x, sym_min(y, z - x));
SYM_TRY_REWRITE(sym_min(y + x, z) - x, sym_min(y, z - x));
SYM_TRY_REWRITE(sym_min(z, x + y) - x, sym_min(z - x, y));
SYM_TRY_REWRITE(sym_min(z, y + x) - x, sym_min(z - x, y));
SYM_TRY_REWRITE(sym_max(x + y, z) - x, sym_max(y, z - x));
SYM_TRY_REWRITE(sym_max(y + x, z) - x, sym_max(y, z - x));
SYM_TRY_REWRITE(sym_max(z, x + y) - x, sym_max(z - x, y));
SYM_TRY_REWRITE(sym_max(z, y + x) - x, sym_max(z - x, y));
SYM_TRY_REWRITE(sym_min(x, y) - sym_min(y, x), MakeConst(0));
SYM_TRY_REWRITE(sym_max(x, y) - sym_max(y, x), MakeConst(0));
SYM_TRY_REWRITE(x - (y + c1), (x - y) + (0 - c1.Val()));
SYM_TRY_REWRITE((x + c1) - y, (x - y) + c1);
SYM_TRY_REWRITE(x - (y - z), (x + z) - y);
return ret;
}
RawPtr VisitMul(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopMul(a, b);
PVarRaw x, y;
PVarImm c1, c2;
SYM_TRY_REWRITE((x * c1) * c2, x * (c1.Val() * c2.Val()));
SYM_TRY_REWRITE((c1 * x) * c2, x * (c1.Val() * c2.Val()));
SYM_TRY_REWRITE(sym_min(x, y) * sym_max(x, y), x * y);
SYM_TRY_REWRITE(sym_max(x, y) * sym_min(x, y), x * y);
SYM_TRY_REWRITE(c1 * x, x * c1);
SYM_TRY_REWRITE(x * (c1 * y), (x * y) * c1);
SYM_TRY_REWRITE((x + c1) * c2, x * c2 + c1.Val() * c2.Val());
SYM_TRY_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1.Val()), c1.Val() < 0);
return ret;
}
RawPtr VisitDiv(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopDiv(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE_IF(sym_div(x * c1, c2), x * (c1.Val() / c2.Val()), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_div(c1 * x, c2), x * (c1.Val() / c2.Val()), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_div(sym_div(x, c1), c2), sym_div(x, c1.Val() * c2.Val()), c1.Val() > 0 && c2.Val() > 0);
SYM_TRY_REWRITE_IF(
sym_div(x + c1, c2), sym_div(x, c2) + (c1.Val() / c2.Val()), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(
sym_div(c1 + x, c2), sym_div(x, c2) + (c1.Val() / c2.Val()), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(
sym_div(x * c1 + y, c2), x * (c1.Val() / c2.Val()) + sym_div(y, c2),
c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(
sym_div(c1 * x + y, c2), x * (c1.Val() / c2.Val()) + sym_div(y, c2),
c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(
sym_div(y + x * c1, c2), sym_div(y, c2) + x * (c1.Val() / c2.Val()),
c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(
sym_div(y + c1 * x, c2), sym_div(y, c2) + x * (c1.Val() / c2.Val()),
c2.Val() > 0 && c1.Val() % c2.Val() == 0);
return ret;
}
RawPtr VisitMod(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopMod(a, b);
PVarRaw x, y;
PVarImm c1, c2;
SYM_TRY_REWRITE_IF(sym_mod(x * c1, c2), MakeConst(0), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(c1 * x, c2), MakeConst(0), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(x + c1, c2), sym_mod(x, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(c1 + x, c2), sym_mod(x, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(x * c1 + y, c2), sym_mod(y, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(c1 * x + y, c2), sym_mod(y, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(y + x * c1, c2), sym_mod(y, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
SYM_TRY_REWRITE_IF(sym_mod(y + c1 * x, c2), sym_mod(y, c2), c2.Val() > 0 && c1.Val() % c2.Val() == 0);
return ret;
}
RawPtr VisitMin(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::Create(SymbolicOpcode::T_MOP_MIN, {a, b});
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(sym_min(x, x), x);
SYM_TRY_REWRITE(sym_min(sym_min(x, c1), c2), sym_min(x, std::min(c1.Val(), c2.Val())));
SYM_TRY_REWRITE(sym_min(c1, sym_min(x, c2)), sym_min(x, std::min(c1.Val(), c2.Val())));
SYM_TRY_REWRITE(sym_min(y - x, z - x), sym_min(y, z) - x);
SYM_TRY_REWRITE(sym_min(x - y, x - z), x - sym_max(y, z));
SYM_TRY_REWRITE(sym_min(x + y, x + z), x + sym_min(y, z));
SYM_TRY_REWRITE(sym_min(y + x, x + z), x + sym_min(y, z));
SYM_TRY_REWRITE(sym_min(x + y, z + x), x + sym_min(y, z));
SYM_TRY_REWRITE(sym_min(y + x, z + x), x + sym_min(y, z));
SYM_TRY_REWRITE(sym_min(x + c1, x + c2), x + std::min(c1.Val(), c2.Val()));
SYM_TRY_REWRITE(sym_min(sym_min(x, y), x), sym_min(x, y));
SYM_TRY_REWRITE(sym_min(sym_min(x, y), y), sym_min(x, y));
SYM_TRY_REWRITE(sym_min(x, sym_min(x, y)), sym_min(x, y));
SYM_TRY_REWRITE(sym_min(y, sym_min(x, y)), sym_min(x, y));
SYM_TRY_REWRITE(sym_min(sym_max(x, y), y), y);
SYM_TRY_REWRITE(sym_min(sym_max(y, x), x), x);
SYM_TRY_REWRITE(sym_min(y, sym_max(x, y)), y);
SYM_TRY_REWRITE(sym_min(sym_max(x, y), x), x);
SYM_TRY_REWRITE(sym_min(x, sym_max(x, y)), x);
SYM_TRY_REWRITE(sym_min(x, sym_max(y, x)), x);
SYM_TRY_REWRITE(sym_min(sym_max(x, y), sym_max(x, z)), sym_max(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_max(x, y), sym_max(z, x)), sym_max(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_max(y, x), sym_max(x, z)), sym_max(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_max(y, x), sym_max(z, x)), sym_max(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_min(x, y), sym_min(x, z)), sym_min(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_min(x, y), sym_min(z, x)), sym_min(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_min(y, x), sym_min(x, z)), sym_min(sym_min(y, z), x));
SYM_TRY_REWRITE(sym_min(sym_min(y, x), sym_min(z, x)), sym_min(sym_min(y, z), x));
SYM_TRY_REWRITE_IF(sym_min(x * c1, y * c1), sym_min(x, y) * c1, c1.Val() > 0);
SYM_TRY_REWRITE_IF(sym_min(x * c1, y * c1), sym_max(x, y) * c1, c1.Val() < 0);
SYM_TRY_REWRITE(sym_min(c1, x), sym_min(x, c1));
SYM_TRY_REWRITE(sym_min(sym_min(x, c1), y), sym_min(sym_min(x, y), c1));
return ret;
}
RawPtr VisitMax(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::Create(SymbolicOpcode::T_MOP_MAX, {a, b});
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(sym_max(x, x), x);
SYM_TRY_REWRITE(sym_max(sym_max(x, c1), c2), sym_max(x, std::max(c1.Val(), c2.Val())));
SYM_TRY_REWRITE(sym_max(c1, sym_max(x, c2)), sym_max(x, std::max(c1.Val(), c2.Val())));
SYM_TRY_REWRITE(sym_max(y - x, z - x), sym_max(y, z) - x);
SYM_TRY_REWRITE(sym_max(x - y, x - z), x - sym_min(y, z));
SYM_TRY_REWRITE(sym_max(x + y, x + z), x + sym_max(y, z));
SYM_TRY_REWRITE(sym_max(y + x, x + z), x + sym_max(y, z));
SYM_TRY_REWRITE(sym_max(x + y, z + x), x + sym_max(y, z));
SYM_TRY_REWRITE(sym_max(y + x, z + x), x + sym_max(y, z));
SYM_TRY_REWRITE(sym_max(x + c1, x + c2), x + std::max(c1.Val(), c2.Val()));
SYM_TRY_REWRITE(sym_max(sym_max(x, y), x), sym_max(x, y));
SYM_TRY_REWRITE(sym_max(sym_max(x, y), y), sym_max(x, y));
SYM_TRY_REWRITE(sym_max(x, sym_max(x, y)), sym_max(x, y));
SYM_TRY_REWRITE(sym_max(y, sym_max(x, y)), sym_max(x, y));
SYM_TRY_REWRITE(sym_max(sym_min(x, y), y), y);
SYM_TRY_REWRITE(sym_max(sym_min(y, x), x), x);
SYM_TRY_REWRITE(sym_max(y, sym_min(x, y)), y);
SYM_TRY_REWRITE(sym_max(sym_min(x, y), x), x);
SYM_TRY_REWRITE(sym_max(x, sym_min(x, y)), x);
SYM_TRY_REWRITE(sym_max(x, sym_min(y, x)), x);
SYM_TRY_REWRITE(sym_max(sym_min(x, y), sym_min(x, z)), sym_min(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_min(x, y), sym_min(z, x)), sym_min(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_min(y, x), sym_min(x, z)), sym_min(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_min(y, x), sym_min(z, x)), sym_min(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_max(x, y), sym_max(x, z)), sym_max(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_max(x, y), sym_max(z, x)), sym_max(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_max(y, x), sym_max(x, z)), sym_max(sym_max(y, z), x));
SYM_TRY_REWRITE(sym_max(sym_max(y, x), sym_max(z, x)), sym_max(sym_max(y, z), x));
SYM_TRY_REWRITE_IF(sym_max(x * c1, y * c1), sym_max(x, y) * c1, c1.Val() > 0);
SYM_TRY_REWRITE_IF(sym_max(x * c1, y * c1), sym_min(x, y) * c1, c1.Val() < 0);
SYM_TRY_REWRITE(sym_max(c1, x), sym_max(x, c1));
SYM_TRY_REWRITE(sym_max(sym_max(x, c1), y), sym_max(sym_max(x, y), c1));
return ret;
}
RawPtr VisitEq(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopEq(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(sym_eq(x, x), MakeConst(1));
SYM_TRY_REWRITE(sym_eq(x + y, x + z), sym_eq(y, z));
SYM_TRY_REWRITE(sym_eq(y + x, x + z), sym_eq(y, z));
SYM_TRY_REWRITE(sym_eq(x + y, z + x), sym_eq(y, z));
SYM_TRY_REWRITE(sym_eq(y + x, z + x), sym_eq(y, z));
SYM_TRY_REWRITE(sym_eq(x - c1, c2), sym_eq(x, c1.Val() + c2.Val()));
SYM_TRY_REWRITE(sym_eq(x + c1, c2), sym_eq(x, c2.Val() - c1.Val()));
SYM_TRY_REWRITE(sym_eq(c1, x), sym_eq(x, c1));
return ret;
}
RawPtr VisitNe(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopNe(a, b);
PVarRaw x, y, z;
PVarImm c1;
SYM_TRY_REWRITE(sym_ne(x, x), MakeConst(0));
SYM_TRY_REWRITE(sym_ne(x + y, x + z), sym_ne(y, z));
SYM_TRY_REWRITE(sym_ne(y + x, x + z), sym_ne(y, z));
SYM_TRY_REWRITE(sym_ne(x + y, z + x), sym_ne(y, z));
SYM_TRY_REWRITE(sym_ne(y + x, z + x), sym_ne(y, z));
SYM_TRY_REWRITE(sym_ne(c1, x), sym_ne(x, c1));
return ret;
}
RawPtr VisitLt(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopLt(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(sym_lt(x, x), MakeConst(0));
SYM_TRY_REWRITE(sym_lt(x + y, x + z), sym_lt(y, z));
SYM_TRY_REWRITE(sym_lt(y + x, x + z), sym_lt(y, z));
SYM_TRY_REWRITE(sym_lt(x + y, z + x), sym_lt(y, z));
SYM_TRY_REWRITE(sym_lt(y + x, z + x), sym_lt(y, z));
SYM_TRY_REWRITE(sym_lt(y - x, z - x), sym_lt(y, z));
SYM_TRY_REWRITE(sym_lt(x - y, x - z), sym_lt(z, y));
SYM_TRY_REWRITE(sym_lt(x, x + z), sym_lt(MakeConst(0), z));
SYM_TRY_REWRITE(sym_lt(x, z + x), sym_lt(MakeConst(0), z));
SYM_TRY_REWRITE(sym_lt(x, x - z), sym_lt(z, MakeConst(0)));
SYM_TRY_REWRITE(sym_lt(x + c1, c2), sym_lt(x, c2.Val() - c1.Val()));
SYM_TRY_REWRITE(sym_lt(x - c1, c2), sym_lt(x, c2.Val() + c1.Val()));
SYM_TRY_REWRITE_IF(sym_lt(x * c1, y * c1), sym_lt(x, y), c1.Val() > 0);
SYM_TRY_REWRITE_IF(sym_lt(x * c1, y * c1), sym_lt(y, x), c1.Val() < 0);
SYM_TRY_REWRITE_IF(sym_lt(sym_div(x, c1), c2), sym_lt(x, c1.Val() * c2.Val()), c1.Val() > 0);
return ret;
}
RawPtr VisitLe(const RawPtr& a, const RawPtr& b)
{
RawPtr ret = Expr::CreateBopLe(a, b);
PVarRaw x, y, z;
PVarImm c1, c2;
SYM_TRY_REWRITE(sym_le(x, x), MakeConst(1));
SYM_TRY_REWRITE(sym_le(x + y, x + z), sym_le(y, z));
SYM_TRY_REWRITE(sym_le(y + x, x + z), sym_le(y, z));
SYM_TRY_REWRITE(sym_le(x + y, z + x), sym_le(y, z));
SYM_TRY_REWRITE(sym_le(y + x, z + x), sym_le(y, z));
SYM_TRY_REWRITE(sym_le(y - x, z - x), sym_le(y, z));
SYM_TRY_REWRITE(sym_le(x - y, x - z), sym_le(z, y));
SYM_TRY_REWRITE(sym_le(x + c1, c2), sym_le(x, c2.Val() - c1.Val()));
SYM_TRY_REWRITE(sym_le(x - c1, c2), sym_le(x, c2.Val() + c1.Val()));
SYM_TRY_REWRITE_IF(sym_le(x * c1, y * c1), sym_le(x, y), c1.Val() > 0);
SYM_TRY_REWRITE_IF(sym_le(x * c1, y * c1), sym_le(y, x), c1.Val() < 0);
return ret;
}
};
}