/* -------------------------------------------------------------------------
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is part of the MindStudio project.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *    http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
*/

#ifndef MSPTI_COMMON_CONCURRENT_MAP_H
#define MSPTI_COMMON_CONCURRENT_MAP_H

#include <unordered_map>
#include <mutex>
#include <stdexcept>
#include <array>

namespace Mspti {
namespace Common {
template <typename K, typename V, typename Hash = std::hash<K>, size_t ConcurrentLevel = 16>
class ConcurrentMap {
    static_assert((ConcurrentLevel & (ConcurrentLevel - 1)) == 0,
                  "ConcurrentLevel must be a power of 2");

    class BucketMap {
    public:
        using iterator = typename std::unordered_map<K, V, Hash>::iterator;
        class Guard {
        public:
            Guard(std::mutex& m, BucketMap& b) : lock(m), bucketMap(b)
            {
                if (isInLock) {
                    throw std::runtime_error("Reentrant lock detected on BucketMap");
                }
                isInLock = true;
            }

            ~Guard()
            {
                isInLock = false;
            }

            Guard(Guard&&) = default;
            Guard& operator=(Guard&&) = default;

            Guard(const Guard&) = delete;
            Guard& operator=(const Guard&) = delete;

            BucketMap* operator->() { return &bucketMap; }

        private:
            std::unique_lock<std::mutex> lock;
            BucketMap& bucketMap;
        };

        std::pair<iterator, bool> UnSafeFind(const K &key)
        {
            auto it = bucket.find(key);
            return {it, it != bucket.end()};
        }

        std::pair<iterator, bool> UnSafeInsert(const K &key, const V &value)
        {
            return bucket.emplace(key, value);
        }

        template<typename... Args>
        std::pair<iterator, bool> UnSafeEmplace(Args&&... args)
        {
            return bucket.emplace(std::forward<Args>(args)...);
        }

        void UnSafeErase(const K &key)
        {
            bucket.erase(key);
        }

        Guard GetGuard()
        {
            return Guard(mapMutex, *this);
        }

        V &operator[](const K &key)
        {
            return bucket[key];
        }

    private:
        inline thread_local static bool isInLock;
        std::mutex mapMutex;
        std::unordered_map<K, V, Hash> bucket;
    };

public:
    using iterator = typename BucketMap::iterator;
    using Guard = typename BucketMap::Guard;

    bool Find(const K& key, V& val)
    {
        auto guard = EnsureBucket(key).GetGuard();
        auto ans = guard->UnSafeFind(key);
        if (ans.second) {
            val = ans.first->second;
        }
        return ans.second;
    }

    void Erase(const K& key) noexcept
    {
        auto guard = EnsureBucket(key).GetGuard();
        guard->UnSafeErase(key);
    }

    std::pair<iterator, bool> Insert(const K& key, const V& value) noexcept
    {
        auto guard = EnsureBucket(key).GetGuard();
        return guard->UnSafeInsert(key, value);
    }

    std::pair<iterator, bool> Insert(const K& key, V&& value) noexcept
    {
        auto guard = EnsureBucket(key).GetGuard();
        return guard->UnSafeInsert(key, std::forward<V>(value));
    }

    template<typename... Args>
    std::pair<iterator, bool> Emplace(K& key, Args&&... args)
    {
        auto guard = EnsureBucket(key).GetGuard();
        return guard->UnSafeEmplace(key, std::forward<Args>(args)...);
    }

    Guard GetGuard(const K &key)
    {
        return EnsureBucket(key).GetGuard();
    }

private:
    static constexpr size_t BucketSize = ConcurrentLevel;
    std::array<BucketMap, BucketSize> buckets;
    Hash hasher;
    BucketMap &EnsureBucket(const K& key)
    {
        const auto index = hasher(key) & (BucketSize - 1);
        return buckets[index];
    }
};
}
}

#endif // MSPTI_COMMON_CONCURRENT_MAP_H