import random
import torch
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.layers.attention import AttentionMetadataTensorCast
from tensor_cast.model_config import LinearQuantConfig, ModelConfig, QuantConfig
from tensor_cast.quantize_utils import LinearQuantType
from tensor_cast.transformers.utils import get_attention_quant_config, strip_module_name
from tensor_cast.utils import exact_division
from tests.helpers.model_cache import get_built_model, get_hf_config
def assert_close(self, value1, value2, rtol=0.01):
self.assertLessEqual(abs(value1 - value2) / value1, rtol, f"{value1} vs. {value2}, rtol={rtol}")
def count_events(runtime, op):
return sum(1 for event in runtime.event_list if event.op_invoke_info.func == op)
def create_attn_metadata_and_kv_cache(model, model_config: ModelConfig):
batch_size = 2
query_len_1 = 55
query_len_2 = 45
seq_len_1 = 2000
seq_len_2 = 1500
num_blocks = 10000
block_size = 128
max_seq_len = max(seq_len_1, seq_len_2)
query_start_loc = torch.tensor([0, query_len_1, query_len_1 + query_len_2], dtype=torch.long)
seq_lens = torch.tensor([seq_len_1, seq_len_2], dtype=torch.long)
query_lens = torch.tensor([query_len_1, query_len_2], dtype=torch.long)
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(batch_size):
block_table = [random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)]
block_tables.append(block_table)
block_table_tensor = torch.tensor(block_tables, dtype=torch.long)
attn_meta = AttentionMetadataTensorCast(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
query_lens=query_lens,
block_table_tensor=block_table_tensor,
)
num_tokens = query_len_1 + query_len_2
kv_cache_by_layers = {}
for i in range(model.num_hidden_layers):
kvcache_dtype = model_config.dtype
if (attention_config := get_attention_quant_config(model, i)) is not None:
kvcache_dtype = attention_config.get_quant_dtype()
if model.text_config.num_key_value_heads >= model_config.parallel_config.tensor_parallel_size:
kv_heads = exact_division(
model.text_config.num_key_value_heads,
model_config.parallel_config.tensor_parallel_size,
)
else:
assert model_config.parallel_config.tensor_parallel_size % model.text_config.num_key_value_heads == 0
kv_heads = 1
kv_cache_by_layers[i] = torch.empty(
[
2,
num_blocks,
block_size,
kv_heads,
model.text_config.head_dim,
],
dtype=kvcache_dtype,
device="meta",
)
return attn_meta, kv_cache_by_layers, num_tokens
def create_mla_metadata_and_kv_cache(model, model_config: ModelConfig, query_len_1=55, query_len_2=45):
batch_size = 2
seq_len_1 = 2000
seq_len_2 = 1500
num_blocks = 10000
block_size = 128
max_seq_len = max(seq_len_1, seq_len_2)
query_start_loc = torch.tensor([0, query_len_1, query_len_1 + query_len_2], dtype=torch.long)
seq_lens = torch.tensor([seq_len_1, seq_len_2], dtype=torch.long)
query_lens = torch.tensor([query_len_1, query_len_2], dtype=torch.long)
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(batch_size):
block_table = [random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)]
block_tables.append(block_table)
block_table_tensor = torch.tensor(block_tables, dtype=torch.long)
attn_meta = AttentionMetadataTensorCast(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
query_lens=query_lens,
block_table_tensor=block_table_tensor,
)
num_tokens = query_len_1 + query_len_2
kv_cache_by_layers = {}
for i in range(model.num_hidden_layers):
kvcache_dtype = model_config.dtype
if (attention_config := get_attention_quant_config(model, i)) is not None:
kvcache_dtype = attention_config.get_quant_dtype()
kv_cache_by_layers[i] = torch.empty(
[
num_blocks,
block_size,
model.text_config.kv_lora_rank + model.text_config.qk_rope_head_dim,
],
dtype=kvcache_dtype,
device="meta",
)
return attn_meta, kv_cache_by_layers, num_tokens
def has_submodule_with_cls_name(module, cls_name):
return any(type(sub_module).__name__ == cls_name for _, sub_module in module.named_modules())
def get_linear_quant_config(quant_type, weight=None, **kwargs):
"""Helper to create a default symmetric per-tensor weight quant config.
Can be customized via kwargs
"""
config_args = {
"quant_type": quant_type,
}
if "weight_scale" not in kwargs and weight is not None:
w_scale = torch.max(torch.abs(weight)) / 127.0
config_args.update({"weight_scale": w_scale})
config_args.update(kwargs)
return LinearQuantConfig(**config_args)
def get_quant_config(model=None, quant_type=LinearQuantType.W4A8, **kwargs):
quant_config = QuantConfig()
if model is None:
quant_config.linear_configs["*"] = get_linear_quant_config(
quant_type,
torch.randn(1) if "weight_group_size" not in kwargs else None,
**kwargs,
)
return quant_config
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
strip_name = strip_module_name(name)
quant_config.linear_configs[strip_name] = get_linear_quant_config(
quant_type,
module.weight.data,
**kwargs,
)
return quant_config
def update_parallel_parameter(user_input: UserInputConfig, world_size=1, tp_size=1, ep=False):
user_input.world_size = world_size
user_input.tp_size = tp_size
user_input.ep_size = world_size if ep else 1
def get_cached_build_model(
cache: dict,
user_config: UserInputConfig,
):
"""Return a session-cached TransformerModel from build_model(user_config)."""
return get_built_model(user_config)
def get_cached_hf_config(cache: dict, model_id: str):
"""Return a session-cached hf config for model_id."""
return get_hf_config(model_id)