* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "tensor_monitor.h"
namespace MemScope {
void TensorMonitor::AddWatchTensor(MonitoredTensor& tensorInfo)
{
std::lock_guard<std::mutex> lock(mapMutex_);
uint64_t ptr = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(tensorInfo.data));
auto it = pythonWatchedTensorsMap_.find(ptr);
if (it != pythonWatchedTensorsMap_.end()) {
pythonWatchedTensorsMap_[ptr] = tensorInfo;
} else {
pythonWatchedTensorsMap_.insert({ptr, tensorInfo});
}
}
void TensorMonitor::AddWatchTensor(const std::vector<MonitoredTensor>& tensorInfoLists, uint32_t outputId)
{
std::lock_guard<std::mutex> lock(mapMutex_);
outputId_ = outputId;
for (auto& tensorInfo : tensorInfoLists) {
uint64_t ptr = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(tensorInfo.data));
auto it = cmdWatchedTensorsMap_.find(ptr);
if (it != cmdWatchedTensorsMap_.end()) {
cmdWatchedTensorsMap_[ptr] = tensorInfo;
} else {
cmdWatchedTensorsMap_.insert({ptr, tensorInfo});
}
}
}
std::unordered_map<uint64_t, MonitoredTensor> TensorMonitor::GetCmdWatchedTensorsMap()
{
std::lock_guard<std::mutex> lock(mapMutex_);
return cmdWatchedTensorsMap_;
}
uint32_t TensorMonitor::GetCmdWatchedOutputId() const
{
return outputId_;
}
std::unordered_map<uint64_t, MonitoredTensor> TensorMonitor::GetPythonWatchedTensorsMap()
{
std::lock_guard<std::mutex> lock(mapMutex_);
return pythonWatchedTensorsMap_;
}
void TensorMonitor::DeleteWatchTensor(MonitoredTensor& tensorInfo)
{
std::lock_guard<std::mutex> lock(mapMutex_);
uint64_t ptr = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(tensorInfo.data));
auto it = pythonWatchedTensorsMap_.find(ptr);
if (it != pythonWatchedTensorsMap_.end()) {
pythonWatchedTensorsMap_.erase(ptr);
} else {
LOG_WARN("Failed to delete the tensor. The tensor ptr of %llu is not watched.", ptr);
}
}
void TensorMonitor::ClearCmdWatchTensor()
{
std::lock_guard<std::mutex> lock(mapMutex_);
cmdWatchedTensorsMap_.clear();
}
bool TensorMonitor::IsInMonitoring()
{
if (cmdWatchedTensorsMap_.empty() && pythonWatchedTensorsMap_.empty()) {
return false;
} else {
return true;
}
}
}