/* -------------------------------------------------------------------------
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is part of the MindStudio project.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *    http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
*/
#include "csrc/activity/ascend/dev_task_manager.h"

#include "csrc/activity/activity_manager.h"
#include "csrc/activity/ascend/channel/channel_pool_manager.h"
#include "csrc/common/context_manager.h"
#include "csrc/common/plog_manager.h"
#include "csrc/common/inject/driver_inject.h"
#include "csrc/common/inject/profapi_inject.h"
#include "csrc/common/utils.h"
#include "securec.h"

namespace Mspti {
namespace Ascend {

std::map<msptiActivityKind, uint64_t> DevTaskManager::datatype_config_map_ = {
    {MSPTI_ACTIVITY_KIND_KERNEL, MSPTI_CONFIG_KERNEL},
    {MSPTI_ACTIVITY_KIND_API, MSPTI_CONFIG_API},
    {MSPTI_ACTIVITY_KIND_COMMUNICATION, MSPTI_CONFIG_COMMUNICATION},
    {MSPTI_ACTIVITY_KIND_RUNTIME_API, MSPTI_CONFIG_RUNTIME_API},
    {MSPTI_ACTIVITY_KIND_ACL_API, MSPTI_CONFIG_API},
    {MSPTI_ACTIVITY_KIND_NODE_API, MSPTI_CONFIG_API}
};

DevTaskManager *DevTaskManager::GetInstance()
{
    static DevTaskManager instance;
    return &instance;
}

DevTaskManager::~DevTaskManager()
{
    for (auto iter = task_map_.begin(); iter != task_map_.end(); iter++) {
        StopAllDevKindProfTask(iter->second);
    }
    task_map_.clear();
}

DevTaskManager::DevTaskManager()
{
    Mspti::Common::ContextManager::GetInstance()->InitHostTimeInfo();
}

msptiResult DevTaskManager::StartAllDevKindProfTask(std::vector<std::unique_ptr<DevProfTask>>& profTasks,
                                                    std::vector<std::unique_ptr<DevProfTask>>& successTasks)
{
    for (auto& profTask : profTasks) {
        if (profTask->Start() != MSPTI_SUCCESS) {
            return MSPTI_ERROR_INNER;
        }
        successTasks.emplace_back(std::move(profTask));
    }
    profTasks.clear();
    return MSPTI_SUCCESS;
}

msptiResult DevTaskManager::StopAllDevKindProfTask(std::vector<std::unique_ptr<DevProfTask>>& profTasks)
{
    msptiResult ret = MSPTI_SUCCESS;
    for (auto& profTask : profTasks) {
        if (profTask->Stop() != MSPTI_SUCCESS) {
            ret = MSPTI_ERROR_INNER;
        }
        profTask.reset(nullptr);
    }
    return ret;
}

msptiResult DevTaskManager::StartDevProfTask(uint32_t deviceId,
                                             const ActivitySwitchType& kinds)
{
    if (!CheckDeviceOnline(deviceId)) {
        MSPTI_LOGE("Device: %u is offline.", deviceId);
        return MSPTI_ERROR_INNER;
    }
    MSPTI_LOGI("Start DevProfTask, deviceId: %u.", deviceId);
    if (Mspti::Ascend::Channel::ChannelPoolManager::GetInstance()->GetAllChannels(deviceId) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Get device: %u channels failed.", deviceId);
        return MSPTI_ERROR_INNER;
    }
    Mspti::Common::ContextManager::GetInstance()->InitDevTimeInfo(deviceId);
    // 根据DeviceId配置项,开启CANN软件栈的Profiling任务
    if (StartCannProfTask(deviceId, kinds) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Start CannProfTask failed. deviceId: %u", deviceId);
        return MSPTI_ERROR_INNER;
    }
    std::lock_guard<std::mutex> lk(task_map_mtx_);
    for (int kindIndex = 0; kindIndex < MSPTI_ACTIVITY_KIND_COUNT; kindIndex++) {
        if (!kinds[kindIndex]) {
            continue;
        }
        auto kind = static_cast<msptiActivityKind>(kindIndex);
        auto iter = task_map_.find({deviceId, kind});
        if (iter == task_map_.end()) {
            auto profTasks = Mspti::Ascend::DevProfTaskFactory::CreateTasks(deviceId, kind);
            decltype(profTasks) successTasks;
            auto ret = StartAllDevKindProfTask(profTasks, successTasks);
            task_map_.emplace(std::make_pair(deviceId, kind), std::move(successTasks));
            if (ret != MSPTI_SUCCESS) {
                MSPTI_LOGE("The device %u start DevProfTask failed.", deviceId);
                return ret;
            }
        } else {
            MSPTI_LOGW("The device: %u, kind: %d is already running.", deviceId, kind);
        }
    }
    return MSPTI_SUCCESS;
}

msptiResult DevTaskManager::StopDevProfTask(uint32_t deviceId,
                                            const ActivitySwitchType& kinds)
{
    if (!CheckDeviceOnline(deviceId)) {
        MSPTI_LOGE("Device: %u is offline.", deviceId);
        return MSPTI_ERROR_INNER;
    }
    MSPTI_LOGI("Stop DevProfTask, deviceId: %u", deviceId);
    if (StopCannProfTask(deviceId) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Stop CannProfTask failed. deviceId: %u", deviceId);
        return MSPTI_ERROR_INNER;
    }
    msptiResult ret = MSPTI_SUCCESS;
    std::lock_guard<std::mutex> lk(task_map_mtx_);
    for (int kindIndex = 0; kindIndex < MSPTI_ACTIVITY_KIND_COUNT; kindIndex++) {
        if (!kinds[kindIndex]) {
            continue;
        }
        auto kind = static_cast<msptiActivityKind>(kindIndex);
        auto iter = task_map_.find({deviceId, kind});
        if (iter != task_map_.end()) {
            ret = StopAllDevKindProfTask(iter->second);
            task_map_.erase(iter);
        }
    }
    return ret;
}

msptiResult DevTaskManager::FlushDevProfData(uint32_t deviceId, msptiActivityKind kind)
{
    if (!CheckDeviceOnline(deviceId)) {
        MSPTI_LOGE("Device: %u is offline.", deviceId);
        return MSPTI_ERROR_INNER;
    }
    MSPTI_LOGI("Flush DevProfData, deviceId: %u, kind: %d", deviceId, kind);
    std::lock_guard<std::mutex> lk(task_map_mtx_);
    auto taskIter = task_map_.find({deviceId, kind});
    if (taskIter == task_map_.end()) {
        MSPTI_LOGW("The device: %u, kind: %d is not running.", deviceId, kind);
        return MSPTI_SUCCESS;
    }
    auto ret = MSPTI_SUCCESS;
    for (auto& profTask : taskIter->second) {
        if (profTask->Flush() != MSPTI_SUCCESS) {
            MSPTI_LOGE("The device %u, kind: %d flush DevProfTask failed.", deviceId, kind);
            ret = MSPTI_ERROR_INNER;
        }
    }
    return ret;
}

bool DevTaskManager::CheckDeviceOnline(uint32_t deviceId)
{
    std::call_once(get_device_flag_, [this] () {
        this->InitDeviceList();
    });
    return device_set_.find(deviceId) != device_set_.end() ? true : false;
}

void DevTaskManager::InitDeviceList()
{
    uint32_t deviceNum = 0;
    auto ret = DrvGetDevNum(&deviceNum);
    constexpr int32_t DEVICE_MAX_NUM = 64;
    if (ret != DRV_ERROR_NONE || deviceNum > DEVICE_MAX_NUM) {
        MSPTI_LOGE("Get device num failed.");
        return;
    }
    uint32_t deviceList[DEVICE_MAX_NUM] = {0};
    ret = DrvGetDevIDs(deviceList, deviceNum);
    if (ret != DRV_ERROR_NONE) {
        MSPTI_LOGE("Get device id list failed.");
        return;
    }
    device_set_.insert(std::begin(deviceList), std::end(deviceList));
}

void DevTaskManager::RegisterReportCallback()
{
    static const std::vector<std::pair<int, VOID_PTR>> CALLBACK_FUNC_LIST = {
        {PROFILE_REPORT_GET_HASH_ID_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiGetHashIdImpl)},
        {PROFILE_HOST_FREQ_IS_ENABLE_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiHostFreqIsEnableImpl)},
        {PROFILE_REPORT_API_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiApiReporterCallbackImpl)},
        {PROFILE_REPORT_EVENT_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiEventReporterCallbackImpl)},
        {PROFILE_REPORT_COMPACT_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiCompactInfoReporterCallbackImpl)},
        {PROFILE_REPORT_ADDITIONAL_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiAddiInfoReporterCallbackImpl)},
        {PROFILE_REPORT_REG_TYPE_INFO_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsptiRegReportTypeInfoImpl)},
        {PROFILE_DEVICE_STATE_C_CALLBACK,
            reinterpret_cast<VOID_PTR>(Mspti::Inject::Detail::MsprofDeviceStateImpl)},
    };
    for (auto iter : CALLBACK_FUNC_LIST) {
        auto ret = Mspti::Inject::MsprofRegisterProfileCallback(iter.first, iter.second, sizeof(VOID_PTR));
        if (ret != MSPTI_SUCCESS) {
            MSPTI_LOGE("Failed to register reporter callback: %d", static_cast<int32_t>(iter.first));
        }
    }
}

void DevTaskManager::UnRegisterReportCallback()
{
    static const std::vector<std::pair<int, VOID_PTR>> CALLBACK_FUNC_LIST = {
        {PROFILE_HOST_FREQ_IS_ENABLE_C_CALLBACK, nullptr},
        {PROFILE_REPORT_API_C_CALLBACK, nullptr},
        {PROFILE_REPORT_EVENT_C_CALLBACK, nullptr},
        {PROFILE_REPORT_COMPACT_CALLBACK, nullptr},
        {PROFILE_REPORT_ADDITIONAL_CALLBACK, nullptr},
        {PROFILE_DEVICE_STATE_C_CALLBACK, nullptr},
    };
    for (auto iter : CALLBACK_FUNC_LIST) {
        auto ret = Mspti::Inject::MsprofRegisterProfileCallback(iter.first, iter.second, sizeof(VOID_PTR));
        if (ret != MSPTI_SUCCESS) {
            MSPTI_LOGE("Failed to unregister reporter callback: %d", static_cast<int32_t>(iter.first));
        }
    }
}

msptiResult DevTaskManager::StartCannProfTask(uint32_t deviceId, const ActivitySwitchType& kinds)
{
    for (int kindIndex = 0; kindIndex < MSPTI_ACTIVITY_KIND_COUNT; kindIndex++) {
        if (!kinds[kindIndex]) {
            continue;
        }
        auto kind = static_cast<msptiActivityKind>(kindIndex);
        auto cfg_iter = datatype_config_map_.find(kind);
        if (cfg_iter == datatype_config_map_.end()) {
            MSPTI_LOGW("Device: %u, kind: %d don't need to start cann profiling task.",
                deviceId, static_cast<int>(kind));
            continue;
        }
        profSwitch_ |= cfg_iter->second;
    }
    if (profSwitch_ == 0) {
        MSPTI_LOGW("profswitch is zero, don't need to enable cann prof.");
        return MSPTI_SUCCESS;
    }
    CommandHandle command;
    if (memset_s(&command, sizeof(command), 0, sizeof(command)) != EOK) {
        MSPTI_LOGE("memset CommandHandle failed.");
        return MSPTI_ERROR_INNER;
    }
    command.profSwitch = profSwitch_;
    command.profSwitchHi = 0;
    command.devNums = 1;
    command.devIdList[0] = deviceId;
    command.modelId = PROF_INVALID_MODE_ID;
    command.type = PROF_COMMANDHANDLE_TYPE_START;
    auto ret = Mspti::Inject::profSetProfCommand(static_cast<VOID_PTR>(&command), sizeof(CommandHandle));
    if (ret != MSPTI_SUCCESS) {
        MSPTI_LOGE("Start Profiling Command failed.");
        return MSPTI_ERROR_INNER;
    }
    return MSPTI_SUCCESS;
}

msptiResult DevTaskManager::StopCannProfTask(uint32_t deviceId)
{
    CommandHandle command;
    if (memset_s(&command, sizeof(command), 0, sizeof(command)) != EOK) {
        MSPTI_LOGE("memset CommandHandle failed.");
        return MSPTI_ERROR_INNER;
    }
    if (profSwitch_ == 0) {
        MSPTI_LOGW("profswitch is zero, don't need to disable cann prof.");
        return MSPTI_SUCCESS;
    }
    command.profSwitch = profSwitch_;
    command.profSwitchHi = 0;
    command.devNums = 1;
    command.devIdList[0] = deviceId;
    command.modelId = PROF_INVALID_MODE_ID;
    command.type = PROF_COMMANDHANDLE_TYPE_STOP;
    auto ret = Mspti::Inject::profSetProfCommand(static_cast<VOID_PTR>(&command), sizeof(CommandHandle));
    if (ret != MSPTI_SUCCESS) {
        MSPTI_LOGE("Stop Profiling Commond failed.");
        return MSPTI_ERROR_INNER;
    }
    return MSPTI_SUCCESS;
}
}  // Ascend
}  // Mspti