"""Tensor class with HIF8 data"""
from __future__ import annotations
__all__ = []
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error
tex = torch_npu._C._cd
aten = torch.ops.aten
NPU_CUSTOM_DType = {
torch.uint8: tex.DType.uint8,
torch.int32: tex.DType.int32,
torch.float32: tex.DType.float32,
torch.half: tex.DType.float16,
torch.bfloat16: tex.DType.bfloat16,
}
class _FromHiFloat8Func(torch.autograd.Function):
"""Cast from HIF8 to other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx,
tensor: _HiFloat8Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if dtype is None:
dtype = tensor.dtype
data = tensor._data.contiguous().view(1, -1).detach()
out = tex.cast_from_fp8(
data,
tex.DType.hifloat8,
NPU_CUSTOM_DType[dtype],
)
out = out.view(tensor.size())
return out
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
return grad, None
class _ToHiFloat8Func(torch.autograd.Function):
"""Cast to HIF8 from other dtype"""
@staticmethod
def forward(
_ctx: torch.autograd.function.FunctionCtx,
tensor: torch.Tensor,
) -> _HiFloat8Tensor:
tensor = tensor.contiguous().npu().detach()
if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16):
tensor = tensor.float()
data = tex.cast_to_fp8(
tensor.view(1, -1),
tex.DType.hifloat8,
)
data = data.view(tensor.size())
return _HiFloat8Tensor(
data=data,
dtype=tensor.dtype,
)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new _HiFloat8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx,
tensor: _HiFloat8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
ctx.input_dtype = tensor.dtype
if init_kwargs is None:
return tensor
default_kwargs = dict(
data=tensor._data,
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in init_kwargs:
init_kwargs[key] = val
return _HiFloat8Tensor(**init_kwargs)
@staticmethod
def backward(ctx, grad):
return grad.to(ctx.input_dtype), None
class _ViewFunc(torch.autograd.Function):
"""View function
View the _HiFloat8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
ctx.shape = tensor.shape
if shape is None:
return tensor
if isinstance(tensor, _HiFloat8Tensor):
return _HiFloat8Tensor.make_like(
tensor,
data=tensor._data.view(*shape),
)
return tensor.view(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, _HiFloat8Tensor):
dgrad = _HiFloat8Tensor.make_like(
grad,
data=grad._data.view(ctx.shape),
)
return dgrad, None
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the _HiFloat8Tensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
ctx.shape = tensor.shape
if shape is None:
return tensor
if isinstance(tensor, _HiFloat8Tensor):
return _HiFloat8Tensor.make_like(
tensor,
data=tensor._data.reshape(*shape),
)
return tensor.reshape(*shape)
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad, _HiFloat8Tensor):
dgrad = _HiFloat8Tensor.make_like(
grad,
data=grad._data.reshape(ctx.shape),
)
return dgrad, None
return grad.reshape(ctx.shape), None
class _TransposeFunc(torch.autograd.Function):
"""Transpose function
Transpose the _HiFloat8Tensor.
"""
@staticmethod
def forward(ctx, tensor, dim0, dim1):
ctx.save_for_backward(dim0, dim1)
if isinstance(tensor, _HiFloat8Tensor):
return _HiFloat8Tensor.make_like(
tensor,
data=tensor._data.transpose(dim0, dim1),
)
return tensor.transpose(dim0, dim1)
@staticmethod
def backward(ctx, grad):
dim0, dim1 = ctx.saved_tensors
if isinstance(grad, _HiFloat8Tensor):
dgrad = _HiFloat8Tensor.make_like(
grad,
data=grad._data.transpose(dim0, dim1),
)
return dgrad, None
return grad.transpose(dim0, dim1), None, None
class _HiFloat8Tensor(torch.Tensor):
"""Experimental tensor class with HIF8 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) HIF8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
data: torch.Tensor
Raw HIF8 data in a uint8 tensor
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
def __new__(
cls,
*,
data: torch.Tensor,
dtype: torch.dtype = torch.float32,
):
if data.element_size() != 1:
raise ValueError(
f"HiFloat8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})"
+ pta_error(ErrCode.VALUE)
)
if data.requires_grad:
raise ValueError(
"HiFloat8Tensor requires non-differentiable data buffer"
+ pta_error(ErrCode.VALUE)
)
if not data.is_npu:
data = data.npu()
self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=dtype,
layout=data.layout,
requires_grad=data.requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
return self
@classmethod
def make_like(
cls,
tensor: _HiFloat8Tensor,
*,
data: torch.Tensor,
**kwargs,
) -> _HiFloat8Tensor:
"""Use attributes of a _HiFloat8Tensor to create another _HiFloat8Tensor
See constructor for list of keyword arguments.
"""
default_kwargs = dict(
dtype=tensor.dtype,
)
for key, val in default_kwargs.items():
if key not in kwargs:
kwargs[key] = val
return _HiFloat8Tensor(data=data, **kwargs)
def __repr__(self):
return (
"HiFloat8Tensor("
f"data={self.from_hifloat8(dtype=self.dtype)}"
")"
)
def from_hifloat8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct PyTorch tensor from _HiFloat8Tensor
By default the resulting tensor's dtype is the
_HiFloat8Tensor's nominal dtype.
"""
return _FromHiFloat8Func.apply(self, dtype)
@classmethod
def to_hifloat8(
cls,
tensor: torch.Tensor
):
"""Construct _HiFloat8Tensor from PyTorch tensor"""
return _ToHiFloat8Func.apply(
tensor
)
def float(self) -> torch.Tensor:
return self.from_hifloat8(dtype=torch.float32)
def bfloat16(self) -> torch.Tensor:
return self.from_hifloat8(dtype=torch.bfloat16)
def half(self) -> torch.Tensor:
return self.from_hifloat8(dtype=torch.float16)
def cpu(self) -> torch.Tensor:
return self.from_hifloat8().cpu()
def clone(self) -> _HiFloat8Tensor:
return _IdentityFunc.apply(self, {"data": self._data.detach().clone()})
def view(self, *shape: Tuple[int]) -> _HiFloat8Tensor:
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> _HiFloat8Tensor:
return _ReshapeFunc.apply(self, shape)
def contiguous(
self,
*,
memory_format: torch.memory_format = torch.contiguous_format,
) -> _HiFloat8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if self._data.is_contiguous(memory_format=memory_format):
return self
return _IdentityFunc.apply(
self,
{"data": self._data.detach().contiguous(memory_format=memory_format)},
)
def to_dtype(self, dtype: torch.dtype) -> _HiFloat8Tensor:
"""Create `_HiFloat8Tensor` with given nominal dtype
The new tensor has the same underlying HIF8 data.
"""
return _HiFloat8Tensor.make_like(
self,
data=self._data,
dtype=dtype,
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == aten.copy_.default:
dst = args[0]
src = args[1]
if not isinstance(dst, torch.Tensor):
raise RuntimeError(
"Attempted to copy into something that isn't a PyTorch tensor"
+ pta_error(ErrCode.TYPE)
)
if not isinstance(src, torch.Tensor):
raise RuntimeError(
"Attempted to copy from something that isn't a PyTorch tensor"
+ pta_error(ErrCode.TYPE)
)
dst_is_hif8 = isinstance(dst, _HiFloat8Tensor)
src_is_hif8 = isinstance(src, _HiFloat8Tensor)
if dst_is_hif8 and src_is_hif8:
dst._data.copy_(src._data)
elif not dst_is_hif8 and src_is_hif8:
dst.copy_(src.from_hifloat8())
elif dst_is_hif8 and not src_is_hif8:
src = src.expand(dst.size())
src = src.to(
device=dst.device,
memory_format=torch.contiguous_format,
)
if not dst._data.is_contiguous():
raise RuntimeError(
"Transformer Engine cast kernels require contiguous data"
+ pta_error(ErrCode.INTERNAL)
)
tex.cast_to_fp8_noalloc(
src.view(1, -1),
dst._data.view(1, -1),
tex.DType.hifloat8,
)
else:
raise RuntimeError(
"Using HiFloat8Tensor copy logic, but no HiFloat8Tensor found"
+ pta_error(ErrCode.INTERNAL)
)
return None
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
data_slice = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return _HiFloat8Tensor.make_like(tensor, data=data_slice)
if func == aten.detach.default:
return _HiFloat8Tensor.make_like(
args[0],
data=args[0]._data,
)
if func == aten.view.default:
tensor = args[0]
data = tensor._data
data_view = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return _HiFloat8Tensor.make_like(
tensor,
data=data_view,
)
def maybe_unwrap(t):
if isinstance(t, _HiFloat8Tensor):
return t.from_hifloat8()
return t
def maybe_update_inplace(arg, new_arg, schema_arg):
"""Update values of HIF8 tensors
Keep the same HIF8 scaling factors.
"""
check_args = isinstance(arg, _HiFloat8Tensor) and isinstance(new_arg, torch.Tensor)
check_schema = (
hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
)
if check_args and check_schema:
arg.copy_(new_arg)
if func._schema.is_mutable:
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
out = super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
if not (kwarg == new_kwarg == schema_arg.name):
raise ValueError('name of the kw argument should match' + pta_error(ErrCode.VALUE))
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def _make_in_reduce_ex(
cls,
data: torch.Tensor,
dtype: torch.dtype,
) -> _HiFloat8Tensor:
"""Build _HiFloat8Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return _HiFloat8Tensor(
data=data,
dtype=dtype,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to HIF8 metadata objects"""
return (
_HiFloat8Tensor._make_in_reduce_ex,
(self._data, self.dtype),
)
def _get_data(self) -> _HiFloat8Tensor:
"""Get tensor data property"""
return super().data
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Cast tensor to HIF8 and store in HIF8 buffer.
"""
with torch.no_grad():
self.copy_(tensor)
data = property(_get_data, _set_data)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
def transpose(self, dim0, dim1):
return _TransposeFunc.apply(self, dim0, dim1)