* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "command_handle.h"
#include "common/profiling/profiling_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "framework/omg/omg_inner_types.h"
#include "graph/load/model_manager/model_manager.h"
#include "aprof_pub.h"
#include "graph/ge_context.h"
#include "acl/acl_rt.h"
namespace ge {
namespace {
constexpr size_t kDeviceListIndex = 3U;
constexpr uint32_t kCommandNum = 6U;
constexpr uint32_t kMaxDevNum = 64U;
const std::string kDeviceNums = "devNums";
const std::string kDeviceIdList = "devIdList";
const std::string kProfilingInit = "prof_init";
const std::string kProfilingFinalize = "prof_finalize";
const std::string kProfilingStart = "prof_start";
const std::string kProfilingStop = "prof_stop";
const std::string kProfilingModelSubscribe = "prof_model_subscribe";
const std::string kProfilingModelUnsubscribe = "prof_model_cancel_subscribe";
const std::string kProfilingModelId = "modelId";
constexpr int32_t RT_ERROR = -1;
enum class ProfCommandHandleType : uint32_t {
kProfCommandHandleInit = 0,
kProfCommandHandleStart,
kProfCommandHandleStop,
kProfCommandHandleFinalize,
kProfCommandHandleModelSubscribe,
kProfCommandHandleModelUnsubscribe
};
const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap {
{ProfCommandHandleType::kProfCommandHandleInit, kProfilingInit},
{ProfCommandHandleType::kProfCommandHandleStart, kProfilingStart},
{ProfCommandHandleType::kProfCommandHandleStop, kProfilingStop},
{ProfCommandHandleType::kProfCommandHandleFinalize, kProfilingFinalize},
{ProfCommandHandleType::kProfCommandHandleModelSubscribe, kProfilingModelSubscribe},
{ProfCommandHandleType::kProfCommandHandleModelUnsubscribe, kProfilingModelUnsubscribe}
};
bool IsProfConfigValid(const uint32_t deviceid_list[], const uint32_t device_nums) {
if ((device_nums == 0U) || (device_nums > kMaxDevNum)) {
GELOGE(PARAM_INVALID, "[Check][DeviceNums]Invalid, device nums: %u", device_nums);
REPORT_INNER_ERR_MSG("E19999", "DeviceNums %u check invalid", device_nums);
return false;
}
uint32_t dev_count = 0;
const aclError rt_err = aclrtGetDeviceCount(&dev_count);
if (rt_err != ACL_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Get][DeviceCount]Failed, error_code %d", rt_err);
REPORT_INNER_ERR_MSG("E19999", "Get device count failed, error_code %d", rt_err);
return false;
}
if (device_nums > dev_count) {
GELOGE(PARAM_INVALID, "[Check][Param]Device num %u is not in range [1,%u]", device_nums, dev_count);
REPORT_INNER_ERR_MSG("E19999", "Device num %u check invalid, it is not in range [1,%u]", device_nums, dev_count);
return false;
}
std::set<uint32_t> record;
for (uint32_t i = 0U; i < device_nums; ++i) {
const uint32_t dev_id = deviceid_list[i];
if (!record.insert(dev_id).second) {
GELOGE(PARAM_INVALID, "[Check][DeviceId]Device id %u is duplicatedly set", dev_id);
REPORT_INNER_ERR_MSG("E19999", "Device id %u is not unique, duplicatedly set", dev_id);
return false;
}
}
return true;
}
bool TransProfConfigToParam(const MsprofCommandHandle &prof_command_handle,
std::vector<std::string> &prof_config_params) {
prof_config_params.clear();
prof_config_params.emplace_back(kDeviceNums);
prof_config_params.emplace_back(std::to_string(prof_command_handle.devNums));
prof_config_params.emplace_back(kDeviceIdList);
std::string dev_id;
if (prof_command_handle.devNums == 0U) {
GELOGE(FAILED, "[Check][Param]The device num is invalid.");
return false;
}
for (uint32_t i = 0U; i < prof_command_handle.devNums; i++) {
(void)dev_id.append(std::to_string(prof_command_handle.devIdList[i]));
if (i != (prof_command_handle.devNums - 1U)) {
(void)dev_id.append(",");
}
}
prof_config_params.push_back(dev_id);
return true;
}
Status NeedUnsubscribe(const ProfCommandHandleType type, const uint32_t graph_id,
std::vector<std::string> &prof_params) {
if (type == ProfCommandHandleType::kProfCommandHandleModelUnsubscribe) {
prof_params.clear();
prof_params.emplace_back(kProfilingModelId);
auto &prof_mgr = ProfilingManager::Instance();
if (ProfilingProperties::Instance().GetSubscribeInfo().is_subscribe) {
uint32_t model_id = 0U;
const auto ret = prof_mgr.GetModelIdFromGraph(graph_id, model_id);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][GraphId]graph_id:%u not not found", graph_id);
return ret;
}
prof_params.emplace_back(std::to_string(model_id));
} else {
prof_params.emplace_back(std::to_string(graph_id));
}
}
return SUCCESS;
}
Status NeedHandleStartEnd(const ProfCommandHandleType type, const MsprofCommandHandle &prof_command_handle,
std::vector<std::string> &prof_params) {
if ((type == ProfCommandHandleType::kProfCommandHandleStart) ||
(type == ProfCommandHandleType::kProfCommandHandleStop)) {
if (!IsProfConfigValid(prof_command_handle.devIdList, prof_command_handle.devNums)) {
return FAILED;
}
if (!TransProfConfigToParam(prof_command_handle, prof_params)) {
GELOGE(PARAM_INVALID, "[Check][Param]Transfer profilerConfig to std::string vector failed");
REPORT_INNER_ERR_MSG("E19999", "Transfer profilerConfig to std::string vector failed");
return PARAM_INVALID;
}
}
return SUCCESS;
}
void SubscribeInfoToParam(const ProfCommandHandleType type, const MsprofCommandHandle &prof_command_handle,
std::vector<std::string> &prof_params) {
if (type == ProfCommandHandleType::kProfCommandHandleModelSubscribe) {
prof_params.clear();
prof_params.push_back(kProfilingModelId);
prof_params.push_back(std::to_string(prof_command_handle.modelId));
}
}
rtError_t ExecuteCommand(const ProfCommandHandleType type,
const MsprofCommandHandle &prof_command_handle,
const std::vector<std::string> &prof_params) {
const auto it = kProfCommandTypeMap.find(type);
if (it == kProfCommandTypeMap.end()) {
GELOGE(PARAM_INVALID, "[Check][Param]The prof comand type is invalid.");
return RT_ERROR;
}
Command command;
command.cmd_type = it->second;
command.cmd_params = prof_params;
command.cache_flag = prof_command_handle.cacheFlag;
if (type != ProfCommandHandleType::kProfCommandHandleFinalize) {
command.module_index = prof_command_handle.profSwitch;
}
GELOGI("Command Type: %s, data type config: 0x%" PRIx64, it->second.c_str(), command.module_index);
if ((type == ProfCommandHandleType::kProfCommandHandleStart) ||
(type == ProfCommandHandleType::kProfCommandHandleStop)) {
if (prof_params.size() > kDeviceListIndex) {
GELOGI("Profiling device nums:%s, deviceId:%s", prof_params[0U].c_str(), prof_params[kDeviceListIndex].c_str());
} else {
GELOGW("Profiling input param[size=%zu] may invalid", prof_params.size());
}
}
const Status ret = ModelManager::GetInstance().HandleCommand(command);
if (ret != SUCCESS) {
GELOGE(ret, "[Handle][Command]Handle profiling command failed, command type %s, error_code %u",
it->second.c_str(), ret);
REPORT_INNER_ERR_MSG("E19999", "Handle profiling command failed, command type %s, error_code %u",
it->second.c_str(), ret);
return RT_ERROR;
}
GELOGI("Successfully execute profiling command type: %d, command 0x%" PRIx64 ".",
static_cast<int32_t>(type), command.module_index);
return RT_ERROR_NONE;
}
rtError_t HandleCtrlSwitch(const MsprofCommandHandle &prof_command_handle) {
if (prof_command_handle.type >= kCommandNum) {
GELOGE(PARAM_INVALID, "[Check][Type]Type %u is invalid", prof_command_handle.type);
return RT_ERROR;
}
GELOGD("Type is %u", prof_command_handle.type);
std::vector<std::string> prof_params;
const auto type = static_cast<ProfCommandHandleType>(prof_command_handle.type);
Status ret = NeedHandleStartEnd(type, prof_command_handle, prof_params);
if (ret != SUCCESS) {
GELOGE(ret, "[Handle][Command]Handle command failed, the command type is %d.", static_cast<int32_t>(type));
return RT_ERROR;
}
std::string run_mode;
if ((type == ProfCommandHandleType::kProfCommandHandleModelSubscribe) &&
(GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS) && (!run_mode.empty())) {
GELOGD("Subscribe in training.");
ProfilingProperties::Instance().SetSubscribeInfo(prof_command_handle.profSwitch, prof_command_handle.modelId, true);
return RT_ERROR_NONE;
}
SubscribeInfoToParam(type, prof_command_handle, prof_params);
const uint32_t graph_id = prof_command_handle.modelId;
ret = NeedUnsubscribe(type, graph_id, prof_params);
if (ret != SUCCESS) {
GELOGE(ret, "[Check][Param]graph_id:%u not not found", graph_id);
REPORT_PREDEFINED_ERR_MSG(
"E10001", std::vector<const char_t *>({"value", "parameter", "reason"}),
std::vector<const char_t *>({std::to_string(graph_id).c_str(), "GraphToModelMap", "Graph_id does not exist."}));
return RT_ERROR;
}
return ExecuteCommand(type, prof_command_handle, prof_params);
}
rtError_t HandleCtrlSetStepInfo(const ProfStepInfoCmd_t &prof_set_stepinfo) {
int32_t device_id = 0;
const aclError rt_ret = aclrtGetDevice(&device_id);
if (rt_ret != ACL_SUCCESS) {
GELOGE(ge::RT_FAILED, "[Get][LogicDeviceId]Failed, ret %d", rt_ret);
REPORT_INNER_ERR_MSG("E19999", "Get logic device id failed, ret %d", rt_ret);
return RT_ERROR;
}
auto &prof_mgr = ge::ProfilingManager::Instance();
const uint64_t step_id = prof_set_stepinfo.index_id;
const uint16_t tag_id = prof_set_stepinfo.tag_id;
GELOGI("[Cann Profiling] set step info, step id is %" PRIu64 ", tag id is %u", step_id, static_cast<uint32_t>(tag_id));
prof_mgr.SetStepInfoIndex(static_cast<int64_t>(step_id));
const Status ret =
gert::GlobalProfilingWrapper::ProfileStepTrace(step_id, ge::kInvalidModelId, tag_id, prof_set_stepinfo.stream);
return ret == SUCCESS ? RT_ERROR_NONE : RT_ERROR;
}
}
rtError_t ProfCtrlHandle(const uint32_t ctrl_type, void *const ctrl_data, const uint32_t data_len) {
if ((ctrl_data == nullptr) || (data_len == 0U)) {
GELOGE(PARAM_INVALID, "[Check][Param]The prof comand is invalid.");
return RT_ERROR;
}
if (ctrl_type == RT_PROF_CTRL_SWITCH) {
const MsprofCommandHandle *const prof_command_handle = PtrToPtr<void, MsprofCommandHandle>(ctrl_data);
return HandleCtrlSwitch(*prof_command_handle);
} else if (ctrl_type == PROF_CTRL_STEPINFO) {
const ProfStepInfoCmd_t *const prof_command_handle = PtrToPtr<void, ProfStepInfoCmd_t>(ctrl_data);
return HandleCtrlSetStepInfo(*prof_command_handle);
} else {
}
return RT_ERROR;
}
}