/* -------------------------------------------------------------------------
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is part of the MindStudio project.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *    http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
*/

#include "csrc/activity/activity_manager.h"

#include <cstring>
#include <functional>
#include <algorithm>
#include <thread>
#include <chrono>
#include "securec.h"

#include "csrc/activity/ascend/dev_task_manager.h"
#include "csrc/activity/ascend/parser/parser_manager.h"
#include "csrc/activity/ascend/reporter/external_correlation_reporter.h"
#include "csrc/common/plog_manager.h"
#include "csrc/common/utils.h"

namespace Mspti {
namespace Activity {
namespace {
msptiResult IsNeedLdPreload(msptiActivityKind kind)
{
    // Some activity kinds depend on LD_PRELOAD hooking (libmspti.so). If those
    // kinds are enabled but LD_PRELOAD does not contain libmspti.so, return
    // MSPTI_ERROR_WITHOUT_LD_PRELOAD to notify the caller.
    static const std::unordered_set<msptiActivityKind> needLdPreloadKinds = {
        MSPTI_ACTIVITY_KIND_HCCL,
        MSPTI_ACTIVITY_KIND_MEMORY,
        MSPTI_ACTIVITY_KIND_MEMSET,
        MSPTI_ACTIVITY_KIND_MEMCPY
    };

    if (needLdPreloadKinds.find(kind) != needLdPreloadKinds.end()) {
        static const std::string ld = Mspti::Common::Utils::GetEnv("LD_PRELOAD");
        if (ld.find("libmspti.so") == std::string::npos) {
            MSPTI_LOGE("Enable activity kind %d requires libmspti.so in LD_PRELOAD.", static_cast<int>(kind));
            return MSPTI_ERROR_WITHOUT_LD_PRELOAD;
        }
    }
    return MSPTI_SUCCESS;
}

inline bool IsNeededDevTask(msptiActivityKind kind)
{
    static const std::unordered_set<msptiActivityKind> needDevTaskKinds = {
        MSPTI_ACTIVITY_KIND_MARKER,
        MSPTI_ACTIVITY_KIND_KERNEL,
        MSPTI_ACTIVITY_KIND_HCCL,
        MSPTI_ACTIVITY_KIND_COMMUNICATION
    };
    return needDevTaskKinds.find(kind) != needDevTaskKinds.end();
}
}  // namespace

void ActivityBuffer::Init(msptiBuffersCallbackRequestFunc func)
{
    if (func == nullptr) {
        MSPTI_LOGE("The request callback is nullptr.");
        return;
    }
    func(&buf_, &buf_size_, &records_num_);
    const static uint64_t MIN_ACTIVITY_BUFFER_SIZE = 2 * 1024 * 1024;
    if (buf_size_ < MIN_ACTIVITY_BUFFER_SIZE) {
        MSPTI_LOGW("Please malloc the Activity Buffer more than 2MB. Current is %lu Bytes.", buf_size_);
    }
    MSPTI_LOGI("ActivityBuffer init, bufSize: %zu, recordsNum: %zu", buf_size_, records_num_);
}

void ActivityBuffer::UnInit(msptiBuffersCallbackCompleteFunc func)
{
    if (func == nullptr) {
        MSPTI_LOGE("The complete callback is nullptr.");
        return;
    }
    MSPTI_LOGI("CallbackCompleteFunc start, validSize: %zu, bufSize: %zu, recordsNum: %zu",
               valid_size_, buf_size_, records_num_);
    func(buf_, buf_size_, valid_size_);
}

msptiResult ActivityBuffer::Record(msptiActivity *activity, size_t size)
{
    if (activity == nullptr) {
        MSPTI_LOGE("The activity is nullptr, failed to record.");
        return MSPTI_ERROR_INNER;
    }
    if (buf_ == nullptr) {
        MSPTI_LOGE("The ActivityBuffer is nullptr, failed to record activity.");
        return MSPTI_ERROR_INNER;
    }
    if (size > buf_size_ - valid_size_) {
        MSPTI_LOGW("Record is dropped due to insufficient space of Activity Buffer.");
        return MSPTI_ERROR_INNER;
    }
    if (memcpy_s(buf_ + valid_size_, buf_size_ - valid_size_, activity, size) != EOK) {
        return MSPTI_ERROR_INNER;
    }
    valid_size_ += size;
    records_num_++;
    return MSPTI_SUCCESS;
}

size_t ActivityBuffer::BufSize()
{
    return buf_size_;
}

size_t ActivityBuffer::ValidSize()
{
    return valid_size_;
}

const std::set<msptiActivityKind> ActivityManager::supportActivityKinds_ = {
    MSPTI_ACTIVITY_KIND_MARKER, MSPTI_ACTIVITY_KIND_KERNEL, MSPTI_ACTIVITY_KIND_API, MSPTI_ACTIVITY_KIND_HCCL,
    MSPTI_ACTIVITY_KIND_MEMORY, MSPTI_ACTIVITY_KIND_MEMSET, MSPTI_ACTIVITY_KIND_MEMCPY,
    MSPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION, MSPTI_ACTIVITY_KIND_COMMUNICATION, MSPTI_ACTIVITY_KIND_ACL_API, MSPTI_ACTIVITY_KIND_NODE_API, MSPTI_ACTIVITY_KIND_RUNTIME_API,
};

ActivityManager *ActivityManager::GetInstance()
{
    static ActivityManager instance;
    return &instance;
}

ActivityManager::~ActivityManager()
{
    if (thread_run_.load() || t_.joinable()) {
        thread_run_.store(false);
        {
            std::unique_lock<std::mutex> lck(cv_mtx_);
            try {
                cv_.notify_one();
            } catch(...) {
                // Exception occurred during destruction of ActivityManager
            }
        }
        try {
            t_.join();
        } catch(...) {
            // Exception occurred during destruction of ActivityManager
        }
    }
    JoinWorkThreads();
    for (int kindIndex = 0; kindIndex < MSPTI_ACTIVITY_KIND_COUNT; kindIndex++) {
        activity_switch_[kindIndex] = false;
    }
    devices_.clear();
    MSPTI_LOGI("Total activity record: %lu. Total activity drop: %lu",
        total_record_num_.load(), total_drop_num_.load());
}

void ActivityManager::JoinWorkThreads()
{
    for (auto &thread : work_thread_) {
        if (thread.joinable()) {
            thread.join();
        }
    }
    work_thread_.clear();
}

msptiResult ActivityManager::RegisterCallbacks(
    msptiBuffersCallbackRequestFunc funcBufferRequested,
    msptiBuffersCallbackCompleteFunc funcBufferCompleted)
{
    if (funcBufferRequested == nullptr || funcBufferCompleted == nullptr) {
        MSPTI_LOGE("Call msptiActivityRegisterCallbacks failed while request or complete callback is nullptr.");
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    bufferRequested_handle_ = funcBufferRequested;
    bufferCompleted_handle_ = funcBufferCompleted;
    if (!t_.joinable()) {
        t_ = std::thread(std::bind(&ActivityManager::Run, this));
        thread_run_.store(true);
    }
    return MSPTI_SUCCESS;
}


msptiResult ActivityManager::Register(msptiActivityKind kind)
{
    if (supportActivityKinds_.find(kind) == supportActivityKinds_.end()) {
        MSPTI_LOGE("The ActivityKind: %d was not support.", static_cast<int>(kind));
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    if (IsNeedLdPreload(kind) != MSPTI_SUCCESS) {
        return MSPTI_ERROR_WITHOUT_LD_PRELOAD;
    }
    activity_switch_[kind] = true;
    append_only_activity_switch_[kind] = true;
    MSPTI_LOGI("Register Activity kind: %d", static_cast<int>(kind));

    auto localDevices = GetAllValidDevice();
    ActivitySwitchType curOpenSwitch{};
    curOpenSwitch[kind] = true;
    for (auto device : localDevices) {
        Ascend::DevTaskManager::GetInstance()->StartDevProfTask(device, curOpenSwitch);
    }
    Parser::ParserManager::GetInstance()->StartAnalysisTask(kind);
    return MSPTI_SUCCESS;
}

msptiResult ActivityManager::UnRegister(msptiActivityKind kind)
{
    if (supportActivityKinds_.find(kind) == supportActivityKinds_.end()) {
        MSPTI_LOGE("The ActivityKind: %d was not support.", static_cast<int>(kind));
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    if (IsNeededDevTask(kind)) {
        auto localDevices = GetAllValidDevice();
        for (auto device : localDevices) {
            Ascend::DevTaskManager::GetInstance()->FlushDevProfData(device, kind);
        }
        if (!localDevices.empty() &&
            Common::ContextManager::GetInstance()->GetChipType(*localDevices.begin()) == Common::PlatformType::CHIP_V6) {
            constexpr uint32_t sleep_ms = 20;
            std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms));
        }
    }
    Parser::ParserManager::GetInstance()->StopAnalysisTask(kind);
    activity_switch_[kind] = false;
    MSPTI_LOGI("UnRegister Activity kind: %d", static_cast<int>(kind));
    return MSPTI_SUCCESS;
}

bool ActivityManager::IsActivityKindEnable(msptiActivityKind kind)
{
    return activity_switch_[kind];
}

msptiResult ActivityManager::GetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record)
{
    if (buffer == nullptr) {
        MSPTI_LOGE("The address of Activity Buffer is nullptr.");
        return MSPTI_ERROR_INVALID_PARAMETER;
    }
    static thread_local size_t pos = 0;
    if (pos >= validBufferSizeBytes) {
        pos = 0;
        return MSPTI_ERROR_MAX_LIMIT_REACHED;
    }

    static const std::unordered_map<msptiActivityKind, size_t> activityKindDataSize = {
        {MSPTI_ACTIVITY_KIND_MARKER,                sizeof(msptiActivityMarker)},
        {MSPTI_ACTIVITY_KIND_KERNEL,                sizeof(msptiActivityKernel)},
        {MSPTI_ACTIVITY_KIND_API,                   sizeof(msptiActivityApi)},
        {MSPTI_ACTIVITY_KIND_HCCL,                  sizeof(msptiActivityHccl)},
        {MSPTI_ACTIVITY_KIND_MEMORY,                sizeof(msptiActivityMemory)},
        {MSPTI_ACTIVITY_KIND_MEMSET,                sizeof(msptiActivityMemset)},
        {MSPTI_ACTIVITY_KIND_MEMCPY,                sizeof(msptiActivityMemcpy)},
        {MSPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION,  sizeof(msptiActivityExternalCorrelation)},
        {MSPTI_ACTIVITY_KIND_COMMUNICATION,         sizeof(msptiActivityCommunication)},
        {MSPTI_ACTIVITY_KIND_ACL_API,               sizeof(msptiActivityApi)},
        {MSPTI_ACTIVITY_KIND_NODE_API,              sizeof(msptiActivityApi)},
        {MSPTI_ACTIVITY_KIND_RUNTIME_API,           sizeof(msptiActivityApi)}
    };

    msptiActivityKind *pKind = Common::ReinterpretConvert<msptiActivityKind*>(buffer + pos);
    auto iter = activityKindDataSize.find(*pKind);
    if (iter == activityKindDataSize.end()) {
        MSPTI_LOGE("GetNextRecord failed, invalid kind: %d", *pKind);
        return MSPTI_ERROR_INNER;
    }
    *record = Common::ReinterpretConvert<msptiActivity*>(buffer + pos);
    pos += iter->second;
    return MSPTI_SUCCESS;
}

msptiResult ActivityManager::FlushAll()
{
    std::deque<std::unique_ptr<ActivityBuffer>> flushBuffers;
    {
        std::unique_lock<std::mutex> lck(cv_mtx_);
        flushBuffers = std::move(co_activity_buffers_);
    }
    for (const auto &buffer : flushBuffers) {
        if (buffer) {
            buffer->UnInit(bufferCompleted_handle_);
        }
    }
    {
        std::unique_lock<std::mutex> lck(cv_mtx_);
        JoinWorkThreads();
    }
    {
        std::lock_guard<std::mutex> lk(buf_mtx_);
        if (cur_buf_) {
            auto consumeBuf = std::move(cur_buf_);
            consumeBuf->UnInit(this->bufferCompleted_handle_);
        }
    }
    MSPTI_LOGI("Flush all activity buffer.");
    return MSPTI_SUCCESS;
}

msptiResult ActivityManager::FlushPeriod(uint32_t time)
{
    std::unique_lock<std::mutex> lck(cv_mtx_);
    if (time == 0) {
        flush_period_time_ = DEFAULT_PERIOD_FLUSH_TIME;
        flush_period_ = false;
    } else {
        flush_period_time_ = time;
        flush_period_ = true;
        cv_.notify_one();
    }
    return MSPTI_SUCCESS;
}

msptiResult ActivityManager::Record(msptiActivity *activity, size_t size)
{
    if (activity == nullptr) {
        return MSPTI_ERROR_INNER;
    }
    if (!IsActivityKindEnable(activity->kind)) {
        return MSPTI_SUCCESS;
    }
    static const float ACTIVITY_BUFFER_THRESHOLD = 0.8;
    std::lock_guard<std::mutex> lk(buf_mtx_);
    if (!cur_buf_) {
        Mspti::Common::MsptiMakeUniquePtr(cur_buf_);
        if (!cur_buf_) {
            MSPTI_LOGE("Failed to init Activity Buffer.");
            return MSPTI_ERROR_INNER;
        }
        cur_buf_->Init(bufferRequested_handle_);
    } else if (cur_buf_->ValidSize() >= ACTIVITY_BUFFER_THRESHOLD * cur_buf_->BufSize()) {
        {
            std::unique_lock<std::mutex> lck(cv_mtx_);
            buf_full_ = true;
            co_activity_buffers_.emplace_back(std::move(cur_buf_));
            cv_.notify_one();
        }
        Mspti::Common::MsptiMakeUniquePtr(cur_buf_);
        if (!cur_buf_) {
            MSPTI_LOGE("Failed to init Activity Buffer.");
            return MSPTI_ERROR_INNER;
        }
        cur_buf_->Init(bufferRequested_handle_);
    }
    if (cur_buf_->Record(activity, size) != MSPTI_SUCCESS) {
        MSPTI_LOGE("Failed to record activity.");
        total_drop_num_++;
        return MSPTI_ERROR_INNER;
    }
    total_record_num_++;
    return MSPTI_SUCCESS;
}

void ActivityManager::Run()
{
    pthread_setname_np(pthread_self(), "ActivityManager");
    while (true) {
        {
            std::unique_lock<std::mutex> lk(cv_mtx_);
            bool serveForWaitFor = true;
            cv_.wait_for(lk, std::chrono::milliseconds(flush_period_time_), [&]() {
                serveForWaitFor = !serveForWaitFor;
                return (serveForWaitFor && flush_period_) || buf_full_ || !thread_run_.load();
            });
            if (!thread_run_.load()) {
                break;
            }
            {
                for (auto& activity_buffer : co_activity_buffers_) {
                    work_thread_.emplace_back(std::thread([this] (std::unique_ptr<ActivityBuffer> activity_buffer) {
                        activity_buffer->UnInit(this->bufferCompleted_handle_);
                    }, std::move(activity_buffer)));
                }
                co_activity_buffers_.clear();
                buf_full_ = false;
            }
        }
    }
    JoinWorkThreads();
}

msptiResult ActivityManager::SetDevice(uint32_t deviceId)
{
    MSPTI_LOGI("Set device: %u", deviceId);
    {
        std::lock_guard<std::mutex> lk(devices_mtx_);
        if (devices_.find(deviceId) != devices_.end()) {
            MSPTI_LOGW("Device: %u is already set.", deviceId);
            return MSPTI_SUCCESS;
        }
        devices_.insert(deviceId);
    }
    if (std::find(activity_switch_.begin(), activity_switch_.end(), true) == activity_switch_.end()) {
        return MSPTI_SUCCESS;
    }
    return Mspti::Ascend::DevTaskManager::GetInstance()->StartDevProfTask(deviceId, activity_switch_);
}

msptiResult ActivityManager::ResetAllDevice()
{
    auto ret = MSPTI_SUCCESS;
    std::lock_guard<std::mutex> lk(devices_mtx_);
    for (const auto& device : devices_) {
        MSPTI_LOGI("Reset device: %u", device);
        auto temp =
            Mspti::Ascend::DevTaskManager::GetInstance()->StopDevProfTask(device, append_only_activity_switch_);
        if (temp != MSPTI_SUCCESS) {
            ret = temp;
        }
    }
    return ret;
}

ActivityManager::ActivityManager()
{
    for (auto& kindSwitch : activity_switch_) {
        kindSwitch.store(false);
    }
    for (auto& kindSwitch : append_only_activity_switch_) {
        kindSwitch.store(false);
    }
}

const std::unordered_set<uint32_t> ActivityManager::GetAllValidDevice()
{
    std::lock_guard<std::mutex> lk(devices_mtx_);
    return devices_;
}
}  // Activity
}  // Mspti

msptiResult msptiActivityRegisterCallbacks(
    msptiBuffersCallbackRequestFunc funcBufferRequested, msptiBuffersCallbackCompleteFunc funcBufferCompleted)
{
    return Mspti::Activity::ActivityManager::GetInstance()->RegisterCallbacks(funcBufferRequested, funcBufferCompleted);
}

msptiResult msptiActivityEnable(msptiActivityKind kind)
{
    return Mspti::Activity::ActivityManager::GetInstance()->Register(kind);
}

msptiResult msptiActivityDisable(msptiActivityKind kind)
{
    return Mspti::Activity::ActivityManager::GetInstance()->UnRegister(kind);
}

bool msptiActivityIsEnabled(msptiActivityKind kind)
{
    return Mspti::Activity::ActivityManager::GetInstance()->IsActivityKindEnable(kind);
}

msptiResult msptiActivityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record)
{
    return Mspti::Activity::ActivityManager::GetInstance()->GetNextRecord(buffer, validBufferSizeBytes, record);
}

msptiResult msptiActivityFlushAll(uint32_t flag)
{
    UNUSED(flag);
    return Mspti::Activity::ActivityManager::GetInstance()->FlushAll();
}

msptiResult msptiActivityFlushPeriod(uint32_t time)
{
    return Mspti::Activity::ActivityManager::GetInstance()->FlushPeriod(time);
}

msptiResult msptiActivityPushExternalCorrelationId(msptiExternalCorrelationKind kind, uint64_t id)
{
    return Mspti::Reporter::ExternalCorrelationReporter::GetInstance()->PushExternalCorrelationId(kind, id);
}

msptiResult msptiActivityPopExternalCorrelationId(msptiExternalCorrelationKind kind, uint64_t *lastId)
{
    return Mspti::Reporter::ExternalCorrelationReporter::GetInstance()->PopExternalCorrelationId(kind, lastId);
}