* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "vm/vmimpl.h"
#include <algorithm>
#include <exception>
#include <vector>
#include <memory>
#include "frontend/operator/ops.h"
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "utils/convert_utils.h"
#include "utils/primitive_utils.h"
namespace mindspore {
namespace compile {
struct CallWrap : public Base {
explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {}
VMFramePtr frame{nullptr};
};
using CallWrapPtr = std::shared_ptr<CallWrap>;
struct ReturnWrap : public Base {
explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {}
BaseRef value{BaseRef()};
};
using ReturnWrapPtr = std::shared_ptr<ReturnWrap>;
VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values,
const AnfNodePtrToBaseRefMap &closure)
: values_(values), todo_(nodes), closure_(closure) {
std::reverse(std::begin(todo_), std::end(todo_));
}
const BaseRef VMFrame::operator[](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto ret = values_.find(node);
if (ret != values_.end()) {
return ret->second;
}
ret = closure_.find(node);
if (ret != closure_.end()) {
return ret->second;
}
if (node->isa<ValueNode>()) {
return GetValueNode(node);
}
MS_LOG(EXCEPTION) << "ValueError " << node->type_name();
}
Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values)
: func_graph_(graph), values_(values) {}
BaseRef Closure::operator()(const VectorRef &args) {
MS_LOG(DEBUG) << "Start closure";
MS_EXCEPTION_IF_NULL(vm_);
return vm_->Evaluate(func_graph_, args, values_);
}
Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {}
BaseRef Partial::operator()(const VectorRef &nodes) {
VectorRef arglist;
(void)arglist.insert(arglist.end(), args_.begin(), args_.end());
(void)arglist.insert(arglist.end(), nodes.begin(), nodes.end());
MS_EXCEPTION_IF_NULL(vm_);
return vm_->Call(fn_, arglist);
}
SetRef VM::ComputeFvs(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
SetRef rval;
for (auto &fkv : graph->free_variables_total()) {
if (utils::isa<FuncGraphPtr>(fkv.first)) {
auto g = utils::cast<FuncGraphPtr>(fkv.first);
for (auto &ctkv : g->value_nodes()) {
auto ct = ctkv.first;
if (GetValueNode(ct) == g) {
(void)rval.insert(ct);
}
}
} else {
(void)rval.insert(fkv.first);
}
}
return rval;
}
void VM::AcquireGraph(const FuncGraphPtr &graph) {
if (vars_.find(graph) != vars_.end()) {
return;
}
MS_EXCEPTION_IF_NULL(manager_);
manager_->AddFuncGraph(graph);
MS_EXCEPTION_IF_NULL(graph->manager());
auto graphs = graph->manager()->func_graphs();
for (auto g = graphs.begin(); g != graphs.end(); ++g) {
vars_[*g] = ComputeFvs(*g);
}
}
VectorRef VM::ExportSequence(const VectorRef &seq) {
std::vector<BaseRef> ret;
(void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret),
[&, this](const BaseRef &x) -> BaseRef { return Export(x); });
return VectorRef(ret);
}
ClosurePtr VM::ExportClosure(const ClosurePtr &clos) {
MS_EXCEPTION_IF_NULL(clos);
clos->set_vm(shared_from_this());
return clos;
}
ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) {
auto c = std::make_shared<Closure>(g, AnfNodePtrToBaseRefMap());
MS_EXCEPTION_IF_NULL(c);
c->set_vm(shared_from_this());
return c;
}
BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; }
BaseRef VM::Export(const BaseRef &value) {
if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<FuncGraph>()) {
return ExportGraph(utils::cast<ValuePtr>(value)->cast<FuncGraphPtr>());
}
if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<Primitive>()) {
return ExportPrimitive(utils::cast<ValuePtr>(value)->cast<PrimitivePtr>());
}
if (utils::isa<FuncGraphPtr>(value)) {
return ExportGraph(utils::cast<FuncGraphPtr>(value));
}
if (utils::isa<ClosurePtr>(value)) {
return ExportClosure(utils::cast<ClosurePtr>(value));
}
if (utils::isa<PrimitivePtr>(value)) {
return ExportPrimitive(utils::cast<PrimitivePtr>(value));
}
if (utils::isa<VectorRef>(value)) {
return ExportSequence(utils::cast<VectorRef>(value));
}
return ExportObj(value);
}
BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) {
MS_EXCEPTION_IF_NULL(graph);
AcquireGraph(graph);
MS_LOG(DEBUG) << "Evalue arg size: " << args.size();
if (args.size() != graph->parameters().size()) {
MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got "
<< args.size();
}
auto nodes = TopoSort(graph->get_return(), SuccVm(graph));
AnfNodePtrToBaseRefMap values;
for (size_t i = 0; i < args.size(); i++) {
values[graph->parameters()[i]] = args[i];
}
VMFramePtrList frames{std::make_shared<VMFrame>(nodes, values, closure)};
while (!frames.empty()) {
auto frame = frames[frames.size() - 1];
MS_EXCEPTION_IF_NULL(frame);
auto todo = frame->todo();
while (!todo.empty()) {
auto except = HandleNode(todo[todo.size() - 1], frame);
if (utils::isa<CallWrapPtr>(except)) {
if (todo.size() == 2) {
frames[frames.size() - 1] = utils::cast<CallWrapPtr>(except)->frame;
} else {
frames.push_back(utils::cast<CallWrapPtr>(except)->frame);
}
break;
}
if (utils::isa<ReturnWrapPtr>(except)) {
(void)frames.erase(frames.begin() + (static_cast<ssize_t>(frames.size()) - 1));
if (frames.size() > 0) {
auto top = frames[frames.size() - 1];
MS_EXCEPTION_IF_NULL(top);
auto td = top->todo();
if (td.empty()) {
MS_LOG(EXCEPTION) << "The td is empty";
}
top->values()[td[td.size() - 1]] = utils::cast<ReturnWrapPtr>(except)->value;
(void)td.erase(td.begin() + (static_cast<ssize_t>(td.size()) - 1));
} else {
return Export(utils::cast<ReturnWrapPtr>(except)->value);
}
break;
}
(void)todo.erase(todo.begin() + (static_cast<ssize_t>(todo.size()) - 1));
}
}
MS_LOG(EXCEPTION) << "VM Evaluate error";
}
SuccFunc VM::SuccVm(const FuncGraphPtr &graph) {
auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList {
MS_EXCEPTION_IF_NULL(node);
AnfNodePtrList ret;
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
for (auto &i : inputs) {
if (i->func_graph() == node->func_graph() ||
(IsValueNode<FuncGraph>(i) && GetValueNode<FuncGraphPtr>(i)->parent() == graph)) {
ret.push_back(i);
}
}
}
if (IsValueNode<FuncGraph>(node) && GetValueNode<FuncGraphPtr>(node)->parent() == graph) {
auto fvs = utils::cast<SetRef>(vars_[GetValueNode<FuncGraphPtr>(node)]);
(void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret),
[](const BaseRef &value) -> AnfNodePtr { return utils::cast<AnfNodePtr>(value); });
}
return ret;
};
return fn;
}
BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) {
if (utils::isa<PrimitivePtr>(fn)) {
return RunOperation(utils::cast<PrimitivePtr>(fn), args);
}
if (utils::isa<FuncGraphPtr>(fn)) {
return Evaluate(utils::cast<FuncGraphPtr>(fn), args);
}
if (utils::isa<ClosurePtr>(fn)) {
auto clos = utils::cast<ClosurePtr>(fn);
return Evaluate(clos->func_graph(), args, clos->values());
}
MS_LOG(EXCEPTION) << "Can't call fn";
}
BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) {
AnfNodePtrToBaseRefMap clos;
auto func_graph = graph;
if (utils::isa<ClosurePtr>(func_graph)) {
clos = utils::cast<ClosurePtr>(func_graph)->values();
func_graph = utils::cast<ClosurePtr>(func_graph)->func_graph();
}
if (utils::isa<ValuePtr>(func_graph)) {
func_graph = utils::cast<ValuePtr>(func_graph)->cast<FuncGraphPtr>();
}
if (!utils::isa<FuncGraphPtr>(func_graph)) {
MS_LOG(EXCEPTION) << "Graph type error";
}
auto graphPtr = utils::cast<FuncGraphPtr>(func_graph);
if (vars_.find(graphPtr) == vars_.end()) {
AcquireGraph(graphPtr);
}
if (args.size() != graphPtr->parameters().size()) {
MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got "
<< args.size();
}
auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr));
AnfNodePtrToBaseRefMap values;
for (size_t i = 0; i < args.size(); i++) {
values[graphPtr->parameters()[i]] = args[i];
}
return std::make_shared<CallWrap>(std::make_shared<VMFrame>(nodes, values, clos));
}
ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) {
MS_EXCEPTION_IF_NULL(frame);
AnfNodePtrToBaseRefMap clos;
for (auto &v : utils::cast<SetRef>(vars_[graph])) {
auto anf = utils::cast<AnfNodePtr>(v);
clos[anf] = (*frame)[anf];
}
return std::make_shared<Closure>(graph, clos);
}
BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) {
if (utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<Primitive>()) {
auto fnval = utils::cast<ValuePtr>(fn)->cast<PrimitivePtr>();
MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true);
if (args.empty()) {
MS_LOG(EXCEPTION) << "Args is empty";
}
if (fnval == prim::kPrimReturn) {
MS_LOG(DEBUG) << "Return args:" << args.size();
return std::make_shared<ReturnWrap>(args[0]);
}
MS_EXCEPTION_IF_NULL(frame);
if (fnval == prim::kPrimMakeTuple) {
frame->values()[node] = args;
return BaseRef();
}
if (fnval == prim::kPrimPartial) {
VectorRef partial_args(args.begin() + 1, args.end());
frame->values()[node] = (std::make_shared<Partial>(args[0], partial_args, shared_from_this()));
return BaseRef();
}
frame->values()[node] = RunOperation(fnval, args);
return BaseRef();
}
if (utils::isa<PartialPtr>(fn)) {
auto fnPtr = utils::cast<PartialPtr>(fn);
VectorRef arglist;
(void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end());
(void)arglist.insert(arglist.end(), args.begin(), args.end());
auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist);
if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
return ret;
}
}
if ((utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<FuncGraph>()) || utils::isa<ClosurePtr>(fn)) {
auto ret = _Call(fn, args);
if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
return ret;
}
}
MS_LOG(EXCEPTION) << "Invalid fn to call";
}
BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>()) {
return BaseRef();
}
if (node->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(node)) {
MS_LOG(EXCEPTION) << "We only visit valuenode graphs ";
}
auto g = GetValueNode<FuncGraphPtr>(node);
MS_EXCEPTION_IF_NULL(frame);
auto iterG = vars_.find(g);
if (iterG != vars_.end() && utils::cast<SetRef>(iterG->second).size() != 0) {
frame->values()[node] = MakeClosure(g, frame);
}
return BaseRef();
}
if (node->isa<CNode>()) {
std::vector<BaseRef> fnArgs;
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs),
[&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; });
if (fnArgs.empty()) {
MS_LOG(EXCEPTION) << "Function arguments is empty";
} else {
auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end());
auto except = DispatchCall(node, frame, fnArgs[0], args);
return except;
}
}
MS_LOG(EXCEPTION) << "Unknown node type";
}
VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
this->manager_ = Manage(g);
auto fn = utils::cast<ClosurePtr>(Export(g));
auto result = (*fn)(args);
if (utils::isa<VectorRef>(result)) {
return utils::cast<VectorRef>(result);
} else {
VectorRef ret({result});
return ret;
}
}
BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
MS_LOG(DEBUG) << "Operation start " << prim->name();
MS_EXCEPTION_IF_NULL(prim);
auto result = prim->RunComputeFunction(args);
if (result.is_null()) {
result = RunComputeFunctionWithoutPyObj(prim, args);
}
if (result.is_null()) {
return RunComputeFunction(prim, args);
}
return result;
}
}
}