* Copyright 2023-2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <iterator>
#include "akg/Analysis/SymbolicShapeAnalysis.h"
#include "akg/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
#ifndef GEN_PASS_DECL_REMOVESYMBOLIC
#define GEN_PASS_DECL_REMOVESYMBOLIC
#ifndef GEN_PASS_DEF_REMOVESYMBOLIC
#define GEN_PASS_DEF_REMOVESYMBOLIC
#include "akg/Transforms/Passes.h.inc"
#endif
#endif
static Type RemoveTypeSymbolic(Type type) {
SymbolicShapeAnalysis &analysis = SymbolicShapeAnalysis::getInstance();
if (!analysis.hasSymbolicShape(type)) {
return type;
}
auto shapedType = cast<ShapedType>(type);
auto shape = shapedType.getShape();
auto elementType = shapedType.getElementType();
mlir::DictionaryAttr dict;
if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
dict = dyn_cast_or_null<mlir::DictionaryAttr>(tensorType.getEncoding());
} else if (auto memRefType = dyn_cast<MemRefType>(type)) {
dict = dyn_cast_or_null<mlir::DictionaryAttr>(memRefType.getMemorySpace());
}
mlir::Attribute newAttr = nullptr;
if (dict) {
NamedAttrList attrList(dict);
(void)attrList.erase(StringAttr::get(type.getContext(), getSymbolShapeAttrName()));
if (!attrList.empty()) {
newAttr = attrList.getDictionary(type.getContext());
}
}
if (auto memRefType = dyn_cast<MemRefType>(type)) {
return MemRefType::get(shape, elementType, memRefType.getLayout(), newAttr);
} else if (isa<RankedTensorType>(type)) {
return RankedTensorType::get(shape, elementType, newAttr);
}
return type;
}
static void RemoveFuncSymbolic(func::FuncOp &func) {
for (auto value : func.getArguments()) {
value.setType(RemoveTypeSymbolic(value.getType()));
}
llvm::SmallVector<Type, 4> newInTys;
(void)std::transform(func.getArgumentTypes().begin(), func.getArgumentTypes().end(), std::back_inserter(newInTys),
[](const Type type) { return RemoveTypeSymbolic(type); });
llvm::SmallVector<Type, 4> newResTys;
(void)std::transform(func.getResultTypes().begin(), func.getResultTypes().end(), std::back_inserter(newResTys),
[](Type type) { return RemoveTypeSymbolic(type); });
auto newFuncTy = mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
func.setType(newFuncTy);
}
namespace {
struct RemoveSymbolic : public impl::RemoveSymbolicBase<RemoveSymbolic> {
void runOnOperation() override {
(void)getOperation()->walk([&](Operation *op) {
if (auto castOp = dyn_cast<memref::MemorySpaceCastOp>(op)) {
Value src = castOp.getSource();
src.setType(RemoveTypeSymbolic(src.getType()));
castOp.getResult().replaceAllUsesWith(src);
castOp.erase();
return WalkResult::advance();
}
if (isa<func::FuncOp>(op)) {
func::FuncOp func = dyn_cast<func::FuncOp>(op);
RemoveFuncSymbolic(func);
}
for (mlir::Value opnd : op->getOperands()) {
opnd.setType(RemoveTypeSymbolic(opnd.getType()));
}
for (mlir::Value resVal : op->getResults()) {
resVal.setType(RemoveTypeSymbolic(resVal.getType()));
}
return WalkResult::advance();
});
}
};
}
}
std::unique_ptr<mlir::Pass> mlir::createSymbolicRemovalPass() { return std::make_unique<RemoveSymbolic>(); }