// SPDX-License-Identifier: Mulan PSL v2
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This software is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *         http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

use std::{env, fs::File, path::Path, time::Duration};

use anyhow::{bail, Context};
use tracing::{debug, error, trace};
use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};

use xgpu_common::ipc::{
    error::IpcError,
    framer::Framer,
    message::{Request, Response},
    peer::Client,
    transport::{Transport, TransportError},
};

mod api;
mod handler;

use api::Api;

const DEFAULT_LOG_LEVEL: &str = "ERROR";

fn init_tracing(file_path: Option<&str>) {
    let log_level = match env::var("RUST_LOG") {
        Ok(level) => level,
        Err(_) => DEFAULT_LOG_LEVEL.to_string(),
    };

    let filter = EnvFilter::new(log_level).add_directive("xgpu_common::ipc=error".parse().unwrap()); // ipc mod

    let mut layers = Vec::new();

    if let Some(path) = file_path {
        let file = File::create(Path::new(path)).expect("create log file failed");
        let file_layer = fmt::layer()
            .with_writer(file)
            .with_ansi(false)
            .with_filter(filter);
        layers.push(file_layer.boxed());
    } else {
        let stdout_layer = tracing_subscriber::fmt::layer()
            .with_ansi(true)
            .with_filter(filter.clone());
        layers.push(stdout_layer.boxed());
    }

    Registry::default().with(layers).init();
}

fn main() {
    init_tracing(None); // -> stdout

    let args: Vec<String> = env::args().collect();
    if args.len() < 2 {
        eprintln!("Usage: {} <shmem addr>", args[0]);
        std::process::exit(1);
    }

    let pid = args[1].clone();
    let handle = std::thread::spawn(|| self::server_main(pid));
    if let Err(e) = handle.join().expect("Failed to join server thread") {
        error!("{:?}", e);
    }
}

fn server_main(pid: String) -> anyhow::Result<()> {
    const BUFFER_SIZE: usize = 512 * 1024 * 1024;
    const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);

    let transport = Transport::new(BUFFER_SIZE);
    let framer = Framer::new(BUFFER_SIZE);

    let client = match Client::connect(&transport, framer, &pid, CONNECT_TIMEOUT) {
        Ok(c) => c,
        Err(IpcError::TransportError(TransportError::ConnectTimeout)) => {
            trace!("'{}': Connection timeout", pid);
            return Ok(());
        }
        Err(e) => {
            bail!("{}", e);
        }
    };
    debug!("{:#?}", client);

    loop {
        let mut request = match client.receive_message::<Request>() {
            Ok(Some(r)) => r,
            Ok(None) => continue,
            Err(IpcError::TransportError(TransportError::ConnectionClosed)) => {
                break;
            }
            Err(e) => bail!("{}", e),
        };
        debug!(
            "[Server] Received request: request_id={}, method_id={}, argc={}",
            request.request_id(),
            request.method_id(),
            request.argc()
        );

        let ret = Api::invoke(request.method_id(), request.args_mut()).with_context(|| {
            format!(
                "Failed to handle request, request_id={}, method_id={}",
                request.request_id(),
                request.method_id()
            )
        })?;

        let response = Response::with_request(&request, ret);
        client
            .send_message(&response)
            .context("Failed to send response")?;
    }

    Ok(())
}