from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
from triton_kernels.tensor import convert_layout
from triton_kernels.tensor import wrap_torch_tensor, FP4
from triton_kernels.target_info import is_cuda, get_cdna_version, cuda_capability_geq
import torch
def quantize_weight(w, dtype, **opt):
if dtype == "bf16":
wq = w.to(torch.bfloat16).transpose(-1, -2).contiguous().transpose(-1, -2)
return wq, InFlexData(), None
elif dtype == "fp8":
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 else torch.float8_e4m3fnuz
wq = w.to(fp8e4_dtype)
if is_cuda() and not cuda_capability_geq(10, 0):
wq = wq.transpose(-1, -2).contiguous().transpose(-1, -2)
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), None
else:
assert dtype == "mx4", f"{dtype=}"
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
if opt:
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"], **opt["value_layout_opts"])
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"], **opt["scale_layout_opts"])
return w, InFlexData(), w_scale