* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::{
cell::{Cell, UnsafeCell},
ffi::c_int,
mem,
};
use cudax::runtime::{cudaError_t, cudaSuccess};
#[derive(Debug)]
pub struct Context {
device: Cell<c_int>,
last_error: Cell<cudaError_t>,
}
impl Context {
#[cold]
fn new() -> Self {
Self {
device: Cell::new(0),
last_error: Cell::new(cudaSuccess),
}
}
#[inline(always)]
pub fn get_device(&self) -> c_int {
self.device.get()
}
#[inline(always)]
pub fn set_device(&self, id: c_int) {
self.device.set(id);
}
#[inline(always)]
pub fn set_error(&self, err: cudaError_t) {
self.last_error.set(err);
}
#[inline(always)]
pub fn get_error(&self) -> cudaError_t {
let err = self.last_error.take();
self.last_error.set(cudaSuccess);
err
}
#[inline(always)]
pub fn peek_error(&self) -> cudaError_t {
self.last_error.get()
}
}
#[inline(always)]
pub fn context() -> &'static Context {
thread_local! {
static CONTEXT: UnsafeCell<Option<Context>> = const { UnsafeCell::new(None) };
}
CONTEXT.with(|cell| {
let slot = unsafe { &mut *cell.get() };
let ctx = slot.get_or_insert_with(Context::new);
unsafe { mem::transmute::<&Context, &'static Context>(ctx) }
})
}