"""
ModelRuner
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
import torch
from ..device import DeviceProfile
from ..layers.sampler import Sampler
from ..performance_model.analytic import AnalyticPerformanceModel
from ..performance_model.empirical import EmpiricalPerformanceModel
from ..performance_model.memory_tracker import MemoryTracker
from ..performance_model.profiling_database import ProfilingDataSource
from ..performance_model.utils import bytes_of_tensor
from ..runtime import Runtime
from ..transformers.custom_model_registry import get_visual
from .input_generator import (
generate_inputs_varlen,
get_inputs_num_bytes,
get_kv_cache_info,
RequestInfo,
)
from .model_builder import build_model
if TYPE_CHECKING:
from ..performance_model import PerformanceModel
from .user_config import UserInputConfig
logger = logging.getLogger(__name__)
class ModelRunner:
"""
corresponding to one data-parallel partition ('dp_rank')
"""
def __init__(self, user_input: UserInputConfig):
self.user_input = user_input
if user_input.device not in DeviceProfile.all_device_profiles:
logger.error(
"Unsupported device: %s. Available devices: %s",
user_input.device,
list(DeviceProfile.all_device_profiles.keys()),
)
raise ValueError(f"Device '{user_input.device}' not recognized.")
logger.info("Loading device profile")
self.device_profile = DeviceProfile.all_device_profiles[user_input.device]
logger.debug("Device profile loaded: %s", self.device_profile)
logger.info("Initializing performance model")
perf_model_types: List[str] = getattr(user_input, "performance_model", ["analytic"])
profiling_database = getattr(user_input, "profiling_database", None)
self.perf_models: List[PerformanceModel] = []
for perf_model_type in perf_model_types:
if perf_model_type == "profiling":
if not profiling_database:
raise ValueError("--profiling-database must be specified when using --performance-model profiling")
data_source = ProfilingDataSource(
profiling_database,
self.device_profile,
parallel_config=user_input.get_parallel_config(),
)
self.perf_models.append(
EmpiricalPerformanceModel(
self.device_profile,
data_source=data_source,
fallback_model=AnalyticPerformanceModel(self.device_profile),
)
)
elif perf_model_type == "analytic":
self.perf_models.append(AnalyticPerformanceModel(self.device_profile))
logger.debug("Performance models initialized: %s", self.perf_models)
logger.info("Generating request information")
if user_input.num_queries != 0:
self.request_info_default = [user_input.get_request_info()]
logger.debug("Request configured: %s", self.request_info_default)
else:
logger.debug("No default requests configured (num_queries = 0)")
self.request_info_default = None
logger.info("Building model architecture")
self.model = build_model(user_input).eval()
logger.debug("Model built:_%s", self.model)
self.total_device_memory_gb = self.device_profile.memory_size_bytes / 1024**3
self.model_weight_size_gb = self.model.weight_size / 1024**3
logger.info("Initializing Sampler")
self.sampler = Sampler()
logger.debug("Sampler initialized: %s", self.sampler)
def run_inference(
self,
requests: Optional[List[RequestInfo]] = None,
generate_inputs_func: Callable = generate_inputs_varlen,
with_sampler: bool = False,
runtime_observer: Optional[Callable[[Runtime], None]] = None,
) -> ModelRunnerMetrics:
def calculate_single_card_tps(execution_time_s: float) -> float:
if not execution_time_s or execution_time_s <= 0:
raise ValueError("execution_time_s must be positive")
return (self.user_input.num_queries * self.user_input.query_len) / (
execution_time_s * self.user_input.world_size
)
data_parallel_size = self.model.model_config.parallel_config.data_parallel_size
logger.debug("data_parallel_size: %s", data_parallel_size)
batch_size = (self.user_input.num_queries + data_parallel_size - 1) // data_parallel_size
logger.debug("batch_size: %s", batch_size)
if requests is None:
requests = self.request_info_default
logger.debug("requests: %s", requests)
input_kwargs = generate_inputs_func(
self.model,
requests,
block_size=self.user_input.block_size,
)
run_start = time.perf_counter()
with (
Runtime(
self.perf_models,
self.device_profile,
memory_tracker=MemoryTracker(self.device_profile),
) as runtime,
torch.no_grad(),
):
logits = self.model.forward(**input_kwargs)
if with_sampler:
_ = self.sampler(logits, input_kwargs["sampling_metadata"])
run_end = time.perf_counter()
for pm in self.perf_models:
if isinstance(pm, EmpiricalPerformanceModel):
from ..performance_model.metrics_collector import MetricsCollector
collector = MetricsCollector()
collector.collect_from_records(pm.op_records)
collector.log_stats()
all_execution_time_s = runtime.total_execution_time_s()
run_time_s = run_end - run_start
table_result = runtime.table_averages(
group_by_input_shapes=self.user_input.dump_input_shapes,
dump_op_bound_results=self.user_input.dump_op_bound_results,
)
perf_model_name = self.perf_models[0].name if self.perf_models else None
runtime_event_list = self._aggregate_runtime_events(runtime.event_list, perf_model_name=perf_model_name)
tps_per_model: Dict[str, float] = {}
for model_name, exec_time in all_execution_time_s.items():
if exec_time and exec_time > 0:
tps_per_model[model_name] = calculate_single_card_tps(exec_time)
peak_memory_usage_gb = runtime.memory_tracker.peak_mem_usage() / 1024**3
kv_cache_bytes = sum(bytes_of_tensor(kv_cache) for kv_cache in input_kwargs["kv_cache_by_layers"].values())
indexer_cache_bytes = sum(
bytes_of_tensor(kv_cache) for kv_cache in input_kwargs.get("indexer_cache_by_layers", {}).values()
)
kv_cache_size_gb = (kv_cache_bytes + indexer_cache_bytes) / 1024**3
kv_cache_per_token_gb = (
input_kwargs["kv_cache_per_token"] + input_kwargs.get("indexer_cache_per_token", 0)
) / 1024**3
if get_visual(self.model) and input_kwargs.get("pixel_values") is None:
visual_weight_size_gb = self.model.get_weight_size_nested([get_visual(self.model)]) / 1024**3
self.model_weight_size_gb = self.model_weight_size_gb - visual_weight_size_gb
model_activation_size_gb = peak_memory_usage_gb - kv_cache_size_gb - self.model_weight_size_gb
if model_activation_size_gb < 0:
logger.warning(
"Negative activation memory estimate (peak=%.6f GB, weight=%.6f GB, kv=%.6f GB); "
"clamping activation to 0 and adjusting peak for consistency.",
peak_memory_usage_gb,
self.model_weight_size_gb,
kv_cache_size_gb,
)
model_activation_size_gb = 0.0
peak_memory_usage_gb = self.model_weight_size_gb + kv_cache_size_gb
device_memory_available_gb = (
self.total_device_memory_gb - peak_memory_usage_gb - self.user_input.reserved_memory_gb
)
if self.user_input.chrome_trace:
runtime.export_chrome_trace(self.user_input.chrome_trace)
if runtime_observer is not None:
runtime_observer(runtime)
return ModelRunnerMetrics(
total_device_memory_gb=self.total_device_memory_gb,
model_weight_size_gb=self.model_weight_size_gb,
peak_memory_usage_gb=peak_memory_usage_gb,
kv_cache_size_gb=kv_cache_size_gb,
kv_cache_per_token_gb=kv_cache_per_token_gb,
indexer_cache_size_gb=indexer_cache_bytes / 1024**3,
indexer_cache_per_token_gb=input_kwargs.get("indexer_cache_per_token", 0) / 1024**3,
model_activation_size_gb=model_activation_size_gb,
reserved_memory_gb=self.user_input.reserved_memory_gb,
device_memory_available_gb=device_memory_available_gb,
execution_time_s=all_execution_time_s,
tps_per_model=tps_per_model,
run_time_s=run_time_s,
batch_size=batch_size,
table_result=table_result,
breakdowns=runtime.get_breakdowns(),
runtime_event_list=runtime_event_list,
perf_model_name=perf_model_name,
)
def get_inputs_num_bytes(self, requests: List[RequestInfo]) -> int:
return get_inputs_num_bytes(self.model, requests, self.user_input.block_size)
def get_kv_cache_num_bytes(self, num_tokens: int) -> int:
return get_kv_cache_info(self.model, 1, 1) * num_tokens
def _aggregate_runtime_events(self, event_list, perf_model_name: Optional[str] = None) -> List[Dict]:
aggregated: Dict[str, Dict[str, float]] = {}
for event in event_list:
name = str(event.op_invoke_info.func)
entry = aggregated.setdefault(name, {"total": 0.0, "count": 0})
entry["count"] += 1
if perf_model_name is None:
continue
result = event.perf_results.get(perf_model_name)
if result is not None:
entry["total"] += result.execution_time_s
items: List[Dict] = []
for name, entry in aggregated.items():
count = entry["count"]
total = entry["total"]
items.append(
{
"name": name,
"perf_model": perf_model_name,
"perf_total": total,
"perf_avg": total / count if count else 0.0,
"call_times": count,
}
)
items.sort(key=lambda x: x["perf_total"], reverse=True)
return items
@dataclass
class ModelRunnerMetrics:
total_device_memory_gb: float
model_weight_size_gb: float
peak_memory_usage_gb: float
kv_cache_size_gb: float
kv_cache_per_token_gb: float
model_activation_size_gb: float
reserved_memory_gb: float
device_memory_available_gb: float
execution_time_s: Dict[str, float]
"""Execution time per performance model, keyed by model name."""
tps_per_model: Dict[str, float]
"""TPS per performance model, keyed by model name."""
run_time_s: float
batch_size: int
indexer_cache_size_gb: float = 0.0
indexer_cache_per_token_gb: float = 0.0
table_result: str = ""
breakdowns: Dict[str, Dict[str, float]] = field(default_factory=dict)
runtime_event_list: List[Dict] = field(default_factory=list)
perf_model_name: Optional[str] = None
def print_info(self):
print(f"Number of Queries per DP rank: {self.batch_size}")
print(f"Model compilation and execution time: {self.run_time_s:.3f} s")
print(self.table_result)
for model_name, exec_time in self.execution_time_s.items():
print(f"[{model_name}] Execution time: {exec_time:.6f} s")
tps = self.tps_per_model.get(model_name)
if tps is not None:
print(f"[{model_name}] TPS/Device: {tps:.4g} token/s")
print(f"Total device memory: {self.total_device_memory_gb:.3f} GB")
print(f" Model weight size: {self.model_weight_size_gb:.3f} GB")
print(f" KV cache: {self.kv_cache_size_gb:.3f} GB")
if self.indexer_cache_size_gb > 0:
print(f" Main KV cache: {self.kv_cache_size_gb - self.indexer_cache_size_gb:.3f} GB")
print(f" Indexer cache: {self.indexer_cache_size_gb:.3f} GB")
print(f" Indexer cache per token: {self.indexer_cache_per_token_gb:.6f} GB")
print(f" Model activation size: {self.model_activation_size_gb:.3f} GB")
print(f" Reserved memory: {self.reserved_memory_gb:.3f} GB")
print(f" Memory available: {self.device_memory_available_gb:.3f} GB")
print("Stats breakdowns:")
for breakdown_name, breakdown in self.breakdowns.items():
total = sum(breakdown.values())
if total == 0:
continue
formatted = ", ".join(f"{key}: {val * 100 / total:.2f}" for key, val in breakdown.items())
print(f" {breakdown_name}: {formatted}")
def dump_json(self, path: str) -> None:
breakdowns_percent: Dict[str, Dict[str, float]] = {}
for name, breakdown in self.breakdowns.items():
total = sum(breakdown.values())
if total == 0:
continue
breakdowns_percent[name] = {k: round(v * 100 / total, 4) for k, v in breakdown.items()}
payload = {
"batch_size": self.batch_size,
"run_time_s": self.run_time_s,
"execution_time_s": dict(self.execution_time_s),
"tps_per_model": dict(self.tps_per_model),
"memory_gb": {
"total_device": self.total_device_memory_gb,
"model_weight": self.model_weight_size_gb,
"peak_usage": self.peak_memory_usage_gb,
"kv_cache": self.kv_cache_size_gb,
"kv_cache_per_token": self.kv_cache_per_token_gb,
"model_activation": self.model_activation_size_gb,
"reserved": self.reserved_memory_gb,
"available": self.device_memory_available_gb,
},
"breakdowns_raw": {k: dict(v) for k, v in self.breakdowns.items()},
"breakdowns_percent": breakdowns_percent,
"perf_model_name": self.perf_model_name,
"runtime_event_list": self.runtime_event_list,
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)