""" """
import torch
def gen_uniform_data(data_shape, min_value, max_value, dtype):
"""
PyTorch版本的均匀分布数据生成, 与NumPy版本行为完全一致
严格保持 [min_value, max_value) 左闭右开区间特性
"""
if min_value == 0 and max_value == 0:
return torch.zeros(data_shape, dtype=dtype)
if dtype == torch.bool:
return torch.randint(0, 2, data_shape, dtype=dtype)
if torch.is_floating_point(torch.tensor(0, dtype=dtype)):
return min_value + (max_value - min_value) * torch.rand(data_shape, dtype=dtype)
else:
return torch.randint(
low=min_value, high=max_value, size=data_shape, dtype=dtype
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, cos, sin):
"""
q: (t, n_q, rope_dim), bf16
cos: (t, rope_dim), bf16
sin: (t, rope_dim), bf16
"""
input_dtype = q.dtype
q_new = q.to(torch.float32)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
cos = torch.unsqueeze(cos, dim=1)
sin = torch.unsqueeze(sin, dim=1)
t, n, d = q_new.shape
q_re = q_new.reshape(t, n, d // 2, 2)
q_rotary = rotate_half(q_re).reshape(t, n, d)
q_embed = (q_new * cos) + (q_rotary * -sin)
if input_dtype != torch.float32:
q_embed = q_embed.to(input_dtype)
return q_embed
def quant_golden(x: torch.Tensor):
x_dtype = x.dtype
x_fp32 = x.to(torch.float32)
max_value = torch.amax(torch.abs(x_fp32), dim=-1, keepdim=True)
scale_quant = 127.0 / max_value
y_fp32 = x_fp32 * scale_quant
y_fp32 = y_fp32.view(x.shape)
y_int32 = torch.round(y_fp32).to(torch.int32)
y_int8 = torch.trunc(y_int32.to(x_dtype)).to(torch.int8)
scale_dequant = 1.0 / scale_quant
return y_int8, scale_dequant