from abc import ABC, abstractmethod
from typing import Optional
import torch
from torch.utils._cxx_pytree import tree_map
from ..config import performance_model as perf_config
from ..device import DeviceProfile
from .base import PerformanceModel
from .op_invoke_info import OpInvokeInfo
from .utils import is_noop_self_copy_op, is_view_op
_op_impl_registry = {}
def get_op_impl(op, device_str):
return _op_impl_registry.get((op, device_str))
def register_op_impl(op, device_str):
def decorator(benchmark_functor):
key = (op, device_str)
if key in _op_impl_registry:
raise ValueError(f"Implementation for {op} on {device_str} already registered")
_op_impl_registry[key] = benchmark_functor
return benchmark_functor
return decorator
@register_op_impl(torch.ops.tensor_cast.quantize.default, torch.device("cpu"))
def _(
x: torch.Tensor,
scale: torch.Tensor,
offset: Optional[torch.Tensor],
out_dtype: torch.dtype = torch.int8,
) -> torch.Tensor:
output = torch.round(x / scale)
if offset is not None:
output = output + offset
return output.to(dtype=out_dtype)
class OpBenchmarkBase(ABC):
def __init__(self, device_profile: DeviceProfile):
self.device_profile = device_profile
@abstractmethod
def benchmark(self, op_invoke_info: OpInvokeInfo) -> PerformanceModel.Result:
pass
class OpBenchmark(OpBenchmarkBase):
def __init__(self, device_profile: DeviceProfile):
super().__init__(device_profile)
self.runtime_device = self.infer_runtime_device()
def benchmark(self, op_invoke_info: OpInvokeInfo) -> PerformanceModel.Result:
if is_view_op(op_invoke_info.func) or is_noop_self_copy_op(op_invoke_info.func, op_invoke_info.args):
return PerformanceModel.Result(0.0)
if op_invoke_info.func.namespace == "tensor_cast":
op_impl = get_op_impl(op_invoke_info.func, self.runtime_device)
if op_impl is None:
raise ValueError(f"No implementation registered for {op_invoke_info.func} on {self.runtime_device}")
else:
op_impl = op_invoke_info.func
return self.do_bench(op_impl, op_invoke_info.args, op_invoke_info.kwargs)
def do_bench(self, op_impl, args, kwargs) -> PerformanceModel.Result:
real_args, real_kwargs = tree_map(
lambda t: (torch.empty_like(t, device=self.runtime_device) if isinstance(t, torch.Tensor) else t),
(args, kwargs),
)
for _ in range(perf_config.empirical.warmup_runs):
op_impl(*real_args, **real_kwargs)
import time
start_time = time.perf_counter()
for _ in range(perf_config.empirical.benchmark_runs):
op_impl(*real_args, **real_kwargs)
end_time = time.perf_counter()
avg_latency_s = (end_time - start_time) / perf_config.empirical.benchmark_runs
return PerformanceModel.Result(execution_time_s=avg_latency_s)
def infer_runtime_device(self) -> torch.device:
if (device_override := perf_config.empirical.runtime_device_override) is not None:
return device_override
if self.device_profile.name == "TEST_DEVICE":
return torch.device("cpu")
if self.device_profile.vendor == "HUAWEI":
return torch.device("npu")
elif self.device_profile.vendor == "NVIDIA":
return torch.device("cuda")
else:
raise ValueError(f"Unsupported benchmarking ops on vendor: {self.device_profile.vendor}")