import argparse
import os
from cli.logo import print_logo
from serving_cast.config import Config
from serving_cast.instance import Instance
from serving_cast.load_gen import FixedLengthLoadGen
from serving_cast.profiler import profiler_interface
from serving_cast.serving import PdAggregationServing, PdDisaggregationServing
from serving_cast.utils import (
gen_profiling_config_set_env_variable,
get_basic_timestamp,
main_processing,
summarize,
)
import serving_cast.stime as stime
def parse_command_line_args():
"""
Parse command line arguments for the simulation.
Expected usage:
python main.py --instance_config=xx.json,yy.json --common_config=zz.json
"""
parser = argparse.ArgumentParser(description="Simulation service")
def validate_file_path(path):
"""Validate that a file path exists"""
if not os.path.exists(path):
raise ValueError(f"File does not exist: {path}")
return path
parser.add_argument(
"--instance_config_path",
type=validate_file_path,
required=True,
help="Path to a YAML file that declares one or more instance groups. "
"Each group defines a homogeneous pool of nodes (role, count, TP/DP parallelism) "
"and can be mixed-and-matched in a single benchmark run.",
)
parser.add_argument(
"--common_config_path",
type=validate_file_path,
required=True,
help="Path to a YAML file with global settings: model architecture, "
"request-generation workload, and serving limits.",
)
parser.add_argument(
"--enable_profiling",
action="store_true",
help="Enable profiling during simulation (default: False)",
)
parser.add_argument(
"--profiling_output_path",
default="./profiling_results",
help="Path to directory where profiling results will be saved (default: ./profiling_results)",
)
parser.add_argument(
"--output_json",
default=None,
help="If set, write the benchmark summary (per-metric table and overall summary) "
"as a structured JSON to this file path.",
)
args = parser.parse_args()
return args
def instance_group2pd_type(instance_group):
is_pd_aggregation = (
len(instance_group["both"]) > 0 and len(instance_group["prefill"]) == 0 and len(instance_group["decode"]) == 0
)
is_pd_disaggregation = (
len(instance_group["both"]) == 0 and len(instance_group["prefill"]) > 0 and len(instance_group["decode"]) > 0
)
if is_pd_aggregation and not is_pd_disaggregation:
return "pd_aggregation"
elif not is_pd_aggregation and is_pd_disaggregation:
return "pd_disaggregation"
else:
return None
def get_instance_group(instance_config_list, common_config):
instance_group = {"prefill": [], "decode": [], "both": []}
for instance_config in instance_config_list:
for _ in range(instance_config.num_instances):
instance = Instance(instance_config)
if instance_config.pd_role not in instance_group:
raise ValueError(f"{instance_config.pd_role} is not supported")
else:
instance_group[instance_config.pd_role].append(instance)
pd_type = instance_group2pd_type(instance_group)
if pd_type in ["pd_aggregation", "pd_disaggregation"]:
return instance_group
else:
raise ValueError("check instance's pd_role")
def get_serving(instance_group):
pd_type = instance_group2pd_type(instance_group)
if pd_type == "pd_aggregation":
serving = PdAggregationServing(instance_group["both"])
elif pd_type == "pd_disaggregation":
serving = PdDisaggregationServing(instance_group["prefill"], instance_group["decode"])
else:
raise ValueError(f"Unknown pd type: {pd_type}")
return serving
def get_load_gen(load_gen_config):
if load_gen_config.load_gen_type == "fixed_length":
load_gen = FixedLengthLoadGen(
model_name=None,
num_requests=load_gen_config.num_requests,
num_input_tokens=load_gen_config.num_input_tokens,
num_output_tokens=load_gen_config.num_output_tokens,
request_rate=load_gen_config.request_rate,
)
return load_gen
else:
raise ValueError(f"Unknown load generator type: {load_gen_config.load_gen_type!r}")
def init_profiling(args):
profiling_path_with_timestamp = os.path.join(args.profiling_output_path, get_basic_timestamp())
os.makedirs(profiling_path_with_timestamp, exist_ok=True)
gen_profiling_config_set_env_variable(prof_dir=profiling_path_with_timestamp)
profiler_interface.init_profiling()
return profiling_path_with_timestamp
def parse_profiling_results(profiling_path_with_timestamp):
profiler_interface.parse_profiling_results(profiling_path_with_timestamp)
def main():
args = parse_command_line_args()
print_logo()
if args.enable_profiling:
profiling_path_with_timestamp = init_profiling(args)
config = Config(parsed_args=args)
stime.init_simulation()
instance_group = get_instance_group(config.instance_config_list, config.common_config)
load_gen = get_load_gen(config.common_config.load_gen)
serving = get_serving(instance_group)
_ = stime.CallableTask(main_processing, serving, load_gen)
stime.start_simulation()
summarize(load_gen.get_finished_requests().values(), output_json_path=args.output_json)
if args.enable_profiling:
parse_profiling_results(profiling_path_with_timestamp)
for pd_type in instance_group:
for instance in instance_group[pd_type]:
instance.shutdown()
if __name__ == "__main__":
main()