* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#include <condition_variable>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <atomic>
#include <queue>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/comm_manager.h"
#include "utils/contract.h"
namespace mindspore {
namespace session {
enum TaskType {
kUnKnown,
kExit,
kCompileNodes,
kCompileGraph,
kBuildGraph,
kRunGraph,
kRunOp,
kCreateCommGroup,
kDestroyCommGroup,
kRunOpsInGraph
};
class Task {
public:
Task() = default;
virtual ~Task() = default;
SessionPtr session_{nullptr};
TaskType type_{kUnKnown};
bool sync_run_{false};
virtual void Run() {}
};
class CompileNodesTask : public Task {
public:
CompileNodesTask() { type_ = kCompileNodes; }
~CompileNodesTask() override = default;
void Run() override;
GraphSegmentPtr segment_;
AnfNodePtrList output_nodes_;
GraphId graph_id_{0};
};
class CompileGraphTask : public Task {
public:
CompileGraphTask() { type_ = kCompileGraph; }
~CompileGraphTask() override = default;
void Run() override;
FuncGraphPtr func_graph_{nullptr};
GraphId graph_id_{0};
};
class BuildGraphTask : public Task {
public:
BuildGraphTask() { type_ = kBuildGraph; }
~BuildGraphTask() override = default;
void Run() override;
GraphId graph_id_{0};
};
class RunGraphTask : public Task {
public:
RunGraphTask() { type_ = kRunGraph; }
~RunGraphTask() override = default;
void Run() override;
std::vector<tensor::TensorPtr> input_tensors_;
std::vector<tensor::TensorPtr> input_need_wait_tensors_;
std::vector<tensor::TensorPtr> input_need_lock_tensors_;
VectorRef outputs_;
GraphId graph_id_{0};
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
};
class RunOpsInGraphTask : public Task {
public:
RunOpsInGraphTask() { type_ = kRunOpsInGraph; }
~RunOpsInGraphTask() override = default;
void Run() override;
std::vector<tensor::TensorPtr> input_tensors_;
VectorRef outputs_;
GraphId graph_id_{0};
};
class RunOpTask : public Task {
public:
RunOpTask() { type_ = kRunOp; }
~RunOpTask() override = default;
void Run() override;
OpRunInfo *op_run_info_{nullptr};
GraphInfo graph_info_;
std::vector<tensor::TensorPtr> *input_tensors_{nullptr};
VectorRef outputs_;
std::vector<int64_t> tensors_mask_;
};
class CreateCommGroupTask : public Task {
public:
CreateCommGroupTask() { type_ = kCreateCommGroup; }
~CreateCommGroupTask() override = default;
void Run() override;
std::string group_name_;
std::vector<uint32_t> ranks_;
bool result_{false};
};
class DestroyCommGroupTask : public Task {
public:
DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
~DestroyCommGroupTask() override = default;
void Run() override;
std::string group_name_;
bool result_{false};
};
class ExitTask : public Task {
public:
ExitTask() { type_ = kExit; }
~ExitTask() override = default;
};
enum class ExecutorEvent { kClear, kRunGraphFinished, kException };
class Executor {
public:
Executor(const std::string &device_name, uint32_t device_id);
~Executor();
void WorkerLoop();
void WorkerJoin();
GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
void BuildGraph(const SessionPtr &session, GraphId graphId);
void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask);
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks);
bool DestroyCommGroup(const std::string &group_name);
void OnEvent(const ExecutorEvent &event);
private:
void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
void OnWorkerExit();
void OnClear();
void OnRunGraphFinished();
void OnException();
void ClearDoneTasks();
uint32_t device_id_;
std::string device_name_;
std::mutex task_mutex_;
std::mutex done_task_mutex_;
std::mutex pending_task_mutex_;
std::mutex reenter_mutex_;
std::condition_variable task_cond_var_;
std::condition_variable sync_cond_var_;
std::condition_variable reenter_cond_var_;
std::queue<std::shared_ptr<Task>> ready_tasks_;
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::vector<std::shared_ptr<Task>> done_tasks_;
std::shared_ptr<std::thread> worker_;
bool sync_run_task_finished_{false};
};
}
}
#endif