// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::ffi::{c_char, c_int, c_void, OsStr};

use once_cell::sync::Lazy;
use tracing::{debug, trace};

use xgpu_common::sys::dynlib;

use crate::hook::dl::DlApi;

mod interceptor {
    use std::{
        collections::HashSet,
        ffi::{CStr, CString, OsString},
        os::unix::ffi::OsStrExt,
        path::{Path, PathBuf},
        ptr,
    };

    use indexmap::IndexMap;
    use parking_lot::RwLock;
    use thiserror::Error;

    use super::*;

    #[derive(Debug, Error)]
    pub enum Error {
        #[error("Failed to resolve current library path")]
        CurrentLibraryPathResolveError,

        #[error(transparent)]
        DlError(#[from] dynlib::DlError),
    }

    #[derive(Debug)]
    pub struct DlInterceptor {
        self_path: CString,
        hook_files: HashSet<OsString>,
        hooked_libs: RwLock<IndexMap<PathBuf, dynlib::DlHandle>>,
    }

    impl DlInterceptor {
        #[inline]
        fn normalize_library_name(path: &Path) -> &OsStr {
            const LIB_EXT_BYTES: &[u8] = b".so";
            const LIB_EXT_LEN: usize = LIB_EXT_BYTES.len();

            let file_name = path.file_name().unwrap_or(path.as_os_str());
            let name_bytes = file_name.as_bytes();

            name_bytes
                .windows(LIB_EXT_LEN)
                .rposition(|w| w == LIB_EXT_BYTES)
                .map(|pos| OsStr::from_bytes(&name_bytes[..pos + LIB_EXT_LEN]))
                .unwrap_or(file_name)
        }

        pub fn new<I, S>(file_names: I) -> Result<Self, Error>
        where
            I: IntoIterator<Item = S>,
            S: AsRef<OsStr>,
        {
            static LIB_MARKER: () = (); // A trick to get an address within our own library.

            let dl_info = dynlib::dladdr(ptr::from_ref(&LIB_MARKER).cast())?;
            let self_path = dl_info
                .file_name()
                .ok_or(Error::CurrentLibraryPathResolveError)?
                .to_owned();

            let hook_files = file_names
                .into_iter()
                .map(|s| s.as_ref().to_owned())
                .collect();

            Ok(Self {
                self_path,
                hook_files,
                hooked_libs: RwLock::new(IndexMap::new()),
            })
        }

        pub fn dlopen(&self, filename: *const c_char, flag: c_int) -> *mut c_void {
            const WSL_DRIVER_PATH: &str = "/usr/lib/wsl/drivers"; // TODO: need to fix this

            if filename.is_null() {
                return unsafe { dynlib::ffi::dlopen(filename, flag) };
            }

            let filename = unsafe { CStr::from_ptr(filename) };
            let lib_path = Path::new(OsStr::from_bytes(filename.to_bytes()));
            let lib_name = Self::normalize_library_name(lib_path);

            debug!("filename:{:?}", filename);
            debug!("lib_path:{:?}", lib_path);
            debug!("lib_name:{:?}", lib_name);

            let is_wsl_driver = lib_path.starts_with(WSL_DRIVER_PATH);
            let filename = if self.hook_files.contains(lib_name) && !is_wsl_driver {
                debug!(
                    "[dlopen] Hooked '{}' -> '{}'.",
                    lib_path.display(),
                    self.self_path.to_string_lossy()
                );
                if !self.hooked_libs.read().contains_key(lib_path) {
                    match dynlib::dlopen(lib_path, dynlib::DlOpenFlag::LAZY) {
                        Ok(handle) => {
                            self.hooked_libs
                                .write()
                                .entry(lib_path.to_owned())
                                .or_insert(handle);
                        }
                        Err(_) => {
                            return ptr::null_mut();
                        }
                    }
                }
                self.self_path.as_ptr()
            } else {
                trace!("[dlopen] Bypassing '{}'.", lib_path.to_string_lossy());
                filename.as_ptr()
            };

            unsafe { dynlib::ffi::dlopen(filename, flag) }
        }

        pub fn find_original_symbol<S>(&self, name: S) -> Option<dynlib::DlSymbol>
        where
            S: AsRef<OsStr>,
        {
            let loaded_libs = self.hooked_libs.read();

            // Reverse search, find from latest loaded library
            for handle in loaded_libs.values().rev() {
                let result = handle.find_symbol(&name).ok().flatten();
                if result.is_some() {
                    return result;
                }
            }

            // Fallback to dlsym
            let result = dynlib::dlsym(&dynlib::NEXT_HANDLE, name).ok().flatten();
            if result.is_some() {
                return result;
            }

            None
        }
    }
}

static DL_INTERCEPTOR: Lazy<interceptor::DlInterceptor> = Lazy::new(|| {
    const FILE_NAMES: &[&str] = &[
        "libcuda.so",
        "libcudart.so",
        "libnvidia-ml.so",
        "libnccl.so",
        "libcublas.so",
        "libcublasLt.so",
    ];

    let interceptor = interceptor::DlInterceptor::new(FILE_NAMES)
        .expect("FATAL: Failed to initialize dynamic library interceptor.");
    debug!(
        "Dynamic library interceptor initialized: {:#?}",
        interceptor
    );

    interceptor
});

pub struct DlApiImpl;

impl DlApi for DlApiImpl {
    fn dlopen(&self, filename: *const c_char, flag: c_int) -> *mut c_void {
        DL_INTERCEPTOR.dlopen(filename, flag)
    }
}

#[inline]
pub fn find_original_symbol<S: AsRef<OsStr>>(name: S) -> Option<dynlib::DlSymbol> {
    DL_INTERCEPTOR.find_original_symbol(name)
}