import logging
from typing import List
import torch
from mindspeed.te.pytorch.fp8.constants import AMAX_COMPUTE_MAP, FP8Format
logger = logging.getLogger(__name__)
class ScaleData:
def __init__(self, recipe_config, fp8_format, scale_shape: List[int] = None):
self.config = recipe_config
self.ori_dtype = None
self.scale_shape = scale_shape if scale_shape is not None else [1]
self.device = 'npu:{}'.format(torch.npu.current_device())
self.fp8_format: FP8Format = fp8_format
self.fp8_max = self.fp8_format.max
self.margin = self.config.config.fp8_margin
self.scale = torch.ones(self.scale_shape, device=self.device)
self.amax_history_len = self.config.config.fp8_amax_history_len
self.amax_history_current_len = 0
if self.config.config.fp8_amax_compute_algo not in AMAX_COMPUTE_MAP:
raise AssertionError('Unsupported amax compute algo {}'.format(self.config.config.fp8_amax_compute_algo))
self.amax_compute = AMAX_COMPUTE_MAP[self.config.config.fp8_amax_compute_algo]
self.amax_history = torch.zeros([self.amax_history_len] + self.scale_shape, device=self.device)
self.amax = torch.zeros(self.scale_shape, device=self.device)
self.current_interval = 1
@property
def quantization_scale(self):
return self.scale if self.scale.numel() == 1 else self.scale[0][0]
@property
def last_history_index(self):
if self.amax_history_current_len < self.amax_history_len:
return self.amax_history_current_len - 1
return -1
def append_amax(self, amax):
if self.amax_history_current_len < self.amax_history_len:
self.amax_history[self.amax_history_current_len, :].copy_(amax, non_blocking=True)
self.amax_history_current_len += 1
else:
self.amax_history = self.amax_history.roll(-1, 1)
self.amax_history[self.amax_history_len - 1, :].copy_(amax, non_blocking=True)
def reduce_amax(self, group=None):
if group is None or torch.distributed.get_world_size(group) <= 1:
return
if self.amax_history_current_len < self.amax_history_len:
amax = self.amax_history[self.amax_history_current_len - 1, :]
else:
amax = self.amax_history[self.amax_history_len - 1, :]
torch.distributed.all_reduce(amax, op=torch.distributed.ReduceOp.MAX, group=group)
def delayed_recipe_update_scale(self):
self.reduce_amax(self.config.amax_reduce_group)
self.amax_compute(self.amax, self.amax_history, self.last_history_index)
self.scale.copy_(((self.amax * (2 ** self.margin)) / self.fp8_max), non_blocking=True)
def delayed_recipe_update_amax(self, tensor, stream):
if self.amax_history_current_len == 0:
self.current_interval = 1
amax = torch.amax(torch.abs(tensor))
self.append_amax(amax)
self.delayed_recipe_update_scale()
scale = self.scale.clone()
else:
stream.wait_stream(torch.cuda.current_stream())
scale = self.scale.clone()
if self.current_interval >= self.config.config.fp8_interval:
self.current_interval = 1
with torch.cuda.stream(stream):
amax = torch.amax(torch.abs(tensor))
self.append_amax(amax)
self.delayed_recipe_update_scale()
else:
self.current_interval += 1
return scale