import shlex
import subprocess
import logging
from abc import ABC, abstractmethod
from argparse import Namespace
from dataclasses import dataclass
from typing import Any, Optional, Union
logger = logging.getLogger("dsbench")
@dataclass
class BenchArgs:
name: str
start_time: str
cwd: str
log_dir: str
result_csv_file: str
args: Namespace
@dataclass
class BenchRemoteInfo:
host: str
username: str
ssh_config_path: str
ssh_port: int
@dataclass
class BenchCommandOutput:
stdout: str
stderr: str
class BenchTask(ABC):
@abstractmethod
def run(self, handler=None):
"""Executes the main logic associated with this task."""
pass
class BenchCommandTask(BenchTask):
command: str
env: dict
remote: Union[BenchRemoteInfo, None]
output: Union[BenchCommandOutput, None]
worker_address: Optional[str]
def __init__(
self,
command: str,
env: Optional[dict[str, str]],
remote: Optional[BenchRemoteInfo],
worker_address: Optional[str] = None,
):
self.command = command
self.env = env
self.remote = remote
self.worker_address = worker_address
@classmethod
def concat_args(cls, prefix: str, args: dict[str, Any]) -> str:
"""Constructs a command-line string from a dictionary of arguments."""
args_list = []
for key, val in args.items():
if val is None:
continue
if isinstance(val, str):
if len(val) == 0:
continue
args_list.append(f"{prefix}{key}={shlex.quote(val)}")
elif isinstance(val, bool):
if val:
args_list.append(f"{prefix}{key}")
elif isinstance(val, int):
args_list.append(f"{prefix}{key}={val}")
else:
raise RuntimeError(f"unknown type of key {key}")
return " ".join(args_list)
def execute_command_locally(self):
"""Executes a command on the local machine and captures its output."""
env = self.env or {}
logger.debug("Executing local command: %s", self.command)
logger.debug("Local command environment: %s", env)
process = subprocess.Popen(
shlex.split(self.command),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
stdout, stderr = process.communicate(timeout=300)
except subprocess.TimeoutExpired:
process.kill()
error_msg = f"Local command execution timed out: '{self.command}'"
logger.error(error_msg)
raise RuntimeError(error_msg) from None
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
logger.debug("Local command stdout: %s", stdout_str)
logger.debug("Local command stderr: %s", stderr_str)
logger.debug("Local command return code: %d", process.returncode)
if process.returncode != 0 and stderr_str.strip():
logger.error("Local command failed with return code %d. Error: %s", process.returncode, stderr_str.strip())
return stdout_str, stderr_str
def execute_command_remotely(self, remote: BenchRemoteInfo):
"""Executes a command on a remote machine via SSH and captures its output."""
logger.debug("Executing remote command on %s", remote.host)
logger.debug("Original remote command: %s", self.command)
logger.debug("Remote command environment: %s", self.env)
new_command = self.command
if self.env is not None:
new_command = f"{BenchCommandTask.concat_args('', self.env)} {new_command}"
ssh_command = (
f"ssh -q {remote.username}@{remote.host} "
f"-i {shlex.quote(remote.ssh_config_path)} {shlex.quote(new_command)}"
)
logger.debug("SSH command: %s", ssh_command)
process = subprocess.Popen(
shlex.split(ssh_command),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
stdout, stderr = process.communicate(timeout=300)
except subprocess.TimeoutExpired:
process.kill()
error_msg = (
f"Remote command execution timed out after 600 seconds. "
f"Process has been killed. SSH command: '{ssh_command}'"
)
logger.error(error_msg)
raise RuntimeError(error_msg) from None
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
logger.debug("Remote command stdout: %s", stdout_str)
logger.debug("Remote command stderr: %s", stderr_str)
logger.debug("Remote command return code: %d", process.returncode)
if process.returncode != 0 and stderr_str.strip():
logger.error("Remote command failed with return code %d. Error: %s", process.returncode, stderr_str.strip())
return stdout_str, stderr_str
def get_output(self) -> Union[BenchCommandOutput, None]:
"""Retrieves the output of the executed command."""
return self.output
def execute(self):
"""Executes the command either locally or remotely."""
logger.debug("Running command task for address: %s", self.worker_address)
if self.remote is None:
stdout, stderr = self.execute_command_locally()
else:
stdout, stderr = self.execute_command_remotely(self.remote)
self.output = BenchCommandOutput(stdout, stderr)
logger.debug("Command task completed for address: %s", self.worker_address)
def run(self, handler=None):
if handler:
handler.before_execute(self)
self.execute()
if handler:
handler.after_execute(self)
class BenchParallelCommandTask(BenchTask):
def __init__(self, tasks: list[BenchCommandTask]):
self.tasks = tasks
def run(self, handler=None):
import concurrent.futures
if not self.tasks:
return
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.tasks)) as executor:
futures = []
for task in self.tasks:
if handler:
handler.before_execute(task)
futures.append(executor.submit(task.run))
for future, task in zip(futures, self.tasks):
future.result()
if handler:
handler.after_execute(task)