* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Analysis/Split/Node.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace mfuse {
namespace split {
std::string Node::toString() const {
std::string result;
llvm::raw_string_ostream os(result);
op_->print(os, mlir::OpPrintingFlags().assumeVerified().useLocalScope());
return result;
}
void Node::addInput(Node *new_input) {
if (!new_input) {
return;
}
new_input->addUser(this, inputs_.size());
inputs_.push_back(new_input);
}
void Node::setInput(size_t i, Node *new_input) {
if (!new_input) {
return;
}
if (i >= inputs_.size()) {
std::string err_msg =
"The index " + std::to_string(i) + " is out of the inputs range [0, " + std::to_string(inputs_.size()) + ")";
llvm::report_fatal_error(llvm::StringRef(err_msg));
}
auto &old_input = inputs_[i];
old_input->removeUser(this, i);
new_input->addUser(this, i);
inputs_[i] = new_input;
}
void Node::setInputs(const std::vector<Node *> &inputs) {
clearInputs();
inputs_.reserve(inputs.size());
for (const auto &inp : inputs) {
addInput(inp);
}
}
void Node::clearInputs() noexcept {
if (!inputs_.empty()) {
for (size_t i = 0; i < inputs_.size(); i++) {
inputs_[i]->removeUser(this, i);
}
inputs_.clear();
}
}
void Node::replaceWith(Node *other_node) {
if (!other_node || users_.empty()) {
return;
}
auto users_copy = users_;
for (const auto &[user, indices] : users_copy) {
for (size_t idx : indices) {
user->setInput(idx, other_node);
}
}
}
void Node::removeUser(Node *user, size_t index) {
if (auto iter = users_.find(user); iter != users_.end()) {
iter->second.erase(index);
if (iter->second.empty()) {
users_.erase(iter);
}
}
}
}
}
}