* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
#define MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
#include <memory>
#include <vector>
#include <map>
#include <mutex>
#include <tuple>
#include <string>
#include <utility>
#include <deque>
#include <algorithm>
#include "debug/tensor_data.h"
#ifdef ONLINE_DBG_MODE
#include "debug/data_dump/dump_json_parser.h"
namespace mindspore {
#endif
class TensorLoader {
public:
TensorLoader() : iter_num_(-1), mem_total_(0), mem_usage_(0) {}
~TensorLoader() { EmptyTensor(); }
void MoveTensorCurrentToPrev(std::string tensor_name) {
auto handle = tensor_list_map_.extract(tensor_name);
if (!handle.empty()) {
MS_LOG(INFO) << "Moving " << tensor_name << " from current map to previous map";
prev_tensor_list_map_.insert(std::move(handle));
}
}
void SwapCurrentPrev() { tensor_list_map_.swap(prev_tensor_list_map_); }
bool TensorExistsInCurrent(std::string tensor_name) const {
return tensor_list_map_.find(tensor_name) != tensor_list_map_.end();
}
bool PrevTensorExistsInCurrent(std::string tensor_name) const { return TensorExistsInCurrent(tensor_name + ":prev"); }
void MoveParametersCurrentToPrev() {
MS_LOG(INFO) << "Moving parameters from current map to previous map";
auto iter = tensor_list_map_.begin();
while (iter != tensor_list_map_.end()) {
auto key = iter->first;
if (PrevTensorExistsInCurrent(key)) {
++iter;
MoveTensorCurrentToPrev(key);
} else {
++iter;
}
}
}
bool IsPrevTensor(std::string tensor_name) const {
const std::string suffix = ":prev";
if (tensor_name.length() <= suffix.length()) return false;
return std::equal(suffix.rbegin(), suffix.rend(), tensor_name.rbegin());
}
bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
lock_.lock();
auto tensor_name = tensor->GetName();
if (keep_prev) {
auto handle = prev_tensor_list_map_.extract(tensor_name);
if (!handle.empty()) {
handle.key() = tensor_name + ":prev";
tensor_list_map_.insert(std::move(handle));
}
}
std::string key_name = tensor_name;
#ifdef OFFLINE_DBG_MODE
key_name += (":" + std::to_string(tensor->GetDeviceId()) + ":" + std::to_string(tensor->GetRootGraphId()) + ":" +
std::to_string(tensor->GetIsOutput()) + ":" + std::to_string(tensor->GetSlot()));
if (tensor_list_map_.find(key_name) != tensor_list_map_.end() &&
tensor->GetIteration() == tensor_list_map_[key_name]->GetIteration() - 1) {
key_name += ":prev";
}
auto iter = tensor_list_map_.find(key_name);
if (iter != tensor_list_map_.end()) {
iter->second->DeleteDataPtr();
}
#endif
tensor_list_map_[key_name] = tensor;
lock_.unlock();
return true;
}
std::vector<std::shared_ptr<TensorData>> GetTensor() {
std::vector<std::shared_ptr<TensorData>> tensor_list;
for (auto &it : tensor_list_map_) {
if (!IsPrevTensor(it.first)) tensor_list.push_back(it.second);
}
return tensor_list;
}
std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) const {
auto iter = tensor_list_map_.find(tensor_name);
if (iter != tensor_list_map_.end()) return iter->second;
return nullptr;
}
std::shared_ptr<TensorData> GetPrevTensor(const std::string &tensor_name) {
if (tensor_list_map_.find(tensor_name + ":prev") != tensor_list_map_.end()) {
return tensor_list_map_[tensor_name + ":prev"];
}
return nullptr;
}
void SearchTensors(const std::vector<std::string> &search_list,
std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> *result_list) {
for (auto i : search_list) {
std::map<std::string, std::shared_ptr<TensorData>>::iterator iter;
iter = tensor_list_map_.find(i);
if (iter != tensor_list_map_.end()) {
result_list->push_back(std::make_tuple(i, iter->second));
} else {
result_list->push_back(std::make_tuple(i, nullptr));
}
}
}
void EmptyTensor() {
std::lock_guard<std::mutex> lg(lock_);
prev_tensor_list_map_.clear();
tensor_list_map_.swap(prev_tensor_list_map_);
}
void EmptyCurrentTensor() { tensor_list_map_.clear(); }
bool EnableMemoryControl() { return mem_total_ > 0; }
void AppendToCacheEvictQueue(const std::string &tensor_name) {
std::lock_guard<std::mutex> lk(mem_lock_);
if (std::find(cache_evict_queue_.begin(), cache_evict_queue_.end(), tensor_name) == cache_evict_queue_.end()) {
cache_evict_queue_.push_back(tensor_name);
evict_cond.notify_one();
}
}
bool CheckMemoryAvailable(const std::string &backend_name, const uint64_t data_size) {
if (data_size > mem_total_) {
MS_LOG(ERROR) << "Failed to load data of tensor " << backend_name << " because the its data size (" << data_size
<< ") exceeds the maximum memory limit (" << mem_total_ << ").";
return false;
}
bool ret = CheckAndEvictTensorCache(data_size);
return ret;
}
bool CheckAndEvictTensorCache(const uint64_t data_size) {
std::string candidate_name;
uint64_t candidates_size;
std::unique_lock<std::mutex> lk(mem_lock_);
while (data_size > mem_total_ - mem_usage_) {
evict_cond.wait(lk, [&] { return !cache_evict_queue_.empty(); });
candidate_name = cache_evict_queue_.front();
candidates_size = tensor_list_map_[candidate_name]->GetByteSize();
lock_.lock();
tensor_list_map_[candidate_name]->DeleteDataPtr();
tensor_list_map_.erase(candidate_name);
lock_.unlock();
cache_evict_queue_.pop_front();
mem_usage_ = std::max(uint64_t(0), mem_usage_ - candidates_size);
MS_LOG(INFO) << "Evict tensor: " << candidate_name;
}
mem_usage_ = std::min(mem_total_, mem_usage_ + data_size);
return true;
}
void SetMemTotal(uint64_t total_mem_size) { this->mem_total_ = total_mem_size; }
#ifdef ONLINE_DBG_MODE
bool DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
TypeId device_type, const std::string &addr_format, size_t slot) {
if (filepath.empty()) {
MS_LOG(ERROR) << "Dump file path is null!";
return false;
}
std::string path = "";
if (trans_flag) {
path = filepath + '.' + host_fmt;
} else {
path = filepath + '.' + addr_format;
}
MS_LOG(INFO) << "Dump path is " << path;
std::string tensor_loader_name = tensor_name + ":" + std::to_string(slot);
auto iter = tensor_list_map_.find(tensor_loader_name);
if (iter != tensor_list_map_.end()) {
std::shared_ptr<TensorData> node = iter->second;
size_t host_size = node->GetByteSize();
return DumpJsonParser::DumpToFile(path, node->GetDataPtr(), host_size, host_shape, host_type);
}
MS_LOG(INFO) << "Tensor name:" << tensor_name << " not found in tensor_list_map_";
return true;
}
#endif
private:
std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map_;
std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map_;
uint32_t iter_num_;
std::mutex lock_;
std::mutex mem_lock_;
uint64_t mem_total_;
uint64_t mem_usage_;
std::deque<std::string> cache_evict_queue_;
std::condition_variable evict_cond;
};
#ifdef ONLINE_DBG_MODE
}
#endif
#endif