from typing import Optional
from typing import Dict, Any, Dict
import torch
class PostQuantWeight(torch.Tensor):
"""
Tensor subclass that carries quantized data, scale metadata, original dtype,
and axiswise quantization layout for MX/FP8 linear forward and backward paths.
"""
_weight_fwd: torch.Tensor
_scale_fwd: torch.Tensor
_weight_bwd: torch.Tensor
_scale_bwd: torch.Tensor
_orig_dtype: torch.dtype
__slots__ = [
"_weight_fwd",
"_scale_fwd",
"_weight_bwd",
"_scale_bwd",
"_orig_dtype",
]
def __new__(
cls,
weight_fwd: Optional[torch.Tensor],
scale_fwd: Optional[torch.Tensor],
weight_bwd: Optional[torch.Tensor],
scale_bwd: Optional[torch.Tensor],
orig_dtype: torch.dtype,
):
reference_tensor = weight_fwd if weight_fwd is not None else weight_bwd
if reference_tensor is None:
raise ValueError("At least one of data or data2 must be provided")
self = torch.Tensor._make_wrapper_subclass(
cls,
reference_tensor.size(),
strides=reference_tensor.stride(),
storage_offset=reference_tensor.storage_offset(),
dtype=orig_dtype,
layout=reference_tensor.layout,
requires_grad=reference_tensor.requires_grad,
device=reference_tensor.device,
)
self._weight_fwd = weight_fwd
self._scale_fwd = scale_fwd
self._weight_bwd = weight_bwd
self._scale_bwd = scale_bwd
self._orig_dtype = orig_dtype
return self
def __repr__(self):
return "QuantTrainingTensor(" f"lp_dtype={self._weight_fwd.dtype}"
def __tensor_flatten__(self):
metadata = {
"_orig_dtype": self._orig_dtype,
}
tensors = ["_weight_fwd", "_scale_fwd", "_weight_bwd", "_scale_bwd"]
return tensors, metadata
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
return PostQuantWeight(
inner_tensors["_weight_fwd"],
inner_tensors["_scale_fwd"],
inner_tensors["_weight_bwd"],
inner_tensors["_scale_bwd"],
metadata["_orig_dtype"],
)
def to_original_precision(self):
raise NotImplementedError()
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
def allowed_subclasses(type_):
return issubclass(cls, type_) or issubclass(torch._subclasses.fake_tensor.FakeTensor, type_) or issubclass(
torch._subclasses.functional_tensor.FunctionalTensor, type_)
if not all(allowed_subclasses(t) for t in types):
return NotImplemented
if func in OPS_TABLE:
return OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"attempting to run {func}, this is not supported")
__torch_function__ = torch._C._disabled_torch_function_impl
aten = torch.ops.aten
OPS_TABLE: Dict[Any, Any] = {}
def implements(aten_ops):
def decorator(func):
for op in aten_ops:
if op in OPS_TABLE:
raise RuntimeError(f"Float8 op {op} is already registered to {OPS_TABLE[op].__name__}")
OPS_TABLE[op] = func
return func
return decorator
@implements(
[
aten.view.default,
aten._unsafe_view.default,
aten.as_strided.default,
aten.clone.default,
aten.slice.Tensor,
aten.fill_.Scalar,
aten.reshape.default,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
arg0 = args[0]
new_data_fwd, new_scale_fwd, new_data_bwd, new_scale_bwd = None, None, None, None
if hasattr(arg0, "_weight_fwd") and arg0._weight_fwd is not None and arg0._weight_fwd.numel() > 0:
new_data_fwd = aten_op(arg0._weight_fwd, *args[1:], **kwargs)
new_scale_fwd = arg0._scale_fwd
if hasattr(arg0, "_weight_bwd") and arg0._weight_bwd is not None and arg0._weight_bwd.numel() > 0:
new_data_bwd = aten_op(arg0._weight_bwd, *args[1:], **kwargs)
new_scale_bwd = arg0._scale_bwd
return PostQuantWeight(
new_data_fwd,
new_scale_fwd,
new_data_bwd,
new_scale_bwd,
args[0]._orig_dtype,
)
@implements([aten.detach.default])
def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None):
arg0 = args[0]
new_data_fwd, new_scale_fwd, new_data_bwd, new_scale_bwd = None, None, None, None
if hasattr(arg0, "_weight_fwd") and arg0._weight_fwd is not None and arg0._weight_fwd.numel() > 0:
new_data_fwd = aten_op(arg0._weight_fwd, *args[1:], **kwargs)
new_scale_fwd = arg0._scale_fwd
if hasattr(arg0, "_weight_bwd") and arg0._weight_bwd is not None and arg0._weight_bwd.numel() > 0:
new_data_bwd = aten_op(arg0._weight_bwd, *args[1:], **kwargs)
new_scale_bwd = arg0._scale_bwd
return PostQuantWeight(
new_data_fwd,
new_scale_fwd,
new_data_bwd,
new_scale_bwd,
args[0]._orig_dtype,
)