* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This software is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
use std::{env, fs::File, path::Path, time::Duration};
use anyhow::{bail, Context};
use tracing::{debug, error, trace};
use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};
use xgpu_common::ipc::{
error::IpcError,
framer::Framer,
message::{Request, Response},
peer::Client,
transport::{Transport, TransportError},
};
mod api;
mod handler;
use api::Api;
const DEFAULT_LOG_LEVEL: &str = "ERROR";
fn init_tracing(file_path: Option<&str>) {
let log_level = match env::var("RUST_LOG") {
Ok(level) => level,
Err(_) => DEFAULT_LOG_LEVEL.to_string(),
};
let filter = EnvFilter::new(log_level).add_directive("xgpu_common::ipc=error".parse().unwrap());
let mut layers = Vec::new();
if let Some(path) = file_path {
let file = File::create(Path::new(path)).expect("create log file failed");
let file_layer = fmt::layer()
.with_writer(file)
.with_ansi(false)
.with_filter(filter);
layers.push(file_layer.boxed());
} else {
let stdout_layer = tracing_subscriber::fmt::layer()
.with_ansi(true)
.with_filter(filter.clone());
layers.push(stdout_layer.boxed());
}
Registry::default().with(layers).init();
}
fn main() {
init_tracing(None);
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <shmem addr>", args[0]);
std::process::exit(1);
}
let pid = args[1].clone();
let handle = std::thread::spawn(|| self::server_main(pid));
if let Err(e) = handle.join().expect("Failed to join server thread") {
error!("{:?}", e);
}
}
fn server_main(pid: String) -> anyhow::Result<()> {
const BUFFER_SIZE: usize = 512 * 1024 * 1024;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
let transport = Transport::new(BUFFER_SIZE);
let framer = Framer::new(BUFFER_SIZE);
let client = match Client::connect(&transport, framer, &pid, CONNECT_TIMEOUT) {
Ok(c) => c,
Err(IpcError::TransportError(TransportError::ConnectTimeout)) => {
trace!("'{}': Connection timeout", pid);
return Ok(());
}
Err(e) => {
bail!("{}", e);
}
};
debug!("{:#?}", client);
loop {
let mut request = match client.receive_message::<Request>() {
Ok(Some(r)) => r,
Ok(None) => continue,
Err(IpcError::TransportError(TransportError::ConnectionClosed)) => {
break;
}
Err(e) => bail!("{}", e),
};
debug!(
"[Server] Received request: request_id={}, method_id={}, argc={}",
request.request_id(),
request.method_id(),
request.argc()
);
let ret = Api::invoke(request.method_id(), request.args_mut()).with_context(|| {
format!(
"Failed to handle request, request_id={}, method_id={}",
request.request_id(),
request.method_id()
)
})?;
let response = Response::with_request(&request, ret);
client
.send_message(&response)
.context("Failed to send response")?;
}
Ok(())
}