#include <chrono>
#include <thread>
#include <random>
#include <cstdio>
#include <cstdlib>
#include <cuda_runtime.h>

#include "xsched/xsched.h"
#include "xsched/cuda/hal.h"

#define VECTOR_SIZE (1 << 25) // 32MB
#define N 100    // Number of vector additions per task
#define M 10000  // Number of tasks, (almost) never stops

// Global memory pointers
float *h_A, *h_B, *h_C;
float *d_A, *d_B, *d_C;

cudaStream_t stream;
HwQueueHandle hwq;
XQueueHandle xq;

__global__ void vector_add(const float* A, const float* B, float* C, int n)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i >= n) return;
    C[i] = A[i] + B[i];
}

void prepare(int priority)
{
    size_t size = VECTOR_SIZE * sizeof(float);

    // Allocate host memory
    h_A = (float*)malloc(size);
    h_B = (float*)malloc(size);
    h_C = (float*)malloc(size);

    // Initialize host vectors
    for (int i = 0; i < VECTOR_SIZE; ++i) {
        h_A[i] = static_cast<float>(rand()) / RAND_MAX;
        h_B[i] = static_cast<float>(rand()) / RAND_MAX;
    }

    // Allocate device memory
    cudaMalloc(&d_A, size);
    cudaMalloc(&d_B, size);
    cudaMalloc(&d_C, size);

    // Copy vectors to device
    cudaMemcpy(d_A, h_A, size, cudaMemcpyHostToDevice);
    cudaMemcpy(d_B, h_B, size, cudaMemcpyHostToDevice);

    cudaStreamCreate(&stream);

    // create XQueue through XSched API
    CudaQueueCreate(&hwq, stream);
    XQueueCreate(&xq, hwq, kPreemptLevelBlock, kQueueCreateFlagNone);
    XQueueSetLaunchConfig(xq, 8, 4);
    // give hints to set priority through XHint API, rather than environment variables
    XHintPriority(xq, priority);
}

void run_task()
{
    // Launch kernel N times
    int block_size = 256;
    int grid_size = (VECTOR_SIZE + block_size - 1) / block_size;
    
    for (int i = 0; i < N; ++i) {
        vector_add<<<grid_size, block_size, 0, stream>>>(d_A, d_B, d_C, VECTOR_SIZE);
    }
    cudaStreamSynchronize(stream);
}

void cleanup()
{
    // Free memory
    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
    free(h_A);
    free(h_B);
    free(h_C);
}

int main(int argc, char *argv[])
{
    if (argc != 2) {
        printf("Usage: %s <priority>\n", argv[0]);
        return 1;
    }
    int priority = atoi(argv[1]);

    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dis(30, 50);

    prepare(priority);

    // Run tasks
    for (int i = 0; i < M; ++i) {
        auto start = std::chrono::high_resolution_clock::now();
        run_task();
        auto end = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
        printf("Task %d completed in %lld ms\n", i, duration.count());

        // Sleep for random interval between tasks
        std::this_thread::sleep_for(std::chrono::milliseconds(dis(gen)));
    }

    cleanup();
    return 0;
}