* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use heck::ToSnakeCase;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
token::{Comma, Unsafe},
Abi, Attribute, FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemForeignMod, ItemMod,
ReturnType, Token, Type, Visibility,
};
struct HookedFunction {
attrs: Vec<Attribute>,
vis: Visibility,
abi: Abi,
unsafety: Option<Unsafe>,
original_ident: Ident,
trait_method_ident: Ident,
inputs: Punctuated<FnArg, Comma>,
output: ReturnType,
}
impl HookedFunction {
fn from_foreign_fn(func: ForeignItemFn, block_abi: &Abi) -> syn::Result<Self> {
let original_ident = func.sig.ident.clone();
let syn::Signature {
ident: trait_method_ident,
unsafety,
inputs,
output,
..
} = func.sig;
Ok(Self {
attrs: func.attrs,
vis: func.vis,
abi: block_abi.clone(),
unsafety,
original_ident,
trait_method_ident,
inputs,
output,
})
}
fn to_trait_method_tokens(&self) -> proc_macro2::TokenStream {
let name = &self.trait_method_ident;
let inputs = &self.inputs;
let output = &self.output;
let unsafety = &self.unsafety;
quote! {
#[allow(clippy::too_many_arguments, non_snake_case)]
#unsafety fn #name(&self, #inputs) #output;
}
}
fn to_panic_impl_tokens(&self) -> proc_macro2::TokenStream {
let name = &self.trait_method_ident;
let inputs = &self.inputs;
let output = &self.output;
let unsafety = &self.unsafety;
let original_name_str = self.original_ident.to_string();
quote! {
#[allow(clippy::too_many_arguments, non_snake_case, unused_variables)]
#unsafety fn #name(&self, #inputs) #output {
unimplemented!(
"Function `{}` is not implemented.",
#original_name_str
);
}
}
}
fn to_hooked_extern_fn_tokens(
&self,
api_dispatcher_struct_name: &Ident,
) -> proc_macro2::TokenStream {
let attrs = &self.attrs;
let vis = &self.vis;
let abi = &self.abi;
let unsafety = &self.unsafety;
let original_name = &self.original_ident;
let trait_method_name = &self.trait_method_ident;
let inputs = &self.inputs;
let output = &self.output;
let arg_pats = self.inputs.iter().filter_map(|arg| {
if let FnArg::Typed(pt) = arg {
Some(&pt.pat)
} else {
None
}
});
quote! {
#(#attrs)*
#[inline(always)]
#[unsafe(no_mangle)]
#vis #unsafety #abi fn #original_name(#inputs) #output {
#api_dispatcher_struct_name::current().#trait_method_name(#(#arg_pats),*)
}
}
}
}
struct ApiHookArgs {
trait_name: Ident,
backend_type: Option<Type>,
}
impl Parse for ApiHookArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let trait_name: Ident = input.parse()?;
let mut backend_type = None;
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
let key = input.parse::<Ident>()?;
if key != "backend" {
return Err(syn::Error::new(key.span(), "expected `backend`"));
}
input.parse::<Token![=]>()?;
let kind = input.parse::<Type>()?;
backend_type = Some(kind);
}
Ok(ApiHookArgs {
trait_name,
backend_type,
})
}
}
pub fn find_and_remove_foreign_mod(module: &mut ItemMod) -> syn::Result<ItemForeignMod> {
let content = match &mut module.content {
Some(content) => &mut content.1,
None => {
return Err(syn::Error::new_spanned(
module,
"The module cannot be empty; it must contain an `extern` block.",
));
}
};
let extern_block_pos = content
.iter()
.position(|item| matches!(item, Item::ForeignMod(_)));
let index = match extern_block_pos {
Some(i) => i,
None => {
return Err(syn::Error::new_spanned(
module,
"Could not find an `extern` block inside the module.",
));
}
};
if content
.iter()
.skip(index + 1)
.any(|item| matches!(item, Item::ForeignMod(_)))
{
return Err(syn::Error::new_spanned(
module,
"Found multiple `extern` blocks in the module. The `api_hook` attribute requires exactly one.",
));
}
if let Item::ForeignMod(fm) = content.remove(index) {
Ok(fm)
} else {
unreachable!();
}
}
fn generate_api_hook_tokens(
args: ApiHookArgs,
foreign_mod: ItemForeignMod,
) -> syn::Result<proc_macro2::TokenStream> {
let hooked_functions: Vec<HookedFunction> = foreign_mod
.items
.into_iter()
.map(|item| match item {
ForeignItem::Fn(func) => HookedFunction::from_foreign_fn(func, &foreign_mod.abi),
other => Err(syn::Error::new_spanned(
other,
"The `extern` block can only contain function declarations.",
)),
})
.collect::<Result<_, _>>()?;
let trait_name = args.trait_name;
let trait_name_str = trait_name.to_string();
let api_runtime_trait_name = format_ident!("__{}Runtime", trait_name_str);
let api_dispatcher_struct_name = format_ident!("__{}Dispatcher", trait_name_str);
let internal_mod_name = format_ident!("__{}_impl", trait_name_str.to_snake_case());
let trait_methods = hooked_functions
.iter()
.map(HookedFunction::to_trait_method_tokens);
let new_extern_fns = hooked_functions
.iter()
.map(|f| f.to_hooked_extern_fn_tokens(&api_dispatcher_struct_name));
let backend_impl_block = match args.backend_type {
Some(user_impl_type) => {
quote! {
use super::*;
impl #api_runtime_trait_name for #api_dispatcher_struct_name {
type Impl = #user_impl_type;
#[inline(always)]
fn current() -> &'static Self::Impl {
static INSTANCE: #user_impl_type = #user_impl_type;
&INSTANCE
}
}
}
}
None => {
let panic_stub_ident = format_ident!("__{}PanicStub", trait_name_str);
let panic_impl_methods = hooked_functions
.iter()
.map(HookedFunction::to_panic_impl_tokens);
quote! {
use super::*;
pub struct #panic_stub_ident;
impl #trait_name for #panic_stub_ident {
#(#panic_impl_methods)*
}
impl #api_runtime_trait_name for #api_dispatcher_struct_name {
type Impl = #panic_stub_ident;
#[inline(always)]
fn current() -> &'static Self::Impl {
static INSTANCE: #panic_stub_ident = #panic_stub_ident;
&INSTANCE
}
}
}
}
};
let definitions = quote! {
pub trait #trait_name: Send + Sync + 'static { #(#trait_methods)* }
trait #api_runtime_trait_name {
type Impl: #trait_name;
fn current() -> &'static Self::Impl;
}
struct #api_dispatcher_struct_name;
};
Ok(quote! {
#definitions
#(#new_extern_fns)*
mod #internal_mod_name {
#backend_impl_block
}
})
}
pub fn api_hook_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ApiHookArgs);
let item = parse_macro_input!(item as Item);
let mut user_mod = match item {
Item::Mod(module) => module,
other => {
return syn::Error::new_spanned(
other,
"The `api_hook` attribute can only be applied to a module containing an `extern` block.",
).to_compile_error().into();
}
};
let foreign_mod = match find_and_remove_foreign_mod(&mut user_mod) {
Ok(fm) => fm,
Err(e) => return e.to_compile_error().into(),
};
let generated_tokens = match generate_api_hook_tokens(args, foreign_mod) {
Ok(tokens) => tokens,
Err(e) => return e.to_compile_error().into(),
};
let other_items = user_mod.content.map(|(_, items)| items).unwrap_or_default();
let mod_attrs = user_mod.attrs;
let mod_vis = user_mod.vis;
let mod_ident = user_mod.ident;
let mod_unsafety = user_mod.unsafety;
quote! {
#(#mod_attrs)*
#mod_vis #mod_unsafety mod #mod_ident {
#(#other_items)*
#generated_tokens
}
}
.into()
}