#include <set>
#include <numeric>
#include <algorithm>

#include "xsched/utils/xassert.h"
#include "xsched/sched/policy/wrr.h"

using namespace std::chrono;
using namespace xsched::sched;

void WeightedRoundRobinPolicy::Sched(const Status &status)
{
    for (auto it = weights_.begin(); it != weights_.end();) {
        XQueueHandle xqueue = it->first;
        if (status.xqueue_status.find(xqueue) == status.xqueue_status.end()) {
            it = weights_.erase(it);
        } else {
            ++it;
        }
    }

    for (auto it = run_queue_.begin(); it != run_queue_.end();) {
        if (status.xqueue_status.find(*it) == status.xqueue_status.end()) {
            in_queue_.erase(*it);
            it = run_queue_.erase(it);
        } else {
            ++it;
        }
    }

    if (cur_running_ != 0 &&
        status.xqueue_status.find(cur_running_) == status.xqueue_status.end()) {
        cur_running_ = 0;
    }

    std::set<XQueueHandle> new_handles;
    for (const auto &xq : status.xqueue_status) {
        XQueueHandle handle = xq.first;
        if (in_queue_.find(handle) != in_queue_.end()) continue;
        if (!xq.second->ready) continue;
        new_handles.insert(handle);
    }
    for (auto handle : new_handles) {
        run_queue_.push_back(handle);
        in_queue_.insert(handle);
    }

    if (run_queue_.empty()) return;

    if (cur_running_ == 0) {
        for (auto it = run_queue_.begin(); it != run_queue_.end(); ++it) {
            auto xit = status.xqueue_status.find(*it);
            if (xit == status.xqueue_status.end()) continue;
            if (!xit->second->ready) continue;
            SwitchTo(*it, status);
            return;
        }
        return;
    }

    auto xit = status.xqueue_status.find(cur_running_);
    if (xit == status.xqueue_status.end()) {
        cur_running_ = 0;
        for (auto it = run_queue_.begin(); it != run_queue_.end(); ++it) {
            auto xit2 = status.xqueue_status.find(*it);
            if (xit2 == status.xqueue_status.end()) continue;
            if (!xit2->second->ready) continue;
            SwitchTo(*it, status);
            return;
        }
        return;
    }

    auto now = system_clock::now();
    if (now < cur_end_ && xit->second->ready) return;

    auto cur_it = std::find(run_queue_.begin(), run_queue_.end(), cur_running_);
    if (cur_it != run_queue_.end()) {
        run_queue_.erase(cur_it);
    }
    run_queue_.push_back(cur_running_);

    for (size_t i = 0; i < run_queue_.size(); ++i) {
        XQueueHandle handle = run_queue_.front();
        auto xit2 = status.xqueue_status.find(handle);
        if (xit2 == status.xqueue_status.end()) {
            run_queue_.pop_front();
            in_queue_.erase(handle);
            continue;
        }
        if (!xit2->second->ready) {
            run_queue_.push_back(handle);
            run_queue_.pop_front();
            continue;
        }
        SwitchTo(handle, status);
        return;
    }

    cur_running_ = 0;
}

void WeightedRoundRobinPolicy::RecvHint(std::shared_ptr<const Hint> hint)
{
    switch (hint->Type())
    {
    case kHintTypeWeight:
    {
        auto h = std::dynamic_pointer_cast<const WeightHint>(hint);
        XASSERT(h != nullptr, "hint type not match");
        int32_t weight = h->Weight();
        if (weight <= 0) {
            XWARN("invalid weight %d for XQueue 0x" FMT_64X ", must be positive", weight, h->Handle());
            break;
        }
        weights_[h->Handle()] = weight;
        XINFO("weight of XQueue 0x" FMT_64X " set to %d", h->Handle(), weight);
        break;
    }
    case kHintTypeTimeslice:
    {
        auto h = std::dynamic_pointer_cast<const TimesliceHint>(hint);
        XASSERT(h != nullptr, "hint type not match");
        base_quantum_us_ = h->Ts();
        XINFO("base quantum set to " FMT_64D " us", base_quantum_us_);
        break;
    }
    default:
        XWARN("unsupported hint type: %d", hint->Type());
        break;
    }
}

int32_t WeightedRoundRobinPolicy::GetWeight(XQueueHandle handle)
{
    auto it = weights_.find(handle);
    if (it != weights_.end()) return it->second;
    return 1;
}

int64_t WeightedRoundRobinPolicy::GetQuantumUs(XQueueHandle handle)
{
    int64_t total_weight = 0;
    for (const auto &xq : run_queue_) {
        total_weight += GetWeight(xq);
    }
    if (total_weight == 0) return base_quantum_us_;

    int64_t weight = GetWeight(handle);
    int64_t quantum = base_quantum_us_ * weight / total_weight;
    return quantum > 0 ? quantum : 1;
}

void WeightedRoundRobinPolicy::SwitchTo(XQueueHandle handle, const Status &status)
{
    for (const auto &xq : status.xqueue_status) {
        if (xq.first == handle) continue;
        this->Suspend(xq.first);
    }

    this->Resume(handle);
    cur_running_ = handle;
    cur_end_ = system_clock::now() + microseconds(GetQuantumUs(handle));
    this->AddTimer(cur_end_);
}