import logging
import pandas as pd
from tensor_cast.core.model_runner import ModelRunner
from .base_throughput_optimizer import BaseThroughputOptimizer
from .optimizer_summary import OptimizerSummary
from .utils import (
DISAGG_COLUMNS,
format_breakdowns,
format_parallel_label,
OptimizerData,
)
logger = logging.getLogger(__name__)
class DisaggThroughputOptimizer(BaseThroughputOptimizer):
name = "disaggregation"
def initialize(self, model_runner: ModelRunner):
self.model_runner = model_runner
self.num_mtp_tokens = (
self.model_runner.model.model_config.mtp_config.num_mtp_layers
if self.model_runner.model.model_config.mtp_config is not None
else 0
)
self.dp = self.model_runner.model.model_config.parallel_config.data_parallel_size
self.tp = self.model_runner.model.model_config.parallel_config.tensor_parallel_size
self.pp = self.model_runner.model.model_config.parallel_config.pipeline_parallel_size
self.ep = self.model_runner.model.model_config.parallel_config.expert_parallel_size
self.moe_tp = self.model_runner.model.model_config.parallel_config.moe_tensor_parallel_size
self.moe_dp = self.model_runner.model.model_config.parallel_config.moe_data_parallel_size
self.is_moe_model = self.model_runner.model.model_config.moe_config is not None
def get_inference_info(self, optimizer_data: OptimizerData) -> OptimizerSummary:
decode_flag = optimizer_data.ttft_limits is None
batch_size = optimizer_data.batch_size
input_length = optimizer_data.input_length
effective_input_length = optimizer_data.get_effective_input_length()
max_batched_tokens = optimizer_data.max_batched_tokens
if decode_flag:
chunk_plan = []
else:
chunk_plan = optimizer_data.get_prefill_chunk_plan()
output_length = optimizer_data.output_length
concurrency = batch_size * self.dp * self.pp
if decode_flag or len(chunk_plan) == 1:
batch_result = self._get_forward_info(concurrency, optimizer_data, decode_flag)
latency_ms = batch_result.execution_time_s.get("analytic") * 1000 + optimizer_data.serving_cost
device_memory_available_gb = batch_result.device_memory_available_gb
breakdowns = format_breakdowns(batch_result.breakdowns)
else:
latency_ms = optimizer_data.serving_cost
device_memory_available_gb = float("inf")
breakdowns = ""
breakdown_sums = {}
breakdown_counts = {}
for chunk in chunk_plan:
wave_size = max(max_batched_tokens // chunk.query_len, 1)
remaining = concurrency
while remaining > 0:
wave_concurrency = min(wave_size, remaining)
batch_result = self._get_forward_info(
wave_concurrency,
optimizer_data,
decode_flag,
query_len=chunk.query_len,
seq_len=chunk.seq_len,
)
latency_ms += batch_result.execution_time_s.get("analytic") * 1000
device_memory_available_gb = min(
device_memory_available_gb,
batch_result.device_memory_available_gb,
)
for breakdown_name, breakdown in batch_result.breakdowns.items():
total = sum(breakdown.values())
if total == 0:
continue
normalized_breakdown = {}
for category, value in breakdown.items():
if isinstance(value, float):
normalized_breakdown[category] = value / total
if normalized_breakdown:
accumulated = breakdown_sums.setdefault(breakdown_name, {})
for category, value in normalized_breakdown.items():
accumulated[category] = accumulated.get(category, 0.0) + value
breakdown_counts[breakdown_name] = breakdown_counts.get(breakdown_name, 0) + 1
remaining -= wave_concurrency
if breakdown_sums:
average_breakdowns = {
breakdown_name: {
category: value / breakdown_counts[breakdown_name] for category, value in breakdown.items()
}
for breakdown_name, breakdown in breakdown_sums.items()
}
breakdowns = format_breakdowns(average_breakdowns)
ttft = tpot = None
if decode_flag:
average_tokens = sum(optimizer_data.mtp_acceptance_rate[: optimizer_data.num_mtp_tokens]) + 1
latency_ms /= average_tokens
tpot = latency_ms
output_throughput = concurrency / tpot * 1000 if tpot > 0 else 0
else:
ttft = latency_ms
output_throughput = concurrency / latency_ms * 1000 * input_length if latency_ms > 0 else 0
token_s_device = output_throughput / self.dp / self.pp / self.tp
parallel = format_parallel_label(
self.model_runner.model.model_config.parallel_config,
self.is_moe_model,
)
logger.info(
"TTFT: %r ms, TPOT: %r ms, "
"Output Throughput: %.2f token/s, "
"Concurrency: %d, "
"parallel: %s, "
"Memory Left: %.2f GB",
ttft,
tpot,
output_throughput,
concurrency,
parallel,
device_memory_available_gb,
)
summary = OptimizerSummary(optimizer_data)
result_df = pd.DataFrame(
columns=DISAGG_COLUMNS,
data=[
[
self.model_runner.user_input.device,
optimizer_data.num_devices,
self.model_runner.user_input.model_id,
self.model_runner.user_input.quantize_linear_action,
self.model_runner.user_input.quantize_attention_action,
input_length,
output_length,
effective_input_length,
max_batched_tokens,
len(chunk_plan),
concurrency,
ttft,
tpot,
output_throughput,
token_s_device,
parallel,
batch_size,
breakdowns,
]
],
).round(3)
summary.set_summary_df(result_df)
summary.set_early_stop_flag(device_memory_available_gb, tpot, ttft)
self._maybe_set_search_info(optimizer_data, device_memory_available_gb, batch_size, ttft, tpot, summary)
return summary