from dataclasses import dataclass, field
from enum import Enum
from typing import Any
import torch
@dataclass
class Sample:
"""The sample generated"""
group_index: int | None = None
index: int | None = None
prompt: str | list[dict[str, str]] = ""
tokens: list[int] = field(default_factory=list)
multimodal_inputs: dict[str, Any] | None = None
multimodal_train_inputs: dict[str, Any] | None = None
response: str = ""
response_length: int = 0
label: str | None = None
reward: float | dict[str, Any] | None = None
loss_mask: list[int] | None = None
weight_versions: list[str] = field(default_factory=list)
rollout_log_probs: list[float] | None = None
rollout_routed_experts: list[list[int]] | None = None
remove_sample: bool = False
teacher_log_probs: list[float] | None = None
class Status(Enum):
PENDING = "pending"
COMPLETED = "completed"
TRUNCATED = "truncated"
ABORTED = "aborted"
FAILED = "failed"
status: Status = Status.PENDING
metadata: dict = field(default_factory=dict)
generate_function_path: str | None = None
train_metadata: dict | None = None
session_id: str | None = None
non_generation_time: float = 0.0
@dataclass
class SpecInfo:
spec_accept_token_num: int = 0
spec_draft_token_num: int = 0
spec_verify_ct: int = 0
completion_token_num: int = 0
@property
def spec_accept_rate(self) -> float:
return self.spec_accept_token_num / self.spec_draft_token_num if self.spec_draft_token_num > 0 else 0.0
@property
def spec_accept_length(self) -> float:
return self.completion_token_num / self.spec_verify_ct if self.spec_verify_ct > 0 else 0.0
def add(self, meta_info: dict):
self.spec_accept_token_num += meta_info.get("spec_accept_token_num", 0)
self.spec_draft_token_num += meta_info.get("spec_draft_token_num", 0)
self.spec_verify_ct += meta_info.get("spec_verify_ct", 0)
self.completion_token_num += meta_info.get("completion_tokens", 0)
def to_dict(self):
return {
"spec_accept_token_num": self.spec_accept_token_num,
"spec_draft_token_num": self.spec_draft_token_num,
"spec_verify_ct": self.spec_verify_ct,
"completion_token_num": self.completion_token_num,
}
@staticmethod
def from_dict(data: dict):
info = Sample.SpecInfo()
info.spec_accept_token_num = data.get("spec_accept_token_num", 0)
info.spec_draft_token_num = data.get("spec_draft_token_num", 0)
info.spec_verify_ct = data.get("spec_verify_ct", 0)
info.completion_token_num = data.get("completion_token_num", 0)
return info
spec_info: SpecInfo = field(default_factory=SpecInfo)
@dataclass
class PrefixCacheInfo:
cached_tokens: int = 0
total_prompt_tokens: int = 0
@property
def prefix_cache_hit_rate(self) -> float:
return self.cached_tokens / self.total_prompt_tokens if self.total_prompt_tokens > 0 else 0.0
def add(self, meta_info: dict):
self.cached_tokens += meta_info.get("cached_tokens", 0)
self.total_prompt_tokens += meta_info.get("prompt_tokens", 0)
def to_dict(self):
return {
"cached_tokens": self.cached_tokens,
"total_prompt_tokens": self.total_prompt_tokens,
}
@staticmethod
def from_dict(data: dict):
info = Sample.PrefixCacheInfo()
info.cached_tokens = data.get("cached_tokens", 0)
info.total_prompt_tokens = data.get("total_prompt_tokens", 0)
return info
prefix_cache_info: PrefixCacheInfo = field(default_factory=PrefixCacheInfo)
def to_dict(self):
value = self.__dict__.copy()
value["status"] = self.status.value
value["spec_info"] = self.spec_info.to_dict()
value["prefix_cache_info"] = self.prefix_cache_info.to_dict()
return value
@staticmethod
def from_dict(data: dict):
data = dict(data)
data["status"] = Sample.Status(data["status"])
data["spec_info"] = Sample.SpecInfo.from_dict(data.get("spec_info", {}))
data["prefix_cache_info"] = Sample.PrefixCacheInfo.from_dict(data.get("prefix_cache_info", {}))
field_names = set(Sample.__dataclass_fields__.keys())
init_data = {k: v for k, v in data.items() if k in field_names}
sample = Sample(**init_data)
for key, value in data.items():
if key not in field_names:
setattr(sample, key, value)
return sample
def get_reward_value(self, args) -> float:
return self.reward if not args.reward_key else self.reward[args.reward_key]
@property
def effective_response_length(self):
return sum(self.loss_mask) if self.loss_mask is not None else self.response_length
def update_from_meta_info(self, args, meta_info: dict):
"""
Update the sample with new information from meta_info returned by the rollout engine.
And extract
"""
if args.sglang_speculative_algorithm:
self.spec_info.add(meta_info=meta_info)
self.prefix_cache_info.add(meta_info=meta_info)
if "weight_version" in meta_info:
self.weight_versions.append(meta_info["weight_version"])
match meta_info["finish_reason"]["type"]:
case "length":
self.status = Sample.Status.TRUNCATED
case "abort":
self.status = Sample.Status.ABORTED
case "stop":
self.status = Sample.Status.COMPLETED
@dataclass(frozen=True)
class ParamInfo:
name: str
dtype: torch.dtype
shape: torch.Size
attrs: dict
size: int
src_rank: int
RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]
@dataclass
class MultimodalType:
name: str
placeholder: str
class MultimodalTypes:
IMAGE = MultimodalType(name="image", placeholder="<image>")
VIDEO = MultimodalType(name="video", placeholder="<video>")
AUDIO = MultimodalType(name="audio", placeholder="<audio>")
@classmethod
def all(cls) -> list[MultimodalType]:
return [cls.IMAGE, cls.VIDEO, cls.AUDIO]
@classmethod
def get(cls, name: str) -> MultimodalType | None:
return next((m for m in cls.all() if m.name == name), None)