* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Analysis/Split/SplitModelInitPattern.h"
#include <algorithm>
#include <numeric>
#include "llvm/Support/Debug.h"
#include "mfusion/Analysis/Split/FusePattern.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#define DEBUG_TYPE "split"
namespace mlir {
namespace mfuse {
namespace split {
namespace ascend {
constexpr size_t kReduceFusionDepth = 10;
constexpr size_t kBroadcastFusionDepth = 6;
constexpr size_t kReduceBwdMaxAreaSize = 10;
constexpr size_t kReduceBwdMaxAreaOutputs = 3;
class FuseSlice : public FusePattern {
public:
FuseSlice() : FusePattern("slice", FuseDirection::BACKWARD) {}
~FuseSlice() = default;
protected:
bool check(const AreaPtr &area) override {
if (!area->dom() || !area->dom()->op()) {
return false;
}
auto op = area->dom()->op();
return op->getName().getStringRef() == "mfuse.slice" || op->getName().getStringRef() == "mfuse.strided_slice";
}
bool match(const AreaPtr &area) override {
for (const auto &[a, r] : area->usersWithRelation()) {
if (a->pattern() < NodePattern::BROADCAST && r == EdgeRelation::INJECTIVE && !hasCircle(area, a)) {
fused_areas_.push_back(a);
}
}
return !fused_areas_.empty();
}
};
class FuseElemAny : public FusePattern {
public:
FuseElemAny() : FusePattern("elemany_addn") {}
~FuseElemAny() = default;
protected:
bool check(const AreaPtr &area) override {
if (!area->dom() || !area->dom()->op()) {
return false;
}
auto op = area->dom()->op();
return op->getName().getStringRef() == "mfuse.elem_any";
}
bool match(const AreaPtr &area) override {
for (const auto &[a, r] : area->inputsWithRelation()) {
if (a->pattern() <= NodePattern::BROADCAST && r == EdgeRelation::INJECTIVE && !hasCircle(area, a)) {
fused_areas_.push_back(a);
}
}
return !fused_areas_.empty();
}
};
}
namespace dvm {
class FuseReduceFwd : public FusePattern {
public:
FuseReduceFwd(FuseType fuseType, size_t sizeLimit)
: FusePattern("reduce_fwd", FuseDirection::FORWARD), fuseType_(fuseType), sizeLimit_(sizeLimit) {
name_ += (fuseType == FuseType::kWidth ? "_width" : "_depth");
}
~FuseReduceFwd() = default;
static std::shared_ptr<FuseReduceFwd> createDepthMatcher(size_t sizeLimit) {
return std::make_shared<FuseReduceFwd>(FuseType::kDepth, sizeLimit);
}
static std::shared_ptr<FuseReduceFwd> createWidthMatcher(size_t sizeLimit) {
return std::make_shared<FuseReduceFwd>(FuseType::kWidth, sizeLimit);
}
protected:
bool check(const AreaPtr &area) override {
if (area->pattern() != NodePattern::REDUCE) {
return false;
}
return fuseType_ == FuseType::kWidth || area->inputNum() == 1;
}
bool match(const AreaPtr &area) override {
for (const auto &[a, r] : area->inputsWithRelation()) {
if (fuseType_ == FuseType::kDepth && a->userNum() != 1) {
continue;
}
if (a->size() > sizeLimit_) {
continue;
}
if (a->pattern() <= NodePattern::BROADCAST) {
if (r != EdgeRelation::INJECTIVE && (a->userNum() != 1 || a->isOutput())) {
continue;
}
if (fuseType_ == FuseType::kWidth && hasCircle(area, a)) {
continue;
}
fused_areas_.emplace_back(a);
}
}
return !fused_areas_.empty();
}
private:
FuseType fuseType_;
size_t sizeLimit_;
};
class FuseReduceBwd : public FusePattern {
public:
FuseReduceBwd() : FusePattern("reduce_bwd", FuseDirection::BACKWARD) {}
~FuseReduceBwd() = default;
protected:
static bool CheckReduceArea(const AreaPtr &area) {
if (!area || !area->isAlive() || area->pattern() != NodePattern::REDUCE || !area->dom() || !area->dom()->op()) {
return false;
}
auto reduce = mlir::dyn_cast<mlir::mfuse::ReduceSumOp>(area->dom()->op());
if (!reduce || !reduce.getKeepdim()) {
return false;
}
auto inputType = mlir::dyn_cast<mlir::RankedTensorType>(reduce.getInput().getType());
if (!inputType || !inputType.hasRank()) {
return false;
}
auto dimensions = reduce.getDimensions();
if (dimensions.empty()) {
return false;
}
const int64_t rank = inputType.getRank();
std::vector<int64_t> dims;
dims.reserve(dimensions.size());
for (auto dimAttr : dimensions.getValue()) {
auto dim = mlir::cast<mlir::IntegerAttr>(dimAttr).getValue().getSExtValue();
if (dim < 0 || dim >= rank) {
return false;
}
dims.push_back(dim);
}
std::sort(dims.begin(), dims.end());
return std::adjacent_find(dims.begin(), dims.end()) == dims.end();
}
bool check(const AreaPtr &area) override {
return area->userNum() == 1 && CheckReduceArea(area);
}
static bool CheckPostReduceArea(const AreaPtr &area, EdgeRelation relation) {
if (!area || !area->isAlive()) {
return false;
}
if (relation != EdgeRelation::INJECTIVE && relation != EdgeRelation::BROADCAST) {
return false;
}
if (area->pattern() != NodePattern::RESHAPE && area->pattern() != NodePattern::ELEMWISE &&
area->pattern() != NodePattern::BROADCAST) {
return false;
}
if (area->areaOutputs().size() > ascend::kReduceBwdMaxAreaOutputs) {
return false;
}
return area->size() <= ascend::kReduceBwdMaxAreaSize;
}
static bool IsTerminalSingleReshapeArea(const AreaPtr &area) {
static constexpr llvm::StringLiteral kReshapeOpName = "mfuse.reshape";
return area && area->size() == 1 && area->userNum() == 0 && area->dom() && area->dom()->op() &&
area->dom()->op()->getName().getStringRef() == kReshapeOpName;
}
bool match(const AreaPtr &area) override {
for (const auto &[a, r] : area->usersWithRelation()) {
if (IsTerminalSingleReshapeArea(a) || hasCircle(area, a) || !CheckPostReduceArea(a, r)) {
continue;
}
fused_areas_.push_back(a);
}
return fused_areas_.size() == 1;
}
};
class FuseMatMul : public FusePattern {
public:
FuseMatMul() : FusePattern("matmul_depth", FuseDirection::BACKWARD) {}
~FuseMatMul() = default;
protected:
bool check(const AreaPtr &area) override {
if (area->size() == 1) {
if (!area->dom() || !area->dom()->op()) {
return false;
}
auto op = area->dom()->op();
auto opName = op->getName().getStringRef();
return opName == "mfuse.matmul" || opName == "mfuse.batch_matmul";
}
return false;
}
bool isSameShapeSize(int64_t size, const std::vector<Node *> &output_nodes) {
for (auto &node : output_nodes) {
if (std::accumulate(mlir::cast<mlir::ShapedType>(node->op()->getResult(0).getType()).getShape().begin(),
mlir::cast<mlir::ShapedType>(node->op()->getResult(0).getType()).getShape().end(),
static_cast<int64_t>(1), std::multiplies<int64_t>()) != size) {
return false;
}
}
return true;
}
bool match(const AreaPtr &dom) override {
constexpr size_t MAX_FUSE_NUM = 5;
size_t current_size = 0;
if (dom->nodes().empty()) {
return false;
}
auto output_op = dom->nodes().back()->op();
if (!output_op) {
return false;
}
auto output_shape = mlir::cast<mlir::ShapedType>(output_op->getResult(0).getType()).getShape();
int64_t matmul_output_size =
std::accumulate(output_shape.begin(), output_shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
if (output_shape.back() == 1) {
return false;
}
auto opName = output_op->getName().getStringRef();
auto users = dom->users();
std::sort(users.begin(), users.end(),
[](const AreaPtr &a, const AreaPtr &b) { return a->areaOutputs().size() < b->areaOutputs().size(); });
for (auto &a : users) {
if (current_size + a->areaOutputs().size() > MAX_FUSE_NUM) {
break;
}
if (a->size() == 1 && a->dom()->op()->getName().getStringRef() == "mfuse.reshape") {
continue;
}
bool fuse_flag = (opName == "mfuse.matmul" || opName == "mfuse.batch_matmul" || opName == "mfuse.grouped_matmul");
if (std::any_of(a->nodes().begin(), a->nodes().end(),
[](const Node *node) { return node->op()->getName().getStringRef() == "mfuse.reshape"; })) {
fuse_flag = fuse_flag && (a->pattern() < NodePattern::BROADCAST);
} else {
fuse_flag = fuse_flag && (a->pattern() <= NodePattern::BROADCAST);
}
if (fuse_flag && !hasCircle(dom, a) && isSameShapeSize(matmul_output_size, a->areaOutputs())) {
fused_areas_.push_back(a);
current_size += a->areaOutputs().size();
}
}
return !fused_areas_.empty();
}
};
class FuseAllReduceFwd : public FusePattern {
public:
FuseAllReduceFwd() : FusePattern("allreduce_fwd", FuseDirection::FORWARD) {}
~FuseAllReduceFwd() = default;
protected:
bool check(const AreaPtr &area) override {
if (!area->dom() || !area->dom()->op()) {
return false;
}
auto op = area->dom()->op();
return area->size() == 1 && op->getName().getStringRef() == "mfuse.allreduce";
}
bool match(const AreaPtr &area) override {
for (const auto &[a, r] : area->inputsWithRelation()) {
if (a->userNum() != 1) {
continue;
}
auto op = a->dom()->op();
if (!hasCircle(a, area) && r == EdgeRelation::INJECTIVE && a->size() == 1 && op &&
op->getName().getStringRef() == "mfuse.matmul") {
fused_areas_.push_back(a);
}
}
return !fused_areas_.empty();
}
};
class FuseAllReduceBwd : public FusePattern {
public:
FuseAllReduceBwd() : FusePattern("allreduce_bwd", FuseDirection::BACKWARD) {}
~FuseAllReduceBwd() = default;
protected:
bool check(const AreaPtr &area) override {
static constexpr llvm::StringLiteral kAllReduceOpName = "mfuse.allreduce";
auto nodes = area->nodes();
return std::any_of(nodes.begin(), nodes.end(), [](const auto &node) {
return node->op() && node->op()->getName().getStringRef() == kAllReduceOpName;
});
}
bool match(const AreaPtr &area) override {
for (const auto &a : area->users()) {
if (a->pattern() < NodePattern::BROADCAST && !hasCircle(area, a)) {
fused_areas_.push_back(a);
}
}
return !fused_areas_.empty();
}
};
class FuseVirtualNode : public FusePattern {
public:
FuseVirtualNode() : FusePattern("fuse_virtual_node", FuseDirection::FORWARD) {}
~FuseVirtualNode() = default;
protected:
bool check(const AreaPtr &area) override { return area->pattern() == NodePattern::VIRTUAL; }
bool match(const AreaPtr &area) override {
for (const auto &inp : area->inputs()) {
if (inp->dom() && inp->dom()->op() && inp->dom()->op()->getName().getStringRef() != "mfuse.permute") {
fused_areas_.push_back(inp);
}
}
return !fused_areas_.empty();
}
};
}
void DVMSplitModel::initFusePatterns() {
addPattern(std::make_shared<dvm::FuseVirtualNode>(), true);
addPattern(std::make_shared<FuseReshape>(), true);
addPattern(FuseElemwiseBroadcastFwd::createDepthMatcher(), true);
addPattern(FuseElemwiseBroadcastFwd::createWidthMatcher(), true);
addPattern(dvm::FuseReduceFwd::createDepthMatcher(ascend::kReduceFusionDepth), true);
addPattern(dvm::FuseReduceFwd::createWidthMatcher(ascend::kReduceFusionDepth), true);
addPattern(FuseElemwiseBroadcastBwd::createDepthMatcher(ascend::kBroadcastFusionDepth), true);
addPattern(FuseElemwiseBroadcastBwd::createWidthMatcher(ascend::kBroadcastFusionDepth), true);
addPattern(std::make_shared<dvm::FuseReduceBwd>(), true);
addPattern(std::make_shared<ascend::FuseElemAny>(), true);
addPattern(std::make_shared<ascend::FuseSlice>(), true);
addPattern(std::make_shared<dvm::FuseMatMul>(), true);
addPattern(std::make_shared<dvm::FuseAllReduceFwd>(), true);
addPattern(std::make_shared<dvm::FuseAllReduceBwd>(), true);
}
void AKGSplitModel::initFusePatterns() {
addPattern(std::make_shared<dvm::FuseVirtualNode>(), true);
addPattern(std::make_shared<FuseReshape>(), true);
addPattern(FuseElemwiseBroadcastFwd::createDepthMatcher(), true);
addPattern(FuseElemwiseBroadcastFwd::createWidthMatcher(), true);
addPattern(FuseElemwiseBroadcastBwd::createDepthMatcher(ascend::kBroadcastFusionDepth), true);
addPattern(FuseElemwiseBroadcastBwd::createWidthMatcher(ascend::kBroadcastFusionDepth), true);
}
void BishengSplitModel::initFusePatterns() {
addPattern(FuseElemwiseBroadcastFwd::createDepthMatcher(), true);
addPattern(FuseElemwiseBroadcastFwd::createWidthMatcher(), true);
}
}
}
}