import json
import logging
import os
import subprocess
from typing import Any, Optional, Union
from yr.datasystem.cli.benchmark.task import (
BenchCommandOutput,
BenchCommandTask,
BenchRemoteInfo,
)
logger = logging.getLogger("dsbench")
def _get_local_ips() -> list[str]:
"""
Safely retrieves a list of all valid local IP addresses.
It's used to determine if a target host is local or remote.
"""
local_ips = {"127.0.0.1", "::1"}
try:
local_ip_raw = (
subprocess.check_output(["hostname", "-I"], stderr=subprocess.PIPE)
.decode()
.strip()
)
if local_ip_raw:
ips_from_command = [
ip.split("%")[0].split("/")[0]
for ip in local_ip_raw.split()
]
local_ips.update(ips_from_command)
except FileNotFoundError:
logger.warning(
"warning: `hostname` command not found. Cannot auto-detect local IPs."
)
except subprocess.CalledProcessError as e:
logger.warning(
f"warning: `hostname -I` command failed. Cannot auto-detect local IPs. Error: {e.stderr.decode().strip()}"
)
except Exception as e:
logger.warning(f"warning: Failed to get local IPs automatically. Error: {e}")
local_ips.discard("")
local_ips.discard(None)
return list(local_ips)
class Executor:
"""
A singleton class to manage SSH configurations and execute BenchCommandTasks.
It replaces direct calls to subprocess with a more structured approach.
"""
_class_instance: Optional["Executor"] = None
def __new__(cls, *args, **kwargs):
if not cls._class_instance:
cls._class_instance = super(Executor, cls).__new__(cls)
cls._class_instance.initialized = False
return cls._class_instance
def __init__(self):
if self.initialized:
return
self.initialized = True
self.pkg_location_cache: dict[str, Optional[str]] = {}
self.dsbench_cpp_permissions_cache: dict[str, bool] = {}
self.local_ips_cache: list[str] = _get_local_ips()
@classmethod
def get_instance(cls) -> "Executor":
"""Returns the singleton instance of the Executor class."""
if not cls._class_instance:
cls._class_instance = Executor()
return cls._class_instance
def get_remote_info(self, target_address: str) -> Optional[BenchRemoteInfo]:
"""
**INTERNAL** helper method to retrieve SSH information.
Uses default SSH parameters: current user, default identity file (~/.ssh/id_rsa), and port 22.
"""
import getpass
target_host = target_address.split(":")[0]
if target_host in self.local_ips_cache:
return None
current_user = getpass.getuser()
identity_file_path = "~/.ssh/id_rsa"
ssh_port = 22
expanded_identity_file = os.path.expanduser(identity_file_path)
if not os.path.isfile(expanded_identity_file):
logger.warning(
f"warning: Identity file '{expanded_identity_file}' for target '{target_address}' does not exist."
)
return None
return BenchRemoteInfo(
host=target_host,
username=current_user,
ssh_config_path=expanded_identity_file,
ssh_port=ssh_port,
)
def get_datasystem_pkg_location(self, worker_address: str) -> Union[str, None]:
"""Retrieves the installation path of openyuanrong-datasystem."""
if worker_address in self.pkg_location_cache:
return self.pkg_location_cache[worker_address]
pip_show_result = self.execute(
'bash -l -c "pip show openyuanrong-datasystem"', worker_address
)
location = None
if isinstance(pip_show_result, BenchCommandOutput):
stdout = pip_show_result.stdout.strip()
if not stdout:
logger.error(
f" [DEBUG] 'pip show' on {worker_address} executed successfully but had no output; "
f"'openyuanrong-datasystem' may not be installed."
)
self.pkg_location_cache[worker_address] = None
return None
for line in stdout.split("\n"):
if line.startswith("Location:"):
location = line.split(":", 1)[1].strip()
break
if not location:
logger.error(
f" [DEBUG] 'Location:' field not found in 'pip show' output from {worker_address}."
)
self.pkg_location_cache[worker_address] = None
return None
elif isinstance(pip_show_result, str):
logger.error(
f" [DEBUG] Failed to get openyuanrong-datasystem location from {worker_address}: {pip_show_result}"
)
self.pkg_location_cache[worker_address] = None
return None
self.pkg_location_cache[worker_address] = location
return location
def ensure_dsbench_cpp_executable(self, worker_address: str) -> Union[str, None]:
"""Ensures dsbench_cpp has execute permissions on the specified worker.
Only sets permissions once per worker address.
Returns:
str: The full path to dsbench_cpp if successful, None otherwise.
"""
if worker_address in self.dsbench_cpp_permissions_cache:
if self.dsbench_cpp_permissions_cache[worker_address]:
location = self.get_datasystem_pkg_location(worker_address)
if location:
return f"{location}/yr/datasystem/dsbench_cpp"
return None
else:
return None
datasystem_location = self.get_datasystem_pkg_location(worker_address)
if not datasystem_location:
logger.error(
f" [DEBUG] Could not find openyuanrong-datasystem location on {worker_address}."
)
self.dsbench_cpp_permissions_cache[worker_address] = False
return None
dsbench_cpp_executable = f"{datasystem_location}/yr/datasystem/dsbench_cpp"
chmod_command = f"chmod +x {dsbench_cpp_executable}"
chmod_result = self.execute(chmod_command, worker_address, env=None)
if isinstance(chmod_result, str):
logger.error(
f" [DEBUG] Failed to set execute permissions for dsbench_cpp on {worker_address}: {chmod_result}"
)
self.dsbench_cpp_permissions_cache[worker_address] = False
return None
self.dsbench_cpp_permissions_cache[worker_address] = True
logger.debug(
" [DEBUG] Successfully set execute permissions for dsbench_cpp on %s.", worker_address
)
return dsbench_cpp_executable
def execute(
self, command_str: str, target_address: str, env=None
) -> Union[BenchCommandOutput, str]:
"""
Public interface to execute a command.
It intelligently decides whether to run the command locally or remotely.
"""
is_local = target_address.split(":")[0] in self.local_ips_cache
remote_info = self.get_remote_info(target_address)
if not is_local and not remote_info:
return f"Error: SSH configuration for '{target_address}' is missing or invalid."
task = BenchCommandTask(command=command_str, env=env, remote=remote_info)
task.run()
if task.output is None:
return f"Error: Execution of '{command_str}' on '{target_address}' did not produce an output object."
return task.output
executor = Executor.get_instance()