#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#define DEBUG_TYPE "enable-arm-streaming"
namespace mlir {
namespace arm_sme {
#define GEN_PASS_DEF_ENABLEARMSTREAMING
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
template <typename... Ops>
constexpr auto opList() {
return std::array{TypeID::get<Ops>()...};
}
bool isScalableVector(Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.isScalable();
return false;
}
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
bool ifRequiredByOps, bool ifScalableAndSupported) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
this->ifScalableAndSupported = ifScalableAndSupported;
}
void runOnOperation() override {
auto function = getOperation();
if (ifRequiredByOps && ifScalableAndSupported) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
"`if-scalable-and-supported` are mutually exclusive");
return signalPassFailure();
}
if (ifRequiredByOps) {
bool foundTileOp = false;
function.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!foundTileOp)
return;
}
if (ifScalableAndSupported) {
auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
bool isCompatibleScalableFunction = false;
function.walk([&](Operation *op) {
if (llvm::is_contained(disallowedOperations,
op->getName().getTypeID())) {
isCompatibleScalableFunction = false;
return WalkResult::interrupt();
}
if (!isCompatibleScalableFunction &&
(llvm::any_of(op->getOperandTypes(), isScalableVector) ||
llvm::any_of(op->getResultTypes(), isScalableVector))) {
isCompatibleScalableFunction = true;
}
return WalkResult::advance();
});
if (!isCompatibleScalableFunction)
return;
}
if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
auto unitAttr = UnitAttr::get(&getContext());
function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
if (zaMode != ArmZaMode::Disabled)
function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
}
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
bool ifRequiredByOps, bool ifScalableAndSupported) {
return std::make_unique<EnableArmStreamingPass>(
streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
}