import dataclasses
import typing
from typing import Optional
import torch
import torch_npu
from mindspeed.te.pytorch.fp8.recipes.recipe import Recipe, RecipeScaling
from mindspeed.te.pytorch.fp8.scale_data import ScaleData
from mindspeed.te.pytorch.fp8.tensor import Float8Tensor, is_fp8_tensor
class DelayedScalingRecipe(Recipe):
ALL_SCALING = []
MAX_STREAM = None
def __init__(self, key, recipe_config: 'TEDelayedScaling', tensor_shape) -> None:
super().__init__(key, recipe_config, tensor_shape)
if DelayedScalingRecipe.MAX_STREAM is None:
DelayedScalingRecipe.MAX_STREAM = torch.cuda.Stream()
self.scale = ScaleData(recipe_config, self.fp8_format)
DelayedScalingRecipe.ALL_SCALING.append(self)
DelayedScalingRecipe.MAX_STREAM.wait_stream(torch.cuda.current_stream())
def finally_step(self):
torch.cuda.current_stream().wait_stream(DelayedScalingRecipe.MAX_STREAM)
self.scale.delayed_recipe_update_scale()
def quantization(self, tensor, key, colwise, rowwise):
if tensor is None:
return tensor
if is_fp8_tensor(tensor):
return tensor
scale = self.scale.delayed_recipe_update_amax(tensor, DelayedScalingRecipe.MAX_STREAM)
quant_tensor = self.run_quantizer(
tensor,
key,
torch_npu.npu_quantize,
allow_reuse=False,
scales=scale,
zero_points=None,
dtype=self.quant_dtype,
axis=-1,
)
return Float8Tensor(quant_tensor, self.quant_dtype, scale, dtype=tensor.dtype)
@dataclasses.dataclass
class TEDelayedScaling(RecipeScaling):
recipe = DelayedScalingRecipe
amax_reduce_group: torch.distributed.ProcessGroup = None
override_linear_precision: Optional[typing.Tuple] = None