use std::{env, mem::transmute, path::PathBuf, sync::LazyLock};
mod db;
mod doctor;
mod model;
mod pretrain;
mod schedule;
mod affinity;
mod onnx;
use affinity::try_apply_cpu_affinity;
use db::execute_db;
use doctor::doctor_main;
use ek_base::config::get_ek_settings_base;
use ek_computation::{controller::controller_main, worker::worker_main};
use env_logger::fmt::default_kv_format;
use opentelemetry::{
KeyValue, propagation::TextMapCompositePropagator, trace::TracerProvider as _,
};
use std::io::Write;
use tokio::runtime::Runtime;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use ek_db::weight_srv;
use clap::{Parser, Subcommand};
use model::execute_model;
use opentelemetry_sdk::{
Resource,
propagation::{BaggagePropagator, TraceContextPropagator},
trace::{RandomIdGenerator, Sampler, SdkTracerProvider},
};
use opentelemetry_semantic_conventions::{
SCHEMA_URL,
resource::{DEPLOYMENT_ENVIRONMENT_NAME, SERVICE_VERSION},
};
use pretrain::{PretrainCommand, execute_pretrain};
use schedule::execute_schedule;
use tracing::Level;
#[derive(Subcommand, Debug)]
enum Command {
#[command(about = "check the environment")]
Doctor {},
#[command(about = "run expert-kit worker")]
Worker {},
#[command(about = "run expert-kit controller")]
Controller {},
#[command(about = "run expert-kit weight server")]
WeightServer {
#[arg(long, default_value_t = ("0.0.0.0").to_string())]
host: String,
#[arg(short, long, default_value_t = 6543)]
port: u16,
#[arg(long)]
model: Vec<PathBuf>,
},
#[command(about = "safetensor pretrain weight manipulation")]
Pretrain {
#[command(subcommand)]
command: PretrainCommand,
},
#[command(about = "low-level db operations")]
DB {
#[command(subcommand)]
command: db::DBCommand,
},
#[command(about = "model operations")]
Model {
#[command(subcommand)]
command: model::ModelCommand,
},
#[command(about = "schedule operations")]
Schedule {
#[command(subcommand)]
command: schedule::ScheduleCommand,
},
#[command(about = "onnx operations")]
Onnx {
#[command(subcommand)]
command: onnx::OnnxCommand,
},
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct RootCli {
#[arg(long, default_value_t = false)]
debug: bool,
#[arg(long, global = true)]
config: Option<String>,
#[command(subcommand)]
command: Command,
}
fn init_log() {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.format_timestamp_millis()
.write_style(env_logger::WriteStyle::Auto)
.target(env_logger::Target::Stderr)
.format(|buf, record| {
let level_color = buf.default_level_style(record.level());
let timestamp = buf.timestamp();
let level = record.level();
let kv = record.key_values();
let _ = write!(
buf,
"<{level_color}{level}{level_color:#}>({timestamp}) {} ",
record.args(),
);
default_kv_format(buf, kv).unwrap();
writeln!(buf).unwrap();
Ok(())
})
.init();
}
fn resource(cmd: &'static str) -> Resource {
Resource::builder()
.with_service_name(cmd)
.with_schema_url(
[
KeyValue::new(SERVICE_VERSION, env!("CARGO_PKG_VERSION")),
KeyValue::new(DEPLOYMENT_ENVIRONMENT_NAME, "develop"),
],
SCHEMA_URL,
)
.build()
}
fn init_tracer_provider(svc_name: &'static str) -> SdkTracerProvider {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_tonic()
.build()
.unwrap();
let provider = SdkTracerProvider::builder()
.with_sampler(Sampler::AlwaysOn)
.with_id_generator(RandomIdGenerator::default())
.with_resource(resource(svc_name))
.with_batch_exporter(exporter)
.build();
let baggage_propagator = BaggagePropagator::new();
let trace_context_propagator = TraceContextPropagator::new();
let composite_propagator = TextMapCompositePropagator::new(vec![
Box::new(baggage_propagator),
Box::new(trace_context_propagator),
]);
opentelemetry::global::set_text_map_propagator(composite_propagator);
provider
}
fn init_tracing_subscriber(svc_name: &'static str) {
let tracer_provider = init_tracer_provider(svc_name);
let tracer = tracer_provider.tracer("tracing-otel-subscriber");
tracing_subscriber::registry()
.with(tracing_subscriber::filter::LevelFilter::from_level(
Level::INFO,
))
.with(tracing_opentelemetry::layer().with_tracer(tracer))
.init();
}
fn get_command_name(cmd: &Command) -> &'static str {
match cmd {
Command::Worker {} => "worker",
Command::Controller {} => "controller",
_ => "others",
}
}
const DEFAULT_THREAD_NUM: usize = 6;
static WORKER_THREAD_NUM: LazyLock<usize> = LazyLock::new(|| {
env::var("EK_WORKER_THREADS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1)
});
fn init_tokio_runtime(command: &Command) -> Result<Runtime, std::io::Error> {
match command {
Command::Worker {} => {
let settings = ek_base::config::get_ek_settings();
if let Err(e) = try_apply_cpu_affinity(&settings.worker) {
log::warn!("Failed to apply CPU affinity before runtime creation: {e}");
} else {
log::debug!("✅ CPU affinity applied before Tokio runtime creation");
}
let worker_threads = if let Some(advanced) = &settings.worker.advanced {
if let Some(cpu_config) = &advanced.cpu_affinity {
cpu_config
.cores
.as_ref()
.map(|cores| cores.len())
.unwrap_or_else(|| DEFAULT_THREAD_NUM)
} else {
DEFAULT_THREAD_NUM
}
} else {
DEFAULT_THREAD_NUM
};
log::info!("Creating Tokio runtime with {worker_threads} worker threads");
tokio::runtime::Builder::new_multi_thread()
.worker_threads(worker_threads)
.max_blocking_threads(*WORKER_THREAD_NUM)
.enable_all()
.build()
}
_ => {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(DEFAULT_THREAD_NUM)
.enable_all()
.build()
}
}
}
fn main() {
let cli = RootCli::parse();
if cli.debug {
unsafe { std::env::set_var("RUST_LOG", "debug") };
}
let command_name = get_command_name(&cli.command);
let mut config_src = vec![];
if let Ok(path) = std::env::var("EK_CONFIG") {
config_src.push(path);
}
if let Some(path) = cli.config {
config_src.push(path.to_string());
}
get_ek_settings_base(
&config_src
.as_slice()
.iter()
.map(|x| x.as_str())
.collect::<Vec<_>>(),
);
log::info!("config source: {config_src:?}");
let settings = ek_base::config::get_ek_settings();
log::info!("settings: {settings:?}");
init_log();
let tokio_rt = match init_tokio_runtime(&cli.command) {
Ok(rt) => rt,
Err(e) => {
eprintln!("Failed to create Tokio runtime: {e}");
std::process::exit(1);
}
};
let res = tokio_rt.block_on(async {
init_tracing_subscriber(command_name);
match cli.command {
Command::Onnx { command } => onnx::execute_onnx(command).await,
Command::Pretrain { command } => execute_pretrain(command).await,
Command::Worker {} => worker_main().await,
Command::Controller {} => controller_main().await,
Command::Doctor {} => doctor_main().await,
Command::WeightServer { host, port, model } => {
let model: &[PathBuf] = unsafe { transmute(model.as_slice()) };
weight_srv::server::listen(model, (host, port)).await
}
Command::DB { command } => execute_db(command).await,
Command::Model { command } => execute_model(command).await,
Command::Schedule { command } => execute_schedule(command).await,
}
});
if let Err(e) = res {
eprintln!("Error: {e}");
std::process::exit(1);
}
}