* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Analysis/SymbolicShape/SymEngineAnalysis.h"
#include "symengine/integer.h"
namespace mfusion {
mlir::FailureOr<SymEngineAnalysis::SymExpr> SymEngineAnalysis::getOrAssignSymbol(mlir::Value value,
const SymbolNameResolver &resolver) {
auto it = valueToSymMap_.find(value);
if (it != valueToSymMap_.end()) {
return it->second;
}
auto nameOrFailure = resolver(value);
if (mlir::failed(nameOrFailure)) {
return mlir::failure();
}
auto sym = builder_.makeSymbol(*nameOrFailure);
valueToSymMap_.try_emplace(value, sym);
return sym;
}
mlir::FailureOr<SymEngineAnalysis::SymExpr> SymEngineAnalysis::convertAffineExpr(
mlir::AffineExpr expr, llvm::ArrayRef<SymExpr> dimSymbols, llvm::ArrayRef<SymExpr> symbolSymbols) const {
if (auto constantExpr = mlir::dyn_cast<mlir::AffineConstantExpr>(expr)) {
return builder_.makeInteger(constantExpr.getValue());
}
if (auto dimExpr = mlir::dyn_cast<mlir::AffineDimExpr>(expr)) {
unsigned position = dimExpr.getPosition();
if (position >= dimSymbols.size()) {
return mlir::failure();
}
return dimSymbols[position];
}
if (auto symExpr = mlir::dyn_cast<mlir::AffineSymbolExpr>(expr)) {
unsigned position = symExpr.getPosition();
if (position >= symbolSymbols.size()) {
return mlir::failure();
}
return symbolSymbols[position];
}
if (auto binExpr = mlir::dyn_cast<mlir::AffineBinaryOpExpr>(expr)) {
auto lhs = convertAffineExpr(binExpr.getLHS(), dimSymbols, symbolSymbols);
auto rhs = convertAffineExpr(binExpr.getRHS(), dimSymbols, symbolSymbols);
if (mlir::failed(lhs) || mlir::failed(rhs)) {
return mlir::failure();
}
switch (binExpr.getKind()) {
case mlir::AffineExprKind::Add:
return builder_.makeAdd(*lhs, *rhs);
case mlir::AffineExprKind::Mul:
return builder_.makeMul(*lhs, *rhs);
case mlir::AffineExprKind::FloorDiv:
return builder_.makeDiv(*lhs, *rhs);
case mlir::AffineExprKind::CeilDiv: {
auto divExpr = builder_.makeDiv(*lhs, *rhs);
return builder_.makeCeil(divExpr);
}
default:
return mlir::failure();
}
}
return mlir::failure();
}
mlir::FailureOr<llvm::SmallVector<SymEngineAnalysis::SymExpr>> SymEngineAnalysis::applyAffineMap(
mlir::AffineMap map, mlir::ValueRange symbols, const SymbolNameResolver &resolver) {
if (map.getNumDims() != 0) {
return mlir::failure();
}
if (symbols.size() != map.getNumSymbols()) {
return mlir::failure();
}
llvm::SmallVector<SymExpr> symbolSyms;
symbolSyms.reserve(symbols.size());
for (mlir::Value symbolValue : symbols) {
auto sym = getOrAssignSymbol(symbolValue, resolver);
if (mlir::failed(sym)) {
return mlir::failure();
}
symbolSyms.push_back(*sym);
}
llvm::SmallVector<SymExpr> results;
results.reserve(map.getNumResults());
for (auto expr : map.getResults()) {
auto converted = convertAffineExpr(expr, {}, symbolSyms);
if (mlir::failed(converted)) {
return mlir::failure();
}
results.push_back(*converted);
}
return results;
}
mlir::FailureOr<int64_t> SymEngineAnalysis::tryExtractInt64(const SymExpr &expr) const {
if (!SymEngine::is_a<SymEngine::Integer>(*expr)) {
return mlir::failure();
}
try {
return static_cast<int64_t>(SymEngine::down_cast<const SymEngine::Integer &>(*expr).as_int());
} catch (...) {
return mlir::failure();
}
}
bool SymEngineAnalysis::isStructurallyEqual(const SymExpr &lhs, const SymExpr &rhs) const {
return SymEngine::eq(*lhs, *rhs);
}
bool SymEngineAnalysis::isStructurallyNotEqual(const SymExpr &lhs, const SymExpr &rhs) const {
return SymEngine::neq(*lhs, *rhs);
}
void SymEngineAnalysis::reset() { valueToSymMap_.clear(); }
}