* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Analysis/Split/Area.h"
#include <algorithm>
#include <numeric>
#include <unordered_set>
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Analysis/Split/OpRegister.h"
#include "mfusion/Dialect/Mfuse/IR/MfuseDialect.h"
#include "mfusion/Dialect/Dvm/IR/Dvm.h"
#include "mfusion/Analysis/SymbolicShape/SymExprBuilder.h"
#include "mfusion/Analysis/SymbolicShape/SymEngineAnalysis.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace mfuse {
namespace split {
namespace {
bool isDynamic(const Type &type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasRank() && !shapedType.hasStaticShape();
}
bool symExprEqual(const SymExpr &a, const SymExpr &b) {
mfusion::SymEngineAnalysis analysis;
return analysis.isStructurallyEqual(a, b);
}
bool shapeEqual(Node *a, Node *b, bool skip_leading_one = true) {
if (!a || !b) {
return false;
}
auto l = a->shape.size() < b->shape.size() ? b : a;
auto s = a->shape.size() < b->shape.size() ? a : b;
const auto &l_shape = l->shape;
const auto &s_shape = s->shape;
const auto &l_sym_shape = l->sym_shape;
const auto &s_sym_shape = s->sym_shape;
auto diff = l_shape.size() - s_shape.size();
bool has_symshape = (l_sym_shape.size() == l_shape.size() && s_sym_shape.size() == s_shape.size());
if (diff != 0 && !skip_leading_one) {
return false;
}
for (size_t i = 0; i < diff; ++i) {
if (l_shape[i] != 1) {
return false;
}
}
for (size_t i = 0; i < s_shape.size(); ++i) {
auto il = i + diff;
if (l_shape[il] < 0 || s_shape[i] < 0) {
if (!has_symshape || !symExprEqual(l_sym_shape[il], s_sym_shape[i])) {
return false;
}
} else if (l_shape[il] != s_shape[i]) {
return false;
}
}
return true;
}
EdgeRelation getRelation(Node *node, Node *input) {
if (!node || !input || !node->op() || !input->op()) {
llvm::report_fatal_error("op or input is nullptr");
}
auto op = node->op();
std::string op_name = op->getName().getStringRef().str();
NodePattern op_pattern = OpRegistry::Instance().GetPattern(op_name);
if (op_pattern != NodePattern::ELEMWISE) {
return op_pattern == NodePattern::BROADCAST ? EdgeRelation::BROADCAST : EdgeRelation::INJECTIVE;
}
if (op->getNumOperands() == 1) {
return EdgeRelation::INJECTIVE;
}
auto input_type = mlir::dyn_cast<mlir::RankedTensorType>(input->op()->getResult(0).getType());
if (input_type && isDynamic(input_type)) {
if (std::all_of(op->getOperands().begin(), op->getOperands().end(),
[input](Value operand) { return operand.getDefiningOp() == input->op(); })) {
return EdgeRelation::INJECTIVE;
}
}
return shapeEqual(node, input) ? EdgeRelation::INJECTIVE : EdgeRelation::BROADCAST;
}
bool sameArea(const AreaWithRelation &a, const AreaWithRelation &b) { return a.first == b.first; }
bool areaWithRelationCmp(const AreaWithRelation &a, const AreaWithRelation &b) {
return sameArea(a, b) ? (a.second > b.second) : (a.first->id() < b.first->id());
}
}
Area::Area(size_t id, Node *node, bool is_output, const std::unordered_map<Node *, AreaPtr> &node_area_map)
: unique_id_(id), is_output_(is_output), nodes_(1, node) {
std::string op_name = node->op()->getName().getStringRef().str();
NodePattern pattern = OpRegistry::Instance().GetPattern(op_name);
hd_ = std::make_shared<NodeHandle>(this, pattern);
auto init_pattern = pattern;
for (auto input : node->inputs()) {
auto input_relation = getRelation(node, input);
if (init_pattern == NodePattern::ELEMWISE && input_relation == EdgeRelation::BROADCAST) {
hd_->setPattern(NodePattern::BROADCAST);
}
if (auto inp_area_iter = node_area_map.find(input); inp_area_iter != node_area_map.end()) {
inputs_with_relation_.emplace_back(std::make_pair(inp_area_iter->second, input_relation));
}
}
if (hd_->pattern() == NodePattern::BROADCAST && init_pattern == NodePattern::ELEMWISE) {
size_t scalar_input_num = 0;
size_t input_num = node->inputNum();
for (auto input : node->inputs()) {
auto defType = mlir::dyn_cast<RankedTensorType>(input->op()->getResult(0).getType());
if (defType && defType.hasStaticShape() && defType.getNumElements() == 1) {
scalar_input_num++;
}
}
if (scalar_input_num + 1 == input_num) {
hd_->setPattern(NodePattern::ELEMWISE);
if (!inputs_with_relation_.empty()) {
inputs_with_relation_[0].second = EdgeRelation::INJECTIVE;
}
}
}
makeUniqueAndSyncInputs();
}
std::vector<AreaPtr> Area::inputs() const {
std::vector<AreaPtr> result;
result.reserve(inputs_with_relation_.size());
std::transform(inputs_with_relation_.begin(), inputs_with_relation_.end(), std::back_inserter(result),
[](const auto &pair) { return pair.first; });
return result;
}
std::vector<AreaPtr> Area::users() const {
std::vector<AreaPtr> result;
(void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &user) {
Node *node = user.first;
return node->as<NodeHandle>()->area();
});
return result;
}
std::vector<AreaWithRelation> Area::usersWithRelation() const {
std::vector<AreaWithRelation> result;
(void)std::transform(hd_->users().begin(), hd_->users().end(), std::back_inserter(result), [](const auto &u) {
Node *node = u.first;
auto area = node->as<NodeHandle>()->area();
const auto relation = area->inputRelation(*(u.second.begin()));
return std::make_pair(area, relation);
});
return result;
}
int64_t Area::computeSize() const {
auto op = dom()->op();
auto op_type = op->getResult(0).getType();
auto op_shape = mlir::dyn_cast<RankedTensorType>(op_type).getShape();
return std::accumulate(op_shape.begin(), op_shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
}
bool Area::computeSizeEqual(const AreaPtr &other) const {
if (!other || nodes_.empty() || other->nodes_.empty()) {
return false;
}
auto op = dom();
auto other_op = other->dom();
auto op_type = op->op()->getResult(0).getType();
auto other_op_type = other_op->op()->getResult(0).getType();
if (op_type && other_op_type && !isDynamic(op_type) && !isDynamic(other_op_type)) {
return computeSize() == other->computeSize();
}
return shapeEqual(op, other_op);
}
std::string Area::toString() const {
std::string result;
llvm::raw_string_ostream os(result);
bool is_first = true;
os << "<";
for (auto node : nodes_) {
if (is_first) {
is_first = false;
os << id() << ":";
} else {
os << "-";
}
os << "op" << node->id();
}
os << ">";
return result;
}
void Area::fuseInput(const AreaPtr &input_area) {
auto iter = std::find_if(inputs_with_relation_.begin(), inputs_with_relation_.end(),
[&input_area](const AreaWithRelation &a) { return a.first == input_area; });
if (iter == inputs_with_relation_.end()) {
std::string err_msg =
"The area " + input_area->toString() + " should be the input of area " + this->toString() + "\n";
llvm::report_fatal_error(llvm::StringRef(err_msg));
}
auto input_idx = iter - inputs_with_relation_.begin();
if (input_area->is_output_) {
is_output_ = true;
}
if (pattern() < input_area->pattern()) {
nodes_.swap(input_area->nodes_);
}
(void)nodes_.insert(nodes_.cend(), input_area->nodes_.cbegin(), input_area->nodes_.cend());
NodePattern new_pattern = std::max(pattern(), input_area->pattern());
if ((new_pattern == NodePattern::ELEMWISE) && (inputRelation(input_idx) == EdgeRelation::BROADCAST)) {
new_pattern = NodePattern::BROADCAST;
}
hd_->setPattern(new_pattern);
inputs_with_relation_.erase(iter);
inputs_with_relation_.insert(inputs_with_relation_.cend(), input_area->inputs_with_relation_.cbegin(),
input_area->inputs_with_relation_.cend());
makeUniqueAndSyncInputs();
updateUsersRelation(input_area);
input_area->nodes_.clear();
input_area->inputs_with_relation_.clear();
input_area->hd_->clearInputs();
}
void Area::makeUniqueAndSyncInputs() {
std::sort(inputs_with_relation_.begin(), inputs_with_relation_.end(), areaWithRelationCmp);
auto last = std::unique(inputs_with_relation_.begin(), inputs_with_relation_.end(), sameArea);
inputs_with_relation_.erase(last, inputs_with_relation_.cend());
this->hd_->clearInputs();
std::for_each(inputs_with_relation_.begin(), inputs_with_relation_.end(),
[this](const AreaWithRelation &pair) { this->hd_->addInput(pair.first->hd_.get()); });
}
void Area::updateUsersRelation(const AreaPtr &input_area) {
auto &user_node_with_index = input_area->hd_->users();
std::vector<AreaPtr> user_areas;
for (auto &[user_hd, index] : user_node_with_index) {
user_areas.emplace_back(user_hd->as<NodeHandle>()->area());
const auto idx = *(index.begin());
user_areas.back()->inputs_with_relation_[idx].first = this->shared_from_this();
}
for (auto user : user_areas) {
user->makeUniqueAndSyncInputs();
}
}
}
}
}