// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::{
    cell::{Cell, UnsafeCell},
    ffi::c_int,
    mem,
};

use cudax::runtime::{cudaError_t, cudaSuccess};

#[derive(Debug)]
pub struct Context {
    device: Cell<c_int>,
    last_error: Cell<cudaError_t>,
}

impl Context {
    #[cold]
    fn new() -> Self {
        Self {
            device: Cell::new(0),
            last_error: Cell::new(cudaSuccess),
        }
    }

    #[inline(always)]
    pub fn get_device(&self) -> c_int {
        self.device.get()
    }

    #[inline(always)]
    pub fn set_device(&self, id: c_int) {
        self.device.set(id);
    }

    #[inline(always)]
    pub fn set_error(&self, err: cudaError_t) {
        self.last_error.set(err);
    }

    #[inline(always)]
    pub fn get_error(&self) -> cudaError_t {
        let err = self.last_error.take();
        self.last_error.set(cudaSuccess);

        err
    }

    #[inline(always)]
    pub fn peek_error(&self) -> cudaError_t {
        self.last_error.get()
    }
}

#[inline(always)]
pub fn context() -> &'static Context {
    thread_local! {
        static CONTEXT: UnsafeCell<Option<Context>> = const { UnsafeCell::new(None) };
    }

    CONTEXT.with(|cell| {
        // SAFETY:
        // 1. `CONTEXT` is a `thread_local` static, ensuring it is unique per thread.
        // 2. We are inside the `.with()` closure, which guarantees we are on the owning thread.
        // 3. No other references to this cell can exist concurrently on this thread,
        //    so obtaining a mutable reference to the inner `Option` is safe and free of data races.
        let slot = unsafe { &mut *cell.get() };

        let ctx = slot.get_or_insert_with(Context::new);

        // SAFETY:
        // 1. The returned reference points to data in Thread Local Storage (TLS).
        //    This memory remains valid as long as the thread is alive.
        // 2. We perform `transmute` to extend the lifetime to `'static` because the
        //    compiler cannot infer that TLS data outlives the `.with()` closure scope.
        // 3. This is sound provided the caller does not send the reference to other threads
        //    (Context is typically !Sync) or use it after the thread has terminated.
        unsafe { mem::transmute::<&Context, &'static Context>(ctx) }
    })
}