import argparse
import logging
import re
from tensor_cast.device import DeviceProfile
LOG_LEVELS = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
LOG_FORMAT = "[%(levelname)s] [%(name)s] %(message)s"
def check_device_targets(args: argparse.Namespace, logger: logging.Logger) -> list[str] | None:
"""Validate ``--device``: default if omitted, de-dupe, reject invalid names, check comm grid."""
profiles = DeviceProfile.all_device_profiles
if not profiles:
logger.error(
"No device profiles are registered. Import tensor_cast.device_profiles before defining CLI defaults."
)
return None
if not args.device:
args.device = ["TEST_DEVICE"]
logger.info("No --device specified; using default profile %r.", args.device[0])
targets = list(dict.fromkeys(args.device))
blank = [name for name in targets if not str(name).strip()]
if blank:
logger.error("Empty --device name is not allowed.")
return None
unknown = [name for name in targets if name not in profiles]
if unknown:
logger.error(
"Unknown --device name(s): %s. Valid profiles: %s",
", ".join(repr(name) for name in unknown),
", ".join(sorted(profiles.keys())),
)
return None
for name in targets:
grid_n = profiles[name].comm_grid.grid.nelement()
if grid_n < args.num_devices:
logger.error(
"Device profile %r cannot model num_devices=%s (communication grid size is %s).",
name,
args.num_devices,
grid_n,
)
return None
return targets
def check_positive_integer(value):
try:
value = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid integer value: {value!r}") from None
if value <= 0:
raise argparse.ArgumentTypeError(f"{value!r} is not a positive integer")
return value
def check_non_negative_integer(value):
try:
value = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid integer value: {value!r}") from None
if value < 0:
raise argparse.ArgumentTypeError(f"{value!r} is not a non-negative integer")
return value
def check_prefix_cache_hit_rate(value):
try:
value = float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid float value for prefix cache hit rate: {value!r}") from None
if not 0 <= value < 1:
raise argparse.ArgumentTypeError(f"{value!r} is not in the valid range [0, 1)")
return value
def parse_int_range(value: str, name: str) -> tuple[int, int]:
"""Parse a range string in the form 'start,end'.
Semantics:
- Surrounding spaces are allowed around both numbers.
- Both values must be integers and non-negative.
- `end` must be greater than or equal to `start`.
Args:
value: Raw CLI string, for example '11,45' or ' 0 , 54 '.
name: Argument name used in error messages, for example '--cache-step-range'.
Returns:
A tuple `(start, end)`.
Raises:
ValueError: If input format or bounds are invalid.
"""
parts = [part.strip() for part in value.split(",")]
if len(parts) != 2 or not parts[0] or not parts[1]:
raise ValueError(f"{name} must be 'start,end', got {value!r}.")
try:
start = int(parts[0])
end = int(parts[1])
except ValueError as exc:
raise ValueError(f"{name} must be 'start,end', got {value!r}.") from exc
if start < 0 or end < 0:
raise ValueError(f"{name} must be non-negative, got {value!r}.")
if end < start:
raise ValueError(f"{name} must be 'start,end' with end >= start, got {value!r}.")
return start, end
def check_string_valid(string: str, max_len=256):
if len(string) > max_len:
raise argparse.ArgumentTypeError(f"String length exceeds {max_len} characters: {string!r}")
if not re.match(r"^[a-zA-Z0-9_/.-]+$", string):
raise argparse.ArgumentTypeError(f"String contains invalid characters: {string!r}")
return string
def get_common_argparser(reserved_memory_gb_default: float = 0.0):
common_parser = argparse.ArgumentParser(
add_help=False,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
general_group = common_parser.add_argument_group("General Options")
general_group.add_argument(
"model_id",
type=check_string_valid,
help=(
"Model source. Recommended safe mode: a reviewed absolute local model path. "
"Model id mode also accepts Hugging Face or ModelScope ids, but may execute remote Python code through "
"trust_remote_code=True and is not security-guaranteed."
),
)
general_group.add_argument(
"--device",
type=str,
choices=list(DeviceProfile.all_device_profiles.keys()),
default="TEST_DEVICE",
help=(
"Specifies the target device profile to use for benchmarking and simulation. "
"Must be a valid device name as defined in DeviceProfile. "
"The default device 'TEST_DEVICE' is used for standard simulation runs."
),
)
general_group.add_argument(
"--num-devices",
type=check_positive_integer,
default=1,
help=(
"Specifies the total number of devices/processes to use. "
"Must be a positive integer. "
"A value of 1 indicates single-device execution."
),
)
general_group.add_argument(
"--enable-multistream",
action="store_true",
default=True,
help=("Enable compiler-driven multi-stream simulation for torch.compile path. Enabled by default."),
)
general_group.add_argument(
"--reserved-memory-gb",
type=float,
default=reserved_memory_gb_default,
help=(
"Amount of device memory (in gigabytes) reserved for system usage and unavailable for application. "
"Set to 0 to disable memory reservation."
),
)
general_group.add_argument(
"--log-level",
choices=LOG_LEVELS,
default="error",
help=(
"Specifies the verbosity level for log output. "
"Available levels: 'debug' (most verbose), 'info', 'warning', 'error', 'critical' (least verbose)."
),
)
return common_parser