use std::fs::File;
use std::io::{BufReader, Read};
use rustls::{Certificate, RootCertStore, PrivateKey, ClientConnection, StreamOwned};
use std::sync::Arc;
use std::net::TcpStream;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
use std::io;
use rpassword::prompt_password;
use anyhow::Result;
use clap::Parser;
use std::collections::HashSet;
use x509_parser::prelude::*;
use x509_parser::num_bigint::ToBigInt;
use std::fs::read_to_string;
use x509_parser::public_key::RSAPublicKey;
use x509_parser::der_parser::oid;
use num_bigint::BigUint;
use openssl::pkey::PKey;
mod commands;
use commands::gputrace::GpuTraceConfig;
use commands::gputrace::GpuTraceOptions;
use commands::gputrace::GpuTraceTriggerConfig;
use commands::nputrace::NpuTraceConfig;
use commands::nputrace::NpuTraceOptions;
use commands::nputrace::NpuTraceTriggerConfig;
use commands::npumonitor::NpuMonitorConfig;
use commands::*;
const DYNO_PORT: u16 = 1778;
const MIN_RSA_KEY_LENGTH: u64 = 3072;
#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
struct Opts {
#[arg(long, default_value = "localhost")]
hostname: String,
#[arg(long, default_value_t = DYNO_PORT)]
port: u16,
#[arg(long, required = true)]
certs_dir: String,
#[command(subcommand)]
cmd: Command,
}
const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy", "Communication"];
fn parse_mspti_activity_kinds(src: &str) -> Result<String, String>{
let allowed_values: HashSet<&str> = ALLOWED_VALUES.iter().cloned().collect();
let kinds: Vec<&str> = src.split(',').map(|s| s.trim()).collect();
for kind in &kinds {
if !allowed_values.contains(kind) {
return Err(format!("Invalid MSPTI activity kind: {}, Possible values: {:?}.]", kind, allowed_values));
}
}
Ok(src.to_string())
}
const ALLOWED_HOST_SYSTEM_VALUES: &[&str] = &["cpu", "mem", "disk", "network", "osrt"];
fn parse_host_sys(src: &str) -> Result<String, String>{
if src == "None" {
return Ok(src.to_string());
}
let allowed_host_sys_values: HashSet<&str> = ALLOWED_HOST_SYSTEM_VALUES.iter().cloned().collect();
let host_systems: Vec<&str> = src.split(',').map(|s| s.trim()).collect();
for host_system in &host_systems {
if !allowed_host_sys_values.contains(host_system) {
return Err(format!("Invalid NPU Trace host system: {}, Possible values: {:?}.]", host_system,
allowed_host_sys_values));
}
}
let result = host_systems.join(",");
Ok(result)
}
const INSTANT_START_STEP: i64 = -1;
fn parse_start_step(src: &str) -> Result<i64, String> {
let start_step = src.trim().parse::<i64>().map_err(|e| format!("{}", e))?;
if start_step < INSTANT_START_STEP {
return Err(format!("Must be non-negative integer or {}", INSTANT_START_STEP));
}
Ok(start_step)
}
fn parse_iterations(src: &str) -> Result<i64, String> {
let iterations = src.trim().parse::<i64>().map_err(|e| format!("{}", e))?;
if iterations <= 0 {
return Err("Must be a positive integer".to_string());
}
Ok(iterations)
}
#[derive(Debug, Parser)]
enum Command {
Status,
Version,
Gputrace {
#[arg(long, default_value_t = 0)]
job_id: u64,
#[arg(long, default_value = "0")]
pids: String,
#[arg(long, default_value_t = 500)]
duration_ms: u64,
#[arg(long, default_value_t = -1)]
iterations: i64,
#[arg(long)]
log_file: String,
#[arg(long, default_value_t = 0)]
profile_start_time: u64,
#[arg(long, default_value_t = 1)]
profile_start_iteration_roundup: u64,
#[arg(long, default_value_t = 3)]
process_limit: u32,
#[arg(long)]
record_shapes: bool,
#[arg(long)]
profile_memory: bool,
#[arg(long)]
with_stacks: bool,
#[arg(long)]
with_flops: bool,
#[arg(long)]
with_modules: bool,
},
Nputrace {
#[clap(long, default_value_t = 0)]
job_id: u64,
#[clap(long, default_value = "0")]
pids: String,
#[clap(long, default_value_t = 500)]
duration_ms: u64,
#[clap(long, value_parser = parse_iterations, allow_negative_numbers = true)]
iterations: i64,
#[clap(long)]
log_file: String,
#[clap(long, default_value_t = 0)]
profile_start_time: u64,
#[clap(long, value_parser = parse_start_step, allow_negative_numbers = true)]
start_step: i64,
#[clap(long, default_value_t = 3)]
process_limit: u32,
#[clap(long, action)]
record_shapes: bool,
#[clap(long, action)]
profile_memory: bool,
#[clap(long, action)]
with_stack: bool,
#[clap(long, action)]
with_flops: bool,
#[clap(long, action)]
with_modules: bool,
#[clap(long, value_parser = ["CPU,NPU", "NPU,CPU", "CPU", "NPU"], default_value = "CPU,NPU")]
activities: String,
#[clap(long, value_parser = ["Level0", "Level1", "Level2", "Level_none"], default_value = "Level0")]
profiler_level: String,
#[clap(long, value_parser = ["AiCoreNone", "PipeUtilization", "ArithmeticUtilization", "Memory", "MemoryL0", "ResourceConflictRatio", "MemoryUB", "L2Cache", "MemoryAccess"], default_value = "AiCoreNone")]
aic_metrics: String,
#[clap(long, action)]
analyse: bool,
#[clap(long, action)]
l2_cache: bool,
#[clap(long, action)]
op_attr: bool,
#[clap(long, action)]
msprof_tx: bool,
#[clap(long)]
gc_detect_threshold: Option<f32>,
#[clap(long, value_parser = ["true", "false"], default_value = "true")]
data_simplification: String,
#[clap(long, value_parser = ["Text", "Db"], default_value = "Text")]
export_type: String,
#[clap(long, value_parser = parse_host_sys, default_value = "None")]
host_sys: String,
#[clap(long, action)]
sys_io: bool,
#[clap(long, action)]
sys_interconnection: bool,
#[clap(long)]
mstx_domain_include: Option<String>,
#[clap(long)]
mstx_domain_exclude: Option<String>,
},
NpuMonitor {
#[clap(long, action)]
npu_monitor_start: bool,
#[clap(long, action)]
npu_monitor_stop: bool,
#[clap(long, default_value_t = 60)]
report_interval_s: u32,
#[clap(long, value_parser = parse_mspti_activity_kinds, default_value = "Marker")]
mspti_activity_kind: String,
#[clap(long, default_value = "")]
log_file: String,
},
DcgmPause {
#[clap(long, default_value_t = 300)]
duration_s: i32,
},
DcgmResume,
}
struct ClientConfigPath {
cert_path: PathBuf,
key_path: PathBuf,
ca_cert_path: PathBuf,
}
fn verify_certificate(cert_der: &[u8], is_root_cert: bool) -> Result<()> {
let (_, cert) = X509Certificate::from_der(cert_der)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?;
if cert.tbs_certificate.version != X509Version(2) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Certificate is not X.509v3"
).into());
}
let sig_alg = cert.signature_algorithm.algorithm;
let md2_rsa = oid!(1.2.840.113549.1.1.2);
let md5_rsa = oid!(1.2.840.113549.1.1.4);
let sha1_rsa = oid!(1.2.840.113549.1.1.5);
if sig_alg == md2_rsa || sig_alg == md5_rsa || sig_alg == sha1_rsa {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Certificate uses insecure signature algorithm"
).into());
}
let rsa_sha256 = oid!(1.2.840.113549.1.1.11);
let rsa_sha384 = oid!(1.2.840.113549.1.1.12);
let rsa_sha512 = oid!(1.2.840.113549.1.1.13);
if sig_alg == rsa_sha256 || sig_alg == rsa_sha384 || sig_alg == rsa_sha512 {
if let Ok((_, public_key)) = SubjectPublicKeyInfo::from_der(&cert.tbs_certificate.subject_pki.subject_public_key.data) {
if let Ok((_, rsa_key)) = RSAPublicKey::from_der(&public_key.subject_public_key.data) {
let modulus = BigUint::from_bytes_be(&rsa_key.modulus);
let key_length = modulus.bits();
if key_length < MIN_RSA_KEY_LENGTH {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("RSA key length {} bits is less than required {} bits", key_length, MIN_RSA_KEY_LENGTH)
).into());
}
}
}
}
let mut has_ca_constraint = false;
let mut has_key_usage = false;
let mut has_crl_sign = false;
let mut has_cert_sign = false;
for ext in cert.tbs_certificate.extensions() {
if ext.oid == oid_registry::OID_X509_EXT_BASIC_CONSTRAINTS {
if let Ok((_, constraints)) = BasicConstraints::from_der(ext.value) {
has_ca_constraint = constraints.ca;
} else {
println!("Failed to parse Basic Constraints");
}
} else if ext.oid == oid_registry::OID_X509_EXT_KEY_USAGE {
println!("Found Key Usage extension");
if let Ok((_, usage)) = KeyUsage::from_der(ext.value) {
has_key_usage = true;
has_cert_sign = usage.key_cert_sign();
has_crl_sign = usage.crl_sign();
} else {
println!("Failed to parse Key Usage");
}
}
}
if is_root_cert {
if !has_ca_constraint {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Root certificate must have CA constraint"
).into());
}
if !has_key_usage {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Root certificate must have key usage extension"
).into());
}
if !has_cert_sign {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Root certificate must have certificate signature permission"
).into());
}
if !has_crl_sign {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Root certificate must have CRL signature permission"
).into());
}
} else {
if has_ca_constraint {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client certificate should not have CA constraint"
).into());
}
if !has_key_usage {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client certificate must have key usage extension"
).into());
}
}
let now = chrono::Utc::now();
let not_before = chrono::DateTime::from_timestamp(
cert.tbs_certificate.validity.not_before.timestamp(),
0
).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_before date"))?;
let not_after = chrono::DateTime::from_timestamp(
cert.tbs_certificate.validity.not_after.timestamp(),
0
).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid not_after date"))?;
if now < not_before {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Certificate is not yet valid. Valid from: {}", not_before)
).into());
}
if now > not_after {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Certificate has expired. Expired at: {}", not_after)
).into());
}
Ok(())
}
fn is_cert_revoked(cert_der: &[u8], crl_path: &PathBuf) -> Result<bool> {
let (_, cert) = X509Certificate::from_der(cert_der)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse cert: {:?}", e)))?;
let crl_data = read_to_string(crl_path)?;
let (_, pem) = pem::parse_x509_pem(crl_data.as_bytes())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL PEM: {:?}", e)))?;
let (_, crl) = CertificateRevocationList::from_der(&pem.contents)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to parse CRL: {:?}", e)))?;
let now = chrono::Utc::now();
let crl_not_before = chrono::DateTime::from_timestamp(
crl.tbs_cert_list.this_update.timestamp(),
0
).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL this_update date"))?;
let crl_not_after = if let Some(next_update) = crl.tbs_cert_list.next_update {
chrono::DateTime::from_timestamp(
next_update.timestamp(),
0
).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRL next_update date"))?
} else {
crl_not_before + chrono::Duration::days(365)
};
if now < crl_not_before {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("CRL is not yet valid. Valid from: {}", crl_not_before)
).into());
}
if now > crl_not_after {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("CRL has expired. Expired at: {}", crl_not_after)
).into());
}
let cert_serial = cert.tbs_certificate.serial.to_bigint()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert certificate serial to BigInt"))?;
for revoked in crl.iter_revoked_certificates() {
let revoked_serial = revoked.user_certificate.to_bigint()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Failed to convert revoked certificate serial to BigInt"))?;
if revoked_serial == cert_serial {
return Ok(true);
}
}
Ok(false)
}
enum DynoClient {
Secure(StreamOwned<ClientConnection, TcpStream>),
Insecure(TcpStream),
}
fn create_dyno_client(
host: &str,
port: u16,
certs_dir: &str,
) -> Result<DynoClient> {
if certs_dir == "NO_CERTS" {
println!("Running in no-certificate mode");
create_dyno_client_with_no_certs(host, port)
} else {
println!("Running in certificate mode");
let certs_dir = PathBuf::from(certs_dir);
let config = ClientConfigPath {
cert_path: certs_dir.join("client.crt"),
key_path: certs_dir.join("client.key"),
ca_cert_path: certs_dir.join("ca.crt"),
};
let client = create_dyno_client_with_certs(host, port, &config)?;
Ok(DynoClient::Secure(client))
}
}
fn create_dyno_client_with_no_certs(
host: &str,
port: u16,
) -> Result<DynoClient> {
let addr = (host, port)
.to_socket_addrs()?
.next()
.expect("Failed to connect to the server");
let stream = TcpStream::connect(addr)?;
Ok(DynoClient::Insecure(stream))
}
fn secure_clear_password(password: &mut Vec<u8>) {
if !password.is_empty() {
for byte in password.iter_mut() {
*byte = 0;
}
password.clear();
password.shrink_to_fit();
}
}
fn create_dyno_client_with_certs(
host: &str,
port: u16,
config: &ClientConfigPath,
) -> Result<StreamOwned<ClientConnection, TcpStream>> {
let addr = (host, port)
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(
io::ErrorKind::NotFound,
"Could not resolve the host address"
))?;
let stream = TcpStream::connect(addr)?;
println!("Loading CA cert from: {}", config.ca_cert_path.display());
let mut root_store = RootCertStore::empty();
let ca_file = File::open(&config.ca_cert_path)?;
let mut ca_reader = BufReader::new(ca_file);
let ca_certs = rustls_pemfile::certs(&mut ca_reader)?;
for ca_cert in &ca_certs {
verify_certificate(ca_cert, true)?;
}
for ca_cert in ca_certs {
root_store.add(&Certificate(ca_cert))?;
}
println!("Loading client cert from: {}", config.cert_path.display());
let cert_file = File::open(&config.cert_path)?;
let mut cert_reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut cert_reader)?;
for cert in &certs {
verify_certificate(cert, false)?;
}
let crl_path = config.cert_path.parent().unwrap().join("ca.crl");
if crl_path.exists() {
println!("Checking CRL file: {}", crl_path.display());
for cert in &certs {
match is_cert_revoked(cert, &crl_path) {
Ok(true) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Certificate is revoked"
).into());
}
Ok(false) => {
continue;
}
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("CRL verification failed: {}", e)
).into());
}
}
}
} else {
println!("CRL file does not exist: {}", crl_path.display());
}
let certs = certs.into_iter().map(Certificate).collect();
println!("Loading client key from: {}", config.key_path.display());
let key_file = File::open(&config.key_path)?;
let mut key_reader = BufReader::new(key_file);
let mut key_data = Vec::new();
key_reader.read_to_end(&mut key_data)?;
let key_str = String::from_utf8_lossy(&key_data);
let is_encrypted = key_str.contains("ENCRYPTED");
let keys = if is_encrypted {
let mut password = prompt_password("Please enter the certificate password: ")?.into_bytes();
let pkey = PKey::private_key_from_pem_passphrase(&key_data, &password)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to decrypt private key: {}", e)))?;
secure_clear_password(&mut password);
vec![pkey.private_key_to_der()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("Failed to convert private key to DER: {}", e)))?]
} else {
let mut key_reader = BufReader::new(File::open(&config.key_path)?);
rustls_pemfile::pkcs8_private_keys(&mut key_reader)?
};
if keys.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"No private key found in the key file"
).into());
}
let key = PrivateKey(keys[0].clone());
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_client_auth_cert(certs, key)?;
let server_name = rustls::ServerName::try_from(host)
.map_err(|e| io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid hostname: {}", e)
))?;
let conn = rustls::ClientConnection::new(
Arc::new(config),
server_name
)?;
Ok(StreamOwned::new(conn, stream))
}
fn main() -> Result<()> {
let Opts {
hostname,
port,
certs_dir,
cmd,
} = Opts::parse();
let client = create_dyno_client(&hostname, port, &certs_dir)
.expect("Couldn't connect to the server...");
match cmd {
Command::Status => status::run_status(client),
Command::Version => version::run_version(client),
Command::Gputrace {
job_id,
pids,
log_file,
duration_ms,
iterations,
profile_start_time,
profile_start_iteration_roundup,
process_limit,
record_shapes,
profile_memory,
with_stacks,
with_flops,
with_modules,
} => {
let trigger_config = if iterations > 0 {
GpuTraceTriggerConfig::IterationBased {
profile_start_iteration_roundup,
iterations,
}
} else {
GpuTraceTriggerConfig::DurationBased {
profile_start_time,
duration_ms,
}
};
let trace_options = GpuTraceOptions {
record_shapes,
profile_memory,
with_stacks,
with_flops,
with_modules,
};
let trace_config = GpuTraceConfig {
log_file,
trigger_config,
trace_options,
};
gputrace::run_gputrace(client, job_id, &pids, process_limit, trace_config)
}
Command::Nputrace {
job_id,
pids,
log_file,
duration_ms,
iterations,
profile_start_time,
start_step,
process_limit,
record_shapes,
profile_memory,
with_stack,
with_flops,
with_modules,
activities,
analyse,
profiler_level,
aic_metrics,
l2_cache,
op_attr,
msprof_tx,
gc_detect_threshold,
data_simplification,
export_type,
host_sys,
sys_io,
sys_interconnection,
mstx_domain_include,
mstx_domain_exclude,
} => {
let trigger_config = if iterations > 0 {
NpuTraceTriggerConfig::IterationBased {
start_step,
iterations,
}
} else {
NpuTraceTriggerConfig::DurationBased {
profile_start_time,
duration_ms,
}
};
let trace_options = NpuTraceOptions {
record_shapes,
profile_memory,
with_stack,
with_flops,
with_modules,
activities,
analyse,
profiler_level,
aic_metrics,
l2_cache,
op_attr,
msprof_tx,
gc_detect_threshold,
data_simplification,
export_type,
host_sys,
sys_io,
sys_interconnection,
mstx_domain_include,
mstx_domain_exclude,
};
let trace_config = NpuTraceConfig {
log_file,
trigger_config,
trace_options,
};
nputrace::run_nputrace(client, job_id, &pids, process_limit, trace_config)
}
Command::NpuMonitor {
npu_monitor_start,
npu_monitor_stop,
report_interval_s,
mspti_activity_kind,
log_file,
} => {
let npu_mon_config = NpuMonitorConfig {
npu_monitor_start,
npu_monitor_stop,
report_interval_s,
mspti_activity_kind,
log_file
};
npumonitor::run_npumonitor(client, npu_mon_config)
}
Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(client, duration_s),
Command::DcgmResume => dcgm::run_dcgm_resume(client),
}
}