/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "atb/kernel_cache/kernel_cache.h"
#include <functional>
#include <securec.h>
#include <map>
#include <asdops/params/params.h>
#include <mki/utils/time/timer.h>
#include "atb/utils/log.h"
#include "atb/utils/tensor_util.h"
#include "atb/utils/param_compare.h"
#include "atb/utils/statistic.h"

namespace atb {
constexpr uint64_t DEFAULT_TILING_SIZE = 10240;
void CacheSlot::Init(uint32_t cacheItemCount)
{
    cachedItems.resize(cacheItemCount);
    for (auto &cacheItem : cachedItems) {
        cacheItem.tilingBuffer.reserve(DEFAULT_TILING_SIZE);
    }
}

void CacheSlot::AddTiling(uint8_t *tilingData, uint64_t tilingSize, const Mki::LaunchParam &launchParam,
                          const Mki::Kernel *kernel)
{
    if (replacePos >= cachedItems.size()) {
        replacePos = 0;
    }
    auto &cachedItem = cachedItems.at(replacePos);
    cachedItem.launchParam = launchParam;
    if (kernel != nullptr) {
        cachedItem.kernel.reset(kernel->Clone());
    }
    cachedItem.tilingBuffer.resize(tilingSize);
    int ret = memcpy_s(cachedItem.tilingBuffer.data(), tilingSize, tilingData, tilingSize);
    ATB_LOG_IF(ret != EOK, ERROR) << "memcpy_s Error! Error Code: " << ret;

    replacePos++;
    validSize = replacePos > validSize ? replacePos : validSize;
}

TilingBufferPtr CacheSlot::GetTilingByIndex(const size_t index, const Mki::LaunchParam &launchParam,
                                            const Mki::Kernel* &kernel)
{
    auto &cachedItem = cachedItems.at(index);
    Mki::Timer timer;

    bool equal = IsLaunchParamEqual(cachedItem.launchParam, launchParam);
    GetOpSetupStatistic().kernelCacheCompareRunInfoTime += timer.ElapsedMicroSecond();
    if (equal) {
        Mki::Timer kernelCacheTimer;
        if (cachedItem.kernel != nullptr) {
            kernel = cachedItem.kernel.get();
        }
        GetOpSetupStatistic().kernelCacheGetRunInfoTime += kernelCacheTimer.ElapsedMicroSecond();
        hitPos = index;
        return &cachedItem.tilingBuffer;
    }
    return nullptr;
}

TilingBufferPtr CacheSlot::GetTiling(const Mki::LaunchParam &launchParam, const Mki::Kernel* &kernel)
{
    TilingBufferPtr tilingBuffeerAddr = nullptr;
    for (size_t i = hitPos; i < validSize; i++) {
        tilingBuffeerAddr = GetTilingByIndex(i, launchParam, kernel);
        if (tilingBuffeerAddr != nullptr) {
            return tilingBuffeerAddr;
        }
    }
    for (size_t i = 0; i < hitPos; i++) {
        tilingBuffeerAddr = GetTilingByIndex(i, launchParam, kernel);
        if (tilingBuffeerAddr != nullptr) {
            return tilingBuffeerAddr;
        }
    }
    return tilingBuffeerAddr;
}

KernelCache::KernelCache() noexcept {}

KernelCache::~KernelCache() {}

void KernelCache::Init(uint64_t kernelCount, uint32_t cacheItemCount)
{
    if (cachedSlots_.empty()) {
        cachedSlots_.resize(kernelCount);
        for (auto &cachedSlot : cachedSlots_) {
            cachedSlot.Init(cacheItemCount);
        }
    }
}

void KernelCache::AddTiling(size_t kernelIndex, uint8_t *tilingData, uint64_t tilingSize,
                            const Mki::LaunchParam &launchParam, const Mki::Kernel *kernel)
{
    if (IsValid(kernelIndex)) {
        auto &cacheSlot = cachedSlots_.at(kernelIndex);
        cacheSlot.AddTiling(tilingData, tilingSize, launchParam, kernel);
    }
}

TilingBufferPtr KernelCache::GetTiling(size_t kernelIndex, const Mki::LaunchParam &launchParam, const Mki::Kernel* &kernel)
{
    if (IsValid(kernelIndex)) {
        auto &cacheSlot = cachedSlots_.at(kernelIndex);
        return cacheSlot.GetTiling(launchParam, kernel);
    }
    return nullptr;
}

bool KernelCache::IsValid(size_t kernelIndex) const
{
    return static_cast<uint64_t>(kernelIndex) < cachedSlots_.size();
}
} // namespace atb