* Copyright 2025 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/NPUAutoTiling.h"
#include <algorithm>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "akg/Analysis/LoopTiling.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "bishengir/Dialect/HACC/IR/HACC.h"
#define DEBUG_TYPE "npu-auto-tiling"
namespace mlir {
#define GEN_PASS DECL_NPUAUTOTILING
#define GEN_PASS_DEF_NPUAUTOTILING
#include "akg/Transforms/Passes.h.inc"
}
namespace mlir {
namespace mockattr {
static constexpr const char *kEnableAutoMarkBufferSize = "enable_auto_mark_buffer_size";
}
namespace {
using hacc::BlockDimAttr;
using hacc::HACCFuncType;
using hacc::HACCFuncTypeAttr;
using hacc::HACCToLLVMIRTranslateAttr;
using hacc::HostFuncType;
using hacc::HostFuncTypeAttr;
using hacc::InputIdxAttr;
using hacc::KernelArgType;
using hacc::KernelArgTypeAttr;
using hacc::OutputIdxAttr;
using hacc::TilingFunctionAttr;
struct AutoTilingOptions {
unsigned blockDim = 40;
};
struct KernelInfo {
func::FuncOp originalKernel;
std::string baseKernelName;
unsigned blockDim = 40;
};
static void setTilingKeyAndDataArgAttrs(func::FuncOp func, unsigned keyIdx, unsigned tilingDataIdx) {
auto *ctx = func.getContext();
auto katName = StringAttr::get(ctx, KernelArgTypeAttr::name);
auto setArgKernelKind = [&](unsigned idx, KernelArgType kind) {
DictionaryAttr old = func.getArgAttrDict(idx);
SmallVector<NamedAttribute> nas;
if (old) {
nas.reserve(old.size() + 1);
std::copy(old.begin(), old.end(), std::back_inserter(nas));
}
auto kat = KernelArgTypeAttr::get(ctx, kind);
bool replaced = false;
for (auto &na : nas) {
if (na.getName() == katName) {
na = NamedAttribute(katName, kat);
replaced = true;
break;
}
}
if (!replaced) nas.emplace_back(katName, kat);
func.setArgAttrs(idx, DictionaryAttr::get(ctx, nas));
};
setArgKernelKind(keyIdx, KernelArgType::kTilingKey);
setArgKernelKind(tilingDataIdx, KernelArgType::kTilingStruct);
}
struct TilingInfo {
func::FuncOp hostTilingFunc;
DenseMap<int64_t, func::FuncOp> perCaseTilingFuncs;
DenseMap<int64_t, func::FuncOp> tilingKey2Kernel;
std::optional<int64_t> constTilingKey;
bool isStaticShape = false;
unsigned tilingStructSize = 0;
void setHostTilingFunc(func::FuncOp f) { hostTilingFunc = f; }
func::FuncOp getHostTilingFunc() const { return hostTilingFunc; }
void setPerCaseTilingFunc(int64_t key, func::FuncOp f) { perCaseTilingFuncs[key] = f; }
func::FuncOp getPerCaseTilingFunc(int64_t key) const {
auto it = perCaseTilingFuncs.find(key);
if (it == perCaseTilingFuncs.end()) return func::FuncOp();
return it->second;
}
SmallVector<int64_t> getAllKeys() const {
SmallVector<int64_t> keys;
keys.reserve(perCaseTilingFuncs.size());
std::transform(perCaseTilingFuncs.begin(), perCaseTilingFuncs.end(), std::back_inserter(keys),
[](const auto &it) { return it.first; });
llvm::sort(keys.begin(), keys.end());
return keys;
}
void recordKernelFunc(int64_t key, func::FuncOp func) { tilingKey2Kernel[key] = func; }
DenseMap<int64_t, func::FuncOp> getTilingKey2KernelMap() const { return tilingKey2Kernel; }
SmallVector<int64_t> getTilingCases() const { return getAllKeys(); }
bool analyzeInputShapeStatic(func::FuncOp kernel) {
for (auto arg : kernel.getArguments()) {
if (auto memrefType = dyn_cast<MemRefType>(arg.getType())) {
if (!memrefType.hasStaticShape()) {
return false;
}
}
}
return true;
}
std::optional<int64_t> computeStaticTilingKey(func::FuncOp kernel) {
for (auto arg : kernel.getArguments()) {
if (auto memrefType = dyn_cast<MemRefType>(arg.getType())) {
if (memrefType.hasStaticShape()) {
return 0;
}
}
}
return std::nullopt;
}
LogicalResult trySimplifyTilingFunc(func::FuncOp originalKernel) {
if (!hostTilingFunc) return success();
isStaticShape = analyzeInputShapeStatic(originalKernel);
if (isStaticShape) {
if (auto staticKey = computeStaticTilingKey(originalKernel)) {
constTilingKey = *staticKey;
LLVM_DEBUG(llvm::dbgs() << "Static shape detected, tiling key = " << *staticKey << "\n");
simplifyHostTilingFuncForStaticShape(*staticKey);
return success();
}
}
return trySimplifyDynamicTilingFunc();
}
void simplifyHostTilingFuncForStaticShape(int64_t selectedKey) {
auto &body = hostTilingFunc.getBody();
while (!body.empty()) {
body.front().erase();
}
Block *entry = hostTilingFunc.addEntryBlock();
OpBuilder b(entry, entry->begin());
Location loc = hostTilingFunc.getLoc();
unsigned numArgs = hostTilingFunc.getNumArguments();
unsigned keyIdx = numArgs - 2;
unsigned tilingDataIdx = numArgs - 1;
SmallVector<Value> args(entry->args_begin(), entry->args_end());
Value keyPtr = args[keyIdx];
Value dataMem = args[tilingDataIdx];
(void)keyPtr;
(void)dataMem;
func::FuncOp perCaseFunc = getPerCaseTilingFunc(selectedKey);
if (perCaseFunc) {
SmallVector<Value> callArgs(args.begin(), args.begin() + keyIdx + 2);
b.create<func::CallOp>(loc, perCaseFunc.getSymName(), TypeRange{}, callArgs);
}
b.create<func::ReturnOp>(loc);
}
LogicalResult trySimplifyDynamicTilingFunc() {
if (auto constKey = tryExtractConstantFromHostFunc()) {
constTilingKey = *constKey;
LLVM_DEBUG(llvm::dbgs() << "Constant tiling key extracted from host func: " << *constKey << "\n");
return success();
}
return success();
}
std::optional<int64_t> tryExtractConstantFromHostFunc() {
std::optional<int64_t> result;
hostTilingFunc.walk([&](Operation *op) {
if (auto returnOp = dyn_cast<func::ReturnOp>(op)) {
if (returnOp.getNumOperands() > 0) {
if (auto constVal = extractConstantIntValue(returnOp.getOperand(0))) {
result = *constVal;
}
}
}
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
if (auto constResult = analyzeIfOpForConst(ifOp)) {
result = *constResult;
}
}
});
return result;
}
std::optional<int64_t> analyzeIfOpForConst(scf::IfOp ifOp) {
if (ifOp.getNumResults() == 0) return std::nullopt;
auto thenYield = dyn_cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
auto elseYield = dyn_cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
if (thenYield && elseYield && thenYield.getNumOperands() > 0 && elseYield.getNumOperands() > 0) {
auto thenConst = extractConstantIntValue(thenYield.getOperand(0));
auto elseConst = extractConstantIntValue(elseYield.getOperand(0));
if (thenConst && elseConst && *thenConst == *elseConst) {
return thenConst;
}
}
return std::nullopt;
}
std::optional<int64_t> extractConstantIntValue(Value val) {
if (auto constOp = val.getDefiningOp<arith::ConstantIntOp>()) {
auto attr = constOp.getValue();
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
return intAttr.getInt();
}
}
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
auto attr = constOp.getValue();
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
return intAttr.getInt();
}
}
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
return constOp.value();
}
return std::nullopt;
}
bool isTilingFullyStatic() const { return constTilingKey.has_value(); }
void pruneUnusedTilingFuncs(int64_t selectedKey, ModuleOp module) {
SmallVector<func::FuncOp> toErase;
for (auto &it : perCaseTilingFuncs) {
if (it.first != selectedKey) {
toErase.push_back(it.second);
}
}
for (auto &it : tilingKey2Kernel) {
if (it.first != selectedKey) {
toErase.push_back(it.second);
}
}
for (auto func : toErase) func.erase();
DenseMap<int64_t, func::FuncOp> newPerCaseMap;
auto it = perCaseTilingFuncs.find(selectedKey);
if (it != perCaseTilingFuncs.end()) newPerCaseMap[selectedKey] = it->second;
perCaseTilingFuncs = std::move(newPerCaseMap);
DenseMap<int64_t, func::FuncOp> newKernelMap;
auto kit = tilingKey2Kernel.find(selectedKey);
if (kit != tilingKey2Kernel.end()) newKernelMap[selectedKey] = kit->second;
tilingKey2Kernel = std::move(newKernelMap);
}
};
class TilingBase {
public:
explicit TilingBase(func::FuncOp f)
: originalKernel_(f),
module_(f ? f->getParentOfType<ModuleOp>() : ModuleOp()),
kernelInfo_(std::make_unique<KernelInfo>()),
tilingInfo_() {
if (kernelInfo_) kernelInfo_->originalKernel = f;
}
virtual ~TilingBase() = default;
LogicalResult runOnOperation(OpBuilder &builder) {
if (failed(runPreTilingProcedure(builder))) return failure();
if (failed(runTilingProcedure(builder))) return failure();
if (failed(runPostTilingProcedure(builder))) return failure();
if (failed(createOrGetGetTilingStructSizeFunction(builder))) return failure();
return success();
}
static void setAutoTilingOptions(const AutoTilingOptions &opt) { options_ = opt; }
protected:
LogicalResult runPreTilingProcedure(OpBuilder &builder) {
auto *ctx = originalKernel_.getContext();
ctx->getOrLoadDialect<hacc::HACCDialect>();
ctx->getOrLoadDialect<scf::SCFDialect>();
ctx->getOrLoadDialect<arith::ArithDialect>();
ctx->getOrLoadDialect<memref::MemRefDialect>();
ctx->getOrLoadDialect<LLVM::LLVMDialect>();
kernelInfo_->baseKernelName = originalKernel_.getSymName().str();
kernelInfo_->blockDim = options_.blockDim;
originalKernel_->setAttr(HACCFuncTypeAttr::name, HACCFuncTypeAttr::get(ctx, HACCFuncType::HOST));
return success();
}
LogicalResult runTilingProcedure(OpBuilder &builder) {
llvm::DenseMap<int64_t, mlir::func::FuncOp> tilingFuncMap;
if (failed(mlir::autotiling::createTilingFunctions(originalKernel_, builder, tilingFuncMap))) return failure();
for (const auto &it : tilingFuncMap) {
int64_t key = it.getFirst();
mlir::func::FuncOp f = it.getSecond();
tilingInfo_.setPerCaseTilingFunc(key, f);
}
if (failed(computeTilingStructSizeFromTilingFuncs(builder, tilingFuncMap))) return failure();
if (failed(createHostTilingFunction(builder))) return failure();
if (failed(tilingInfo_.trySimplifyTilingFunc(originalKernel_))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to simplify host tiling func\n");
} else {
LLVM_DEBUG(llvm::dbgs() << "Successfully simplified host tiling func\n");
}
SmallVector<int64_t> tilingKeys;
if (tilingInfo_.isTilingFullyStatic()) {
int64_t selectedKey = *tilingInfo_.constTilingKey;
LLVM_DEBUG(llvm::dbgs() << "Selected tiling key (constant folded): " << selectedKey << "\n");
tilingKeys.push_back(selectedKey);
} else {
tilingKeys = tilingInfo_.getAllKeys();
}
if (tilingKeys.empty()) {
originalKernel_.emitError("no per-case tiling functions recorded in TilingInfo");
return failure();
}
bool allOk = std::all_of(tilingKeys.begin(), tilingKeys.end(),
[&](int64_t key) { return succeeded(initTilingKernel(key, builder)); });
if (!allOk) return failure();
if (tilingInfo_.isTilingFullyStatic()) {
tilingInfo_.pruneUnusedTilingFuncs(*tilingInfo_.constTilingKey, module_);
}
if (failed(applyTilingImpl(builder))) return failure();
if (failed(fixCallSitesAndCaller(builder))) return failure();
return success();
}
LogicalResult computeTilingStructSizeFromTilingFuncs(
OpBuilder &builder, const llvm::DenseMap<int64_t, mlir::func::FuncOp> &tilingFuncMap) {
auto it0 = tilingFuncMap.find(0);
if (it0 == tilingFuncMap.end()) {
originalKernel_.emitError("tilingFuncMap does not contain key=0");
return failure();
}
func::FuncOp tilingFunc0 = it0->second;
auto *ctx = builder.getContext();
auto katName = StringAttr::get(ctx, KernelArgTypeAttr::name);
unsigned structSize = 0;
for (auto [idx, arg] : llvm::enumerate(tilingFunc0.getArguments())) {
DictionaryAttr dict = tilingFunc0.getArgAttrDict(idx);
if (!dict) continue;
Attribute attr = dict.get(katName);
if (!attr) continue;
auto katAttr = dyn_cast<KernelArgTypeAttr>(attr);
if (!katAttr) continue;
if (katAttr.getArgType() == KernelArgType::kTilingStruct) {
auto memrefTy = dyn_cast<MemRefType>(arg.getType());
if (!memrefTy || memrefTy.getRank() != 1 || !memrefTy.getElementType().isInteger(64)) {
originalKernel_.emitError("tiling struct argument must be rank-1 memref<i64>");
return failure();
}
int64_t dim0 = memrefTy.getDimSize(0);
if (dim0 <= 0) {
originalKernel_.emitError("tiling struct size must be a positive static dimension");
return failure();
}
structSize = static_cast<unsigned>(dim0);
break;
}
}
if (structSize == 0) {
structSize = 1;
}
tilingInfo_.tilingStructSize = structSize;
return success();
}
LogicalResult runPostTilingProcedure(OpBuilder &) { return success(); }
void copyHaccIOAttrsFrom(func::FuncOp orig, func::FuncOp dst) {
if (std::optional<ArrayAttr> maybeArray = orig.getArgAttrs()) {
ArrayAttr arr = *maybeArray;
unsigned n = std::min<unsigned>(arr.size(), dst.getNumArguments());
for (unsigned i = 0; i < n; ++i) {
if (auto dict = dyn_cast_or_null<DictionaryAttr>(arr[i])) {
SmallVector<NamedAttribute, 4> attrs;
attrs.reserve(dict.size());
std::copy(dict.begin(), dict.end(), std::back_inserter(attrs));
for (const auto &na : attrs) dst.setArgAttr(i, na.getName(), na.getValue());
}
}
}
}
void copyAttrsForDeviceFromHost(func::FuncOp host, func::FuncOp device) {
for (auto &na : host->getAttrs()) {
auto name = na.getName().strref();
if (name == "hacc.function_kind" || name == "hacc.host_func_type" || name == "hacc.tiling_function" ||
name == "mindspore_kernel") {
continue;
}
if (name == "OperatorType" || name == "compute_capability" || name == "process") {
device->setAttr(na.getName(), na.getValue());
}
}
}
func::FuncOp createHostTilingFuncOpHeader(OpBuilder &builder, MLIRContext *ctx, Location loc,
ArrayRef<Type> argTypes) {
auto llvmPtrTy = LLVM::LLVMPointerType::get(ctx);
auto i64Ty = builder.getI64Type();
unsigned sz = tilingInfo_.tilingStructSize;
if (sz == 0) sz = 1;
auto memrefTy = MemRefType::get({static_cast<int64_t>(sz)}, i64Ty);
SmallVector<Type> fullArgs(argTypes.begin(), argTypes.end());
fullArgs.push_back(llvmPtrTy);
fullArgs.push_back(memrefTy);
SmallVector<Type> resTypes;
std::string name = kernelInfo_->baseKernelName + "_tiling_function";
auto funcTy = FunctionType::get(ctx, fullArgs, resTypes);
auto host = builder.create<func::FuncOp>(loc, name, funcTy);
host.addEntryBlock();
copyHaccIOAttrsFrom(originalKernel_, host);
host->setAttr(HACCFuncTypeAttr::name, HACCFuncTypeAttr::get(ctx, HACCFuncType::HOST));
host->setAttr(HostFuncTypeAttr::name, HostFuncTypeAttr::get(ctx, HostFuncType::kTilingFunction));
unsigned numArgs = host.getNumArguments();
unsigned keyIdx = numArgs - 2;
unsigned tilingDataIdx = numArgs - 1;
setTilingKeyAndDataArgAttrs(host, keyIdx, tilingDataIdx);
return host;
}
Value selectKeyByDimension(OpBuilder &b, Location loc, ArrayRef<Value> args) {
auto k0 = b.create<arith::ConstantIntOp>(loc, 0, 64);
return k0;
}
LogicalResult createHostTilingFunction(OpBuilder &builder) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(originalKernel_);
auto *ctx = builder.getContext();
auto loc = originalKernel_.getLoc();
auto origTy = originalKernel_.getFunctionType();
SmallVector<Type> argTypes(origTy.getInputs().begin(), origTy.getInputs().end());
func::FuncOp host = createHostTilingFuncOpHeader(builder, ctx, loc, argTypes);
if (!host) return failure();
OpBuilder bodyBuilder(&host.getBody().front(), host.getBody().front().end());
SmallVector<Value> args(host.getArguments().begin(), host.getArguments().end());
unsigned numArgs = args.size();
if (numArgs < 2) {
host.emitError("host tiling function expects at least 2 tail args");
return failure();
}
unsigned keyIdx = numArgs - 2;
unsigned tilingDataIdx = numArgs - 1;
SmallVector<Value> computeArgs(args.begin(), args.begin() + keyIdx);
Value keyPtr = args[keyIdx];
Value dataMem = args[tilingDataIdx];
Value logicalKeyI64 = selectKeyByDimension(bodyBuilder, loc, computeArgs);
Value keyIndex = bodyBuilder.create<arith::IndexCastUIOp>(loc, bodyBuilder.getIndexType(), logicalKeyI64);
SmallVector<int64_t> caseKeys = tilingInfo_.getAllKeys();
if (caseKeys.empty()) {
host.emitError("no tiling cases recorded in TilingInfo for host tiling function");
return failure();
}
auto switchOp = bodyBuilder.create<scf::IndexSwitchOp>(loc, TypeRange{}, keyIndex, ArrayRef<int64_t>(caseKeys),
caseKeys.size());
for (unsigned i = 0; i < caseKeys.size(); ++i) {
int64_t key = caseKeys[i];
Region ® = switchOp.getCaseRegions()[i];
Block *blk = new Block();
reg.push_back(blk);
OpBuilder cb(blk, blk->begin());
func::FuncOp perCaseF = tilingInfo_.getPerCaseTilingFunc(key);
if (!perCaseF) {
host.emitError() << "missing per-case tiling function for key=" << key;
return failure();
}
SmallVector<Value> callArgs;
callArgs.append(computeArgs.begin(), computeArgs.end());
callArgs.push_back(keyPtr);
callArgs.push_back(dataMem);
cb.create<func::CallOp>(loc, perCaseF.getSymName(), perCaseF.getFunctionType().getResults(), callArgs);
cb.create<scf::YieldOp>(loc);
}
{
Region &defaultReg = switchOp.getDefaultRegion();
Block *blk = new Block();
defaultReg.push_back(blk);
OpBuilder db(blk, blk->begin());
Value falseVal = db.create<arith::ConstantIntOp>(loc, 0, db.getI1Type());
auto msgAttr = db.getStringAttr("Invalid tiling key");
db.create<cf::AssertOp>(loc, falseVal, msgAttr);
db.create<scf::YieldOp>(loc);
}
bodyBuilder.create<func::ReturnOp>(loc);
tilingInfo_.setHostTilingFunc(host);
return success();
}
LogicalResult collectDeviceSignature(func::FuncOp orig, SmallVector<Type> &devInputs, SmallVector<Type> &devResults,
int64_t ) {
auto origTy = orig.getFunctionType();
devInputs.clear();
devInputs.reserve(origTy.getNumInputs());
std::copy(origTy.getInputs().begin(), origTy.getInputs().end(), std::back_inserter(devInputs));
devResults.clear();
devResults.reserve(origTy.getNumResults());
std::copy(origTy.getResults().begin(), origTy.getResults().end(), std::back_inserter(devResults));
if (tilingInfo_.isTilingFullyStatic()) return success();
auto *ctx = orig.getContext();
auto i64Ty = IntegerType::get(ctx, 64);
auto llvmPtrTy = LLVM::LLVMPointerType::get(ctx);
unsigned sz = tilingInfo_.tilingStructSize;
if (sz == 0) sz = 1;
auto memrefTy = MemRefType::get({static_cast<int64_t>(sz)}, i64Ty);
devInputs.push_back(llvmPtrTy);
devInputs.push_back(memrefTy);
return success();
}
func::FuncOp createAndAnnotateDeviceFunc(OpBuilder &builder, Location loc, StringRef name, FunctionType devTy,
FunctionType , unsigned blockDim, func::FuncOp hostTiling) {
auto deviceFunc = builder.create<func::FuncOp>(loc, name, devTy);
deviceFunc.addEntryBlock();
copyHaccIOAttrsFrom(originalKernel_, deviceFunc);
copyAttrsForDeviceFromHost(originalKernel_, deviceFunc);
deviceFunc->setAttr(mockattr::kEnableAutoMarkBufferSize, builder.getUnitAttr());
deviceFunc->setAttr(HACCFuncTypeAttr::name, HACCFuncTypeAttr::get(builder.getContext(), HACCFuncType::DEVICE));
deviceFunc->setAttr(BlockDimAttr::name, builder.getI64IntegerAttr(blockDim));
deviceFunc->setAttr(
StringAttr::get(builder.getContext(), stringifyHACCToLLVMIRTranslateAttr(HACCToLLVMIRTranslateAttr::ENTRY)),
builder.getUnitAttr());
if (hostTiling) {
deviceFunc->setAttr(
TilingFunctionAttr::name,
TilingFunctionAttr::get(builder.getContext(), FlatSymbolRefAttr::get(hostTiling.getSymNameAttr())));
}
if (!tilingInfo_.isTilingFullyStatic()) {
unsigned numInputs = devTy.getNumInputs();
unsigned keyIdx = numInputs - 2;
unsigned tilingDataIdx = numInputs - 1;
setTilingKeyAndDataArgAttrs(deviceFunc, keyIdx, tilingDataIdx);
}
return deviceFunc;
}
SmallVector<Value> cloneKernelBodyToDeviceFunc(func::FuncOp originalKernel, func::FuncOp deviceFunc) {
SmallVector<Value> returnedValues;
if (originalKernel.empty()) return returnedValues;
IRMapping map;
Block &oldEntry = originalKernel.front();
Block &newEntry = deviceFunc.getBody().front();
unsigned numArgs = std::min<unsigned>(oldEntry.getNumArguments(), newEntry.getNumArguments());
for (unsigned i = 0; i < numArgs; ++i) map.map(oldEntry.getArgument(i), newEntry.getArgument(i));
func::ReturnOp oldRet = nullptr;
SmallVector<Operation *> toClone;
toClone.reserve(oldEntry.getOperations().size());
for (Operation &op : oldEntry) {
if (auto r = dyn_cast<func::ReturnOp>(op)) {
oldRet = r;
continue;
}
toClone.push_back(&op);
}
OpBuilder b = OpBuilder::atBlockEnd(&newEntry);
for (Operation *op : toClone) b.clone(*op, map);
if (oldRet) {
unsigned numRet = oldRet.getNumOperands();
SmallVector<Value> retVals;
retVals.reserve(numRet);
returnedValues.resize(numRet);
for (unsigned i = 0; i < numRet; ++i) {
Value mapped = map.lookupOrNull(oldRet.getOperand(i));
returnedValues[i] = mapped;
if (mapped) retVals.push_back(mapped);
}
b.create<func::ReturnOp>(deviceFunc.getLoc(), retVals);
} else {
b.create<func::ReturnOp>(deviceFunc.getLoc(), ValueRange{});
}
return returnedValues;
}
LogicalResult initTilingKernel(int64_t key, OpBuilder &builder) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(originalKernel_);
SmallVector<Type> devInputs, devResults;
if (failed(collectDeviceSignature(originalKernel_, devInputs, devResults, key))) return failure();
std::string keyStr = std::to_string(key);
if (keyStr.size() == 1) keyStr = "0" + keyStr;
std::string name = kernelInfo_->baseKernelName + "_" + keyStr;
auto devTy = FunctionType::get(builder.getContext(), devInputs, devResults);
auto origTy = originalKernel_.getFunctionType();
auto deviceFunc = createAndAnnotateDeviceFunc(builder, originalKernel_.getLoc(), name, devTy, origTy,
kernelInfo_->blockDim, tilingInfo_.getPerCaseTilingFunc(key));
(void)cloneKernelBodyToDeviceFunc(originalKernel_, deviceFunc);
tilingInfo_.recordKernelFunc(key, deviceFunc);
return success();
}
LogicalResult applyTilingImpl(OpBuilder &builder) {
for (const auto &it : tilingInfo_.getTilingKey2KernelMap()) {
int64_t key = it.getFirst();
mlir::func::FuncOp f = it.getSecond();
if (failed(mlir::autotiling::applyTilingFromTilingFunc(f, builder, tilingInfo_.isStaticShape))) return failure();
tilingInfo_.recordKernelFunc(key, f);
}
return success();
}
LogicalResult fixCallSitesAndCaller(OpBuilder &builder) {
auto *ctx = builder.getContext();
auto oldTy = originalKernel_.getFunctionType();
unsigned oldNumInputs = oldTy.getNumInputs();
if (tilingInfo_.isTilingFullyStatic()) {
return rewriteOriginalKernelBodyStatic(builder, ctx, oldTy, oldNumInputs);
}
auto i64Ty = builder.getI64Type();
auto llvmPtrTy = LLVM::LLVMPointerType::get(ctx);
unsigned sz = tilingInfo_.tilingStructSize;
if (sz == 0) sz = 1;
auto tilingdataTy = MemRefType::get({static_cast<int64_t>(sz)}, i64Ty);
SmallVector<Type> newInputs(oldTy.getInputs().begin(), oldTy.getInputs().end());
newInputs.push_back(llvmPtrTy);
newInputs.push_back(tilingdataTy);
auto newTy = FunctionType::get(ctx, newInputs, oldTy.getResults());
originalKernel_.setType(newTy);
if (failed(updateOriginalKernelArgAttrs(oldTy, oldNumInputs))) return failure();
return rewriteOriginalKernelBodyDynamic(builder, ctx, oldTy, oldNumInputs);
}
LogicalResult updateOriginalKernelArgAttrs(FunctionType oldTy, unsigned oldNumInputs) {
(void)oldTy;
std::optional<ArrayAttr> maybeArr = originalKernel_.getArgAttrs();
if (maybeArr) {
ArrayAttr oldArr = *maybeArr;
unsigned copyNum = std::min<unsigned>(oldArr.size(), oldNumInputs);
for (unsigned i = 0; i < copyNum; ++i) {
if (auto dict = dyn_cast_or_null<DictionaryAttr>(oldArr[i])) {
originalKernel_.setArgAttrs(i, dict);
}
}
}
unsigned keyIdx = oldNumInputs;
unsigned memrefIdx = oldNumInputs + 1;
setTilingKeyAndDataArgAttrs(originalKernel_, keyIdx, memrefIdx);
return success();
}
LogicalResult rewriteOriginalKernelBodyStatic(OpBuilder &builder, MLIRContext *ctx, FunctionType oldTy,
unsigned oldNumInputs) {
while (!originalKernel_.getBody().empty()) originalKernel_.getBody().front().erase();
Block *entry = originalKernel_.addEntryBlock();
OpBuilder b = OpBuilder::atBlockEnd(entry);
Location loc = originalKernel_.getLoc();
SmallVector<Value> args(entry->args_begin(), entry->args_end());
if (args.size() != oldNumInputs) {
originalKernel_.emitError("static tiling: unexpected number of arguments");
return failure();
}
if (!tilingInfo_.isTilingFullyStatic() || !tilingInfo_.constTilingKey.has_value()) {
originalKernel_.emitError("static tiling: constTilingKey not available");
return failure();
}
int64_t constKey = *tilingInfo_.constTilingKey;
std::string keyStr = std::to_string(constKey);
if (keyStr.size() == 1) keyStr = "0" + keyStr;
std::string devName = kernelInfo_->baseKernelName + "_" + keyStr;
auto sym = SymbolTable::lookupSymbolIn(module_, StringAttr::get(ctx, devName));
if (!sym) {
originalKernel_.emitError() << "cannot find device kernel " << devName << " for static tiling";
return failure();
}
auto devFunc = dyn_cast<func::FuncOp>(sym);
if (!devFunc) {
originalKernel_.emitError() << devName << " is not func.func (static tiling)";
return failure();
}
auto call = b.create<func::CallOp>(loc, devFunc.getSymName(), TypeRange(oldTy.getResults()), args);
b.create<func::ReturnOp>(loc, call.getResults());
return success();
}
LogicalResult rewriteOriginalKernelBodyDynamic(OpBuilder &builder, MLIRContext *ctx, FunctionType oldTy,
unsigned oldNumInputs) {
while (!originalKernel_.getBody().empty()) originalKernel_.getBody().front().erase();
Block *entry = originalKernel_.addEntryBlock();
OpBuilder b = OpBuilder::atBlockEnd(entry);
Location loc = originalKernel_.getLoc();
SmallVector<Value> args(entry->args_begin(), entry->args_end());
Value keyPtr = args[oldNumInputs];
Value dataMem = args[oldNumInputs + 1];
auto i64Ty = b.getI64Type();
auto indexTy = b.getIndexType();
Value keyI64 = b.create<LLVM::LoadOp>(loc, i64Ty, keyPtr);
Value keyIndex = b.create<arith::IndexCastUIOp>(loc, indexTy, keyI64);
SmallVector<int64_t> caseKeys = tilingInfo_.getAllKeys();
if (caseKeys.empty()) {
originalKernel_.emitError("no tiling cases recorded in TilingInfo for original kernel switch");
return failure();
}
auto switchOp = b.create<scf::IndexSwitchOp>(loc, TypeRange(oldTy.getResults()), keyIndex,
ArrayRef<int64_t>(caseKeys), caseKeys.size());
for (unsigned i = 0; i < caseKeys.size(); ++i) {
int64_t key = caseKeys[i];
Region ® = switchOp.getCaseRegions()[i];
Block *blk = new Block();
reg.push_back(blk);
OpBuilder cb(blk, blk->begin());
SmallVector<Value> callArgs;
callArgs.reserve(oldNumInputs + 2);
for (unsigned a = 0; a < oldNumInputs; ++a) callArgs.push_back(args[a]);
callArgs.push_back(keyPtr);
callArgs.push_back(dataMem);
std::string keyStr = std::to_string(key);
if (keyStr.size() == 1) keyStr = "0" + keyStr;
std::string devName = kernelInfo_->baseKernelName + "_" + keyStr;
auto sym = SymbolTable::lookupSymbolIn(module_, StringAttr::get(ctx, devName));
if (!sym) {
originalKernel_.emitError() << "cannot find device kernel " << devName;
return failure();
}
auto devFunc = dyn_cast<func::FuncOp>(sym);
if (!devFunc) {
originalKernel_.emitError() << devName << " is not func.func";
return failure();
}
auto call = cb.create<func::CallOp>(loc, devFunc.getSymName(), TypeRange(oldTy.getResults()), callArgs);
cb.create<scf::YieldOp>(loc, ValueRange(call.getResults()));
}
{
Region &defaultReg = switchOp.getDefaultRegion();
Block *blk = new Block();
defaultReg.push_back(blk);
OpBuilder db(blk, blk->begin());
Value falseVal = db.create<arith::ConstantIntOp>(loc, 0, db.getI1Type());
auto msgAttr = db.getStringAttr("Invalid tiling key");
db.create<mlir::cf::AssertOp>(loc, falseVal, msgAttr);
SmallVector<Value> outputArgs;
auto maybeArgAttrArray = originalKernel_.getArgAttrs();
ArrayAttr argAttrArray;
if (maybeArgAttrArray) argAttrArray = *maybeArgAttrArray;
auto katName = StringAttr::get(ctx, KernelArgTypeAttr::name);
for (BlockArgument arg : entry->getArguments()) {
unsigned argIdx = arg.getArgNumber();
if (!argAttrArray || argIdx >= argAttrArray.size()) continue;
auto dict = dyn_cast_or_null<DictionaryAttr>(argAttrArray[argIdx]);
if (!dict) continue;
Attribute attr = dict.get(katName);
if (!attr) continue;
if (auto katAttr = dyn_cast<KernelArgTypeAttr>(attr)) {
if (katAttr.getArgType() == KernelArgType::kOutput) {
outputArgs.push_back(arg);
}
}
}
SmallVector<Value> defaultResults;
defaultResults.reserve(oldTy.getNumResults());
for (unsigned i = 0, e = oldTy.getNumResults(); i < e; ++i) {
defaultResults.push_back(outputArgs[i]);
}
db.create<scf::YieldOp>(loc, defaultResults);
}
b.create<func::ReturnOp>(loc, switchOp.getResults());
return success();
}
LogicalResult createOrGetGetTilingStructSizeFunction(OpBuilder &builder) {
ModuleOp module = originalKernel_->getParentOfType<ModuleOp>();
if (!module) {
originalKernel_.emitError("cannot find parent ModuleOp for original kernel");
return failure();
}
std::string base = (kernelInfo_ ? kernelInfo_->baseKernelName : originalKernel_.getSymName().str());
std::string hostName = base + "_get_tiling_struct_size_function";
if (auto sym = SymbolTable::lookupSymbolIn(module, StringAttr::get(module.getContext(), hostName))) {
if (isa<func::FuncOp>(sym)) return success();
}
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(module.getBody());
auto funcTy = FunctionType::get(module.getContext(), TypeRange{}, TypeRange{builder.getI64Type()});
auto host = builder.create<func::FuncOp>(originalKernel_.getLoc(), hostName, funcTy);
host.setVisibility(SymbolTable::Visibility::Public);
auto *ctx = builder.getContext();
host->setAttr(HACCFuncTypeAttr::name, HACCFuncTypeAttr::get(ctx, HACCFuncType::HOST));
host->setAttr(HostFuncTypeAttr::name, HostFuncTypeAttr::get(ctx, HostFuncType::kGetTilingStructSizeFunction));
Block *entry = host.addEntryBlock();
OpBuilder b = OpBuilder::atBlockEnd(entry);
int64_t sz = static_cast<int64_t>(tilingInfo_.tilingStructSize);
if (sz <= 0) sz = 1;
auto sizeConst = b.create<arith::ConstantIntOp>(host.getLoc(), sz, 64);
b.create<func::ReturnOp>(host.getLoc(), ValueRange{sizeConst});
return success();
}
protected:
func::FuncOp originalKernel_;
ModuleOp module_;
std::unique_ptr<KernelInfo> kernelInfo_;
TilingInfo tilingInfo_;
static AutoTilingOptions options_;
};
AutoTilingOptions TilingBase::options_;
class PureElemwiseTiling : public TilingBase {
public:
using TilingBase::TilingBase;
};
struct NPUAutoTiling : public mlir::impl::NPUAutoTilingBase<NPUAutoTiling> {
NPUAutoTiling() = default;
void runOnOperation() override {
ModuleOp module = getOperation();
if (!module) {
signalPassFailure();
return;
}
SmallVector<func::FuncOp> kernels;
module.walk([&](func::FuncOp f) {
auto kind = f->getAttrOfType<StringAttr>(HACCFuncTypeAttr::name);
if (kind && kind.getValue() == "DEVICE") {
return;
}
kernels.push_back(f);
});
AutoTilingOptions opts;
TilingBase::setAutoTilingOptions(opts);
OpBuilder builder(module.getContext());
for (func::FuncOp k : kernels) {
PureElemwiseTiling sch(k);
if (failed(sch.runOnOperation(builder))) {
k.emitError("memref-based auto-tiling failed");
signalPassFailure();
return;
}
}
}
};
}
}
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> mlir::createNPUAutoTilingPass() {
return std::make_unique<NPUAutoTiling>();
}