* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Dialect/Mfuse/IR/MfuseTraits.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Dialect/Mfuse/Support/SymbolAttrUtils.h"
#include "symengine/integer.h"
namespace mlir::mfuse::detail {
static mlir::LogicalResult verifyResultShapeConsistency(mlir::Operation *op) {
for (mlir::Value result : op->getResults()) {
auto rankedType = llvm::dyn_cast<mlir::RankedTensorType>(result.getType());
if (!rankedType) {
continue;
}
auto symAttr = mlir::dyn_cast_or_null<mlir::mfuse::SymbolicShapeAttr>(
SymbolAttrUtils::getSymbolicShapeAttrFromEncoding(rankedType));
if (!symAttr) {
continue;
}
auto symExprs = symAttr.getSymEngineExprs();
auto shape = rankedType.getShape();
if (static_cast<int64_t>(symExprs.size()) != rankedType.getRank()) {
return op->emitOpError() << "symbolic shape rank (" << symExprs.size() << ") does not match tensor rank ("
<< rankedType.getRank() << ")";
}
for (int64_t i = 0; i < rankedType.getRank(); ++i) {
bool isDynamic = mlir::ShapedType::isDynamic(shape[i]);
bool isInteger = SymEngine::is_a<SymEngine::Integer>(*symExprs[i]);
if (!isDynamic) {
if (!isInteger) {
return op->emitOpError() << "dimension " << i << " is static (" << shape[i]
<< ") but symbolic expression is not a constant: " << symExprs[i]->__str__();
}
auto symVal = static_cast<int64_t>(SymEngine::down_cast<const SymEngine::Integer &>(*symExprs[i]).as_int());
if (symVal != shape[i]) {
return op->emitOpError() << "dimension " << i << " is static (" << shape[i]
<< ") but symbolic expression has different value: " << symVal;
}
} else {
if (isInteger) {
return op->emitOpError() << "dimension " << i
<< " is dynamic but symbolic expression is a constant: " << symExprs[i]->__str__();
}
}
}
}
return mlir::success();
}
mlir::LogicalResult verifySymbolicShapeTrait(mlir::Operation *op) {
if (mlir::failed(verifyResultShapeConsistency(op))) {
return mlir::failure();
}
bool hasInputSymbol = false;
for (mlir::Value operand : op->getOperands()) {
if (SymbolAttrUtils::hasSymbolicShapeEncoding(operand.getType())) {
hasInputSymbol = true;
break;
}
}
if (!hasInputSymbol) {
return mlir::success();
}
for (mlir::Value operand : op->getOperands()) {
auto rankedType = llvm::dyn_cast<mlir::RankedTensorType>(operand.getType());
if (rankedType && !rankedType.hasStaticShape() && !SymbolAttrUtils::hasSymbolicShapeEncoding(rankedType)) {
return op->emitOpError() << "failed symbolic shape verification: because at least one "
"input has a symbolic shape, all non-static ranked tensor "
"inputs must also have a symbolic shape.";
}
}
for (mlir::Value result : op->getResults()) {
auto rankedType = llvm::dyn_cast<mlir::RankedTensorType>(result.getType());
if (rankedType && !rankedType.hasStaticShape() && !SymbolAttrUtils::hasSymbolicShapeEncoding(rankedType)) {
return op->emitOpError() << "failed symbolic shape verification: because at least one "
"input has a symbolic shape, all non-static ranked tensor "
"results must also have a symbolic shape.";
}
}
return mlir::success();
}
}