// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use heck::ToSnakeCase;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    punctuated::Punctuated,
    token::{Comma, Unsafe},
    Abi, Attribute, FnArg, ForeignItem, ForeignItemFn, Ident, Item, ItemForeignMod, ItemMod,
    ReturnType, Token, Type, Visibility,
};

/// Models a function from an `extern` block that needs to be hooked.
struct HookedFunction {
    attrs: Vec<Attribute>,
    vis: Visibility,
    abi: Abi,
    unsafety: Option<Unsafe>,
    original_ident: Ident,
    trait_method_ident: Ident,
    inputs: Punctuated<FnArg, Comma>,
    output: ReturnType,
}

impl HookedFunction {
    /// Creates a `HookedFunction` by taking ownership of a `ForeignItemFn` to avoid clones.
    fn from_foreign_fn(func: ForeignItemFn, block_abi: &Abi) -> syn::Result<Self> {
        let original_ident = func.sig.ident.clone();
        let syn::Signature {
            ident: trait_method_ident,
            unsafety,
            inputs,
            output,
            ..
        } = func.sig;

        Ok(Self {
            attrs: func.attrs,
            vis: func.vis,
            abi: block_abi.clone(),
            unsafety,
            original_ident,
            trait_method_ident,
            inputs,
            output,
        })
    }

    /// Generates the method signature for the public API trait.
    fn to_trait_method_tokens(&self) -> proc_macro2::TokenStream {
        let name = &self.trait_method_ident;
        let inputs = &self.inputs;
        let output = &self.output;
        let unsafety = &self.unsafety;

        quote! {
            #[allow(clippy::too_many_arguments, non_snake_case)]
            #unsafety fn #name(&self, #inputs) #output;
        }
    }

    /// Generates the method implementation for the default `PanicStub`.
    fn to_panic_impl_tokens(&self) -> proc_macro2::TokenStream {
        let name = &self.trait_method_ident;
        let inputs = &self.inputs;
        let output = &self.output;
        let unsafety = &self.unsafety;
        let original_name_str = self.original_ident.to_string();

        quote! {
            #[allow(clippy::too_many_arguments, non_snake_case, unused_variables)]
            #unsafety fn #name(&self, #inputs) #output {
                unimplemented!(
                    "Function `{}` is not implemented.",
                    #original_name_str
                );
            }
        }
    }

    /// Generates the FFI entry point function that delegates to the trait implementation.
    fn to_hooked_extern_fn_tokens(
        &self,
        api_dispatcher_struct_name: &Ident,
    ) -> proc_macro2::TokenStream {
        let attrs = &self.attrs;
        let vis = &self.vis;
        let abi = &self.abi;
        let unsafety = &self.unsafety;
        let original_name = &self.original_ident;
        let trait_method_name = &self.trait_method_ident;
        let inputs = &self.inputs;
        let output = &self.output;

        // Lazily create an iterator for argument patterns.
        let arg_pats = self.inputs.iter().filter_map(|arg| {
            if let FnArg::Typed(pt) = arg {
                Some(&pt.pat)
            } else {
                None
            }
        });

        quote! {
            #(#attrs)*
            #[inline(always)]
            #[unsafe(no_mangle)]
            #vis #unsafety #abi fn #original_name(#inputs) #output {
                #api_dispatcher_struct_name::current().#trait_method_name(#(#arg_pats),*)
            }
        }
    }
}

/// Parses the macro arguments, e.g., `(MyApi, backend = MyImpl)`.
struct ApiHookArgs {
    trait_name: Ident,
    backend_type: Option<Type>,
}

impl Parse for ApiHookArgs {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let trait_name: Ident = input.parse()?;
        let mut backend_type = None;

        if input.peek(Token![,]) {
            input.parse::<Token![,]>()?;

            let key = input.parse::<Ident>()?;
            if key != "backend" {
                return Err(syn::Error::new(key.span(), "expected `backend`"));
            }

            input.parse::<Token![=]>()?;

            let kind = input.parse::<Type>()?;
            backend_type = Some(kind);
        }

        Ok(ApiHookArgs {
            trait_name,
            backend_type,
        })
    }
}

/// Finds, removes, and returns the unique `extern` block from the module's items.
pub fn find_and_remove_foreign_mod(module: &mut ItemMod) -> syn::Result<ItemForeignMod> {
    let content = match &mut module.content {
        Some(content) => &mut content.1,
        None => {
            return Err(syn::Error::new_spanned(
                module,
                "The module cannot be empty; it must contain an `extern` block.",
            ));
        }
    };

    let extern_block_pos = content
        .iter()
        .position(|item| matches!(item, Item::ForeignMod(_)));

    let index = match extern_block_pos {
        Some(i) => i,
        None => {
            return Err(syn::Error::new_spanned(
                module,
                "Could not find an `extern` block inside the module.",
            ));
        }
    };

    if content
        .iter()
        .skip(index + 1)
        .any(|item| matches!(item, Item::ForeignMod(_)))
    {
        return Err(syn::Error::new_spanned(
            module,
            "Found multiple `extern` blocks in the module. The `api_hook` attribute requires exactly one.",
        ));
    }

    if let Item::ForeignMod(fm) = content.remove(index) {
        Ok(fm)
    } else {
        unreachable!();
    }
}

/// Generates the core API trait, runtime trait, and FFI functions.
fn generate_api_hook_tokens(
    args: ApiHookArgs,
    foreign_mod: ItemForeignMod,
) -> syn::Result<proc_macro2::TokenStream> {
    let hooked_functions: Vec<HookedFunction> = foreign_mod
        .items
        .into_iter()
        .map(|item| match item {
            ForeignItem::Fn(func) => HookedFunction::from_foreign_fn(func, &foreign_mod.abi),
            other => Err(syn::Error::new_spanned(
                other,
                "The `extern` block can only contain function declarations.",
            )),
        })
        .collect::<Result<_, _>>()?;

    let trait_name = args.trait_name;
    let trait_name_str = trait_name.to_string();
    let api_runtime_trait_name = format_ident!("__{}Runtime", trait_name_str);
    let api_dispatcher_struct_name = format_ident!("__{}Dispatcher", trait_name_str);
    let internal_mod_name = format_ident!("__{}_impl", trait_name_str.to_snake_case());

    let trait_methods = hooked_functions
        .iter()
        .map(HookedFunction::to_trait_method_tokens);
    let new_extern_fns = hooked_functions
        .iter()
        .map(|f| f.to_hooked_extern_fn_tokens(&api_dispatcher_struct_name));

    let backend_impl_block = match args.backend_type {
        Some(user_impl_type) => {
            quote! {
                use super::*;

                impl #api_runtime_trait_name for #api_dispatcher_struct_name {
                    type Impl = #user_impl_type;

                    #[inline(always)]
                    fn current() -> &'static Self::Impl {
                        static INSTANCE: #user_impl_type = #user_impl_type;
                        &INSTANCE
                    }
                }
            }
        }
        None => {
            let panic_stub_ident = format_ident!("__{}PanicStub", trait_name_str);
            let panic_impl_methods = hooked_functions
                .iter()
                .map(HookedFunction::to_panic_impl_tokens);
            quote! {
                use super::*;

                pub struct #panic_stub_ident;

                impl #trait_name for #panic_stub_ident {
                    #(#panic_impl_methods)*
                }

                impl #api_runtime_trait_name for #api_dispatcher_struct_name {
                    type Impl = #panic_stub_ident;

                    #[inline(always)]
                    fn current() -> &'static Self::Impl {
                        static INSTANCE: #panic_stub_ident = #panic_stub_ident;
                        &INSTANCE
                    }
                }
            }
        }
    };

    let definitions = quote! {
        pub trait #trait_name: Send + Sync + 'static { #(#trait_methods)* }

        /// Internal trait defining the contract for providing the backend implementation.
        trait #api_runtime_trait_name {
            type Impl: #trait_name;

            fn current() -> &'static Self::Impl;
        }

        /// Internal struct used to dispatch FFI calls to the backend implementation.
        struct #api_dispatcher_struct_name;
    };

    Ok(quote! {
        #definitions

        #(#new_extern_fns)*

        mod #internal_mod_name {
            #backend_impl_block
        }
    })
}

pub fn api_hook_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
    let args = parse_macro_input!(attr as ApiHookArgs);
    let item = parse_macro_input!(item as Item);

    let mut user_mod = match item {
        Item::Mod(module) => module,
        other => {
            return syn::Error::new_spanned(
                other,
                "The `api_hook` attribute can only be applied to a module containing an `extern` block.",
            ).to_compile_error().into();
        }
    };

    let foreign_mod = match find_and_remove_foreign_mod(&mut user_mod) {
        Ok(fm) => fm,
        Err(e) => return e.to_compile_error().into(),
    };

    let generated_tokens = match generate_api_hook_tokens(args, foreign_mod) {
        Ok(tokens) => tokens,
        Err(e) => return e.to_compile_error().into(),
    };

    let other_items = user_mod.content.map(|(_, items)| items).unwrap_or_default();
    let mod_attrs = user_mod.attrs;
    let mod_vis = user_mod.vis;
    let mod_ident = user_mod.ident;
    let mod_unsafety = user_mod.unsafety;

    quote! {
        #(#mod_attrs)*
        #mod_vis #mod_unsafety mod #mod_ident {
            #(#other_items)*

            #generated_tokens
        }
    }
    .into()
}