import unittest
import torch
from tensor_cast.core.model_builder import build_model
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import TEST_DEVICE
from tensor_cast.model_config import ModelConfig, ParallelConfig, QuantConfig
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.performance_model.base import PerformanceModel
from tensor_cast.performance_model.memory_tracker import MemoryTracker
from tensor_cast.performance_model.op_estimator_registry import _op_estimator_table, register_op_estimator
from tensor_cast.performance_model.op_invoke_info import OpInvokeInfo
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.custom_model_registry import get_moe_config
from tensor_cast.transformers.model import TransformerModel
from tensor_cast.transformers.utils import AutoModelConfigLoader
from .test_common import create_attn_metadata_and_kv_cache
def _restore_op_properties_functor(target_op, original_functor) -> None:
if original_functor is None:
OpInvokeInfo._op_properties_functors.pop(target_op, None)
return
OpInvokeInfo._op_properties_functors[target_op] = original_functor
def _restore_op_estimator(target_dtype, target_op, original_estimator) -> None:
estimator_by_dtype = _op_estimator_table.setdefault(target_dtype, {})
if original_estimator is None:
estimator_by_dtype.pop(target_op, None)
if not estimator_by_dtype:
_op_estimator_table.pop(target_dtype, None)
return
estimator_by_dtype[target_op] = original_estimator
class CustomModelingOperatorTestCase(unittest.TestCase):
def test_custom_operator_properties(self):
MEMORY_READ_BYTES = 32768
MEMORY_WRITE_BYTES = 32768
MEMORY_READWRITE_BYTES = 0
MMA_OPS = 100000
GP_OPS = 5000
NUM_TOKENS = 100
MODEL_ID = "Qwen/Qwen3-32B"
TARGET_OP_NAME = "reshape_and_cache"
target_op = torch.ops.tensor_cast.reshape_and_cache.default
original_properties_functor = OpInvokeInfo._op_properties_functors.get(target_op)
self.addCleanup(_restore_op_properties_functor, target_op, original_properties_functor)
@OpInvokeInfo.register_op_properties(target_op, True)
def simple_operator_properties(
op_invoke_info: OpInvokeInfo,
) -> OpInvokeInfo.PerformanceProperties:
properties = OpInvokeInfo.PerformanceProperties()
properties.memory_read_bytes = MEMORY_READ_BYTES
properties.memory_write_bytes = MEMORY_WRITE_BYTES
properties.memory_readwrite_bytes = MEMORY_READWRITE_BYTES
compute_ops = properties.compute_ops.setdefault(torch.float16, OpInvokeInfo.ComputeOps())
compute_ops.mma_ops = MMA_OPS
compute_ops.gp_ops = GP_OPS
return properties
user_config = UserInputConfig(model_id=MODEL_ID)
model = build_model(user_config)
inputs = torch.empty([1, NUM_TOKENS], dtype=torch.long, device="meta")
position_ids = torch.empty([1, NUM_TOKENS], dtype=torch.long, device="meta")
device_profile = TEST_DEVICE
perf_model = AnalyticPerformanceModel(device_profile)
attn_meta, kv_cache_by_layers, num_tokens = create_attn_metadata_and_kv_cache(model, model.model_config)
with (
Runtime(perf_model, device_profile, memory_tracker=MemoryTracker(device_profile)) as runtime,
torch.no_grad(),
):
model.forward(
inputs,
position_ids,
attention_meta=attn_meta,
kv_cache_by_layers=kv_cache_by_layers,
)
self.assertIn(
target_op,
OpInvokeInfo._op_properties_functors,
"failed to register operator",
)
result = None
for event in runtime.event_list:
if (
hasattr(event.op_invoke_info, "func")
and event.op_invoke_info.func is not None
and TARGET_OP_NAME in event.op_invoke_info.func._name
):
result = event
break
self.assertIsNotNone(result, "Failed to get result")
perf_props = result.op_invoke_info.get_perf_properties()
self.assertEqual(perf_props.memory_read_bytes, MEMORY_READ_BYTES)
self.assertEqual(perf_props.memory_write_bytes, MEMORY_WRITE_BYTES)
self.assertEqual(perf_props.memory_readwrite_bytes, MEMORY_READWRITE_BYTES)
compute_ops = perf_props.compute_ops.get(torch.float16)
self.assertEqual(compute_ops.mma_ops, MMA_OPS)
self.assertEqual(compute_ops.gp_ops, GP_OPS)
def test_custom_estimate_operator_estimator(self):
all_to_all_execution_time_s = 3.0
target_op = torch.ops.tensor_cast.all_to_all.default
original_estimator = _op_estimator_table.get(None, {}).get(target_op)
self.addCleanup(_restore_op_estimator, None, target_op, original_estimator)
@register_op_estimator(target_op, None, True)
def _estimate_custom_comm(op_invoke_info, device_profile) -> object:
return PerformanceModel.Result(all_to_all_execution_time_s)
model_id = "Qwen/Qwen3-235B-A22B"
auto_loader = AutoModelConfigLoader()
hf_config = auto_loader.load_config(model_id)
moe_config = get_moe_config(hf_config.model_type)
parallel_config = ParallelConfig(
world_size=16,
tensor_parallel_size=2,
mlp_tensor_parallel_size=4,
lmhead_tensor_parallel_size=1,
expert_parallel_size=16,
moe_data_parallel_size=1,
moe_tensor_parallel_size=1,
)
model_config = ModelConfig(
parallel_config,
QuantConfig(),
enable_repetition=True,
moe_config=moe_config,
hf_config=hf_config,
)
model = TransformerModel(model_id, model_config)
NUM_TOKENS = 100
inputs = torch.empty([1, NUM_TOKENS], dtype=torch.long, device="meta")
position_ids = torch.empty([1, NUM_TOKENS], dtype=torch.long, device="meta")
attn_meta, kv_cache_by_layers, num_tokens = create_attn_metadata_and_kv_cache(model, model.model_config)
machine_config = TEST_DEVICE
perf_model = AnalyticPerformanceModel(machine_config)
with (
Runtime(perf_model, machine_config) as runtime,
torch.no_grad(),
):
model.forward(
inputs,
position_ids,
attention_meta=attn_meta,
kv_cache_by_layers=kv_cache_by_layers,
)
TARGET_OP_NAME = "all_to_all"
result = None
for event in runtime.event_list:
if (
hasattr(event.op_invoke_info, "func")
and event.op_invoke_info.func is not None
and TARGET_OP_NAME in event.op_invoke_info.func._name
):
result = event
break
self.assertIsNotNone(result)
self.assertEqual(
result.perf_results.get("analytic").execution_time_s,
all_to_all_execution_time_s,
)
if __name__ == "__main__":
unittest.main()