"""
QAT Implementation Module
This module implements the core computation functions for Quantization-Aware Training (QAT).
It provides both symmetric and asymmetric quantization operations with forward and backward passes.
Main Functions:
- ai_infra_qat_asymmetric_per_group: Asymmetric per-group quantization
- ai_infra_qat_asymmetric_per_group_backward: Backward pass for asymmetric per-group quantization
- ai_infra_qat_symmetric_per_channel: Symmetric per-channel quantization
- ai_infra_qat_symmetric_per_channel_backward: Backward pass for symmetric per-channel quantization
- ai_infra_qat_symmetric_per_tensor: Symmetric per-tensor quantization
- ai_infra_qat_symmetric_per_tensor_backward: Backward pass for symmetric per-tensor quantization
Example:
See ai_infra_pypto_qat.py for usage examples.
"""
import pypto
import torch
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 64,
},
pass_options={"vec_nbuffer_setting": {-1: 2, -2: 1}},
)
def ai_infra_qat_asymmetric_per_group_kernel(
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
offset: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
output_bf16: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
eps,
n_levels,
neg_clip_val,
clip_val,
shift
):
pypto.experimental.set_operation_options(combine_axis=True)
unroll_list = [512, 256]
num_groups = scale.shape[0]
group_size = weight.shape[1]
pypto.set_vec_tile_shapes(128, 128)
for g_offset, unroll_length in pypto.loop_unroll(
0, num_groups, 1,
name="LOOP_GROUPS",
idx_name="g_offset",
unroll_list=unroll_list
):
tile_groups = unroll_length
weight_tile = pypto.view(weight, [tile_groups, group_size], [g_offset, 0])
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_tile = pypto.view(scale, [tile_groups, 1], [g_offset, 0])
offset_tile = pypto.view(offset, [tile_groups, 1], [g_offset, 0])
scale_fp32 = pypto.cast(scale_tile, pypto.DT_FP32)
offset_fp32 = pypto.cast(offset_tile, pypto.DT_FP32)
protected_scale = pypto.maximum(scale_fp32, eps)
alpha = pypto.mul(protected_scale, n_levels)
weight_shifted = pypto.sub(weight_fp32, offset_fp32)
weight_norm = pypto.div(weight_shifted, alpha)
weight_clipped = pypto.clip(weight_norm, neg_clip_val, clip_val)
weight_scaled = pypto.mul(weight_clipped, n_levels)
weight_shifted2 = pypto.sub(weight_scaled, shift)
weight_rounded = pypto.round(weight_shifted2, decimals=0)
weight_unshifted = pypto.add(weight_rounded, shift)
weight_denorm = pypto.div(weight_unshifted, n_levels)
weight_rescaled = pypto.mul(weight_denorm, alpha)
output = pypto.add(weight_rescaled, offset_fp32)
output_tile = pypto.cast(output, pypto.DT_BF16)
pypto.assemble(output_tile, [g_offset, 0], output_bf16)
def ai_infra_qat_asymmetric_per_group(weight, scale, offset, group_size=128, bit=4,
eps=1e-4, clip_val=0.99):
n_levels = 2 ** (bit - 1)
shift = 0.5
neg_clip_val = -clip_val
output_bf16 = torch.empty(weight.shape, dtype=weight.dtype, device=weight.device)
weight_grouped = weight.view(-1, group_size)
output_bf16_grouped = output_bf16.view(-1, group_size)
inputs = [
weight_grouped,
scale,
offset,
output_bf16_grouped,
eps,
n_levels,
neg_clip_val,
clip_val,
shift
]
ai_infra_qat_asymmetric_per_group_kernel(*inputs)
return output_bf16_grouped.view(weight.shape)
@pypto.frontend.jit(
pass_options={
"vec_nbuffer_setting": {-1: 2, -2: 1}
},
runtime_options={
"stitch_function_max_num": 64
},
)
def ai_infra_qat_asymmetric_per_group_backward_kernel(
grad_output: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
offset: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
grad_weight_out: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
grad_scale_out: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
grad_offset_out: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
eps,
n_levels,
neg_clip_val,
clip_val,
shift
):
group_size = grad_output.shape[1]
num_groups = scale.shape[0]
unroll_list = [512, 256]
pypto.experimental.set_operation_options(combine_axis=True)
for g_offset, unroll_length in pypto.loop_unroll(
0, num_groups, 1,
name="LOOP_GROUPS",
idx_name="g_offset",
unroll_list=unroll_list
):
tile_groups = unroll_length
pypto.set_vec_tile_shapes(64, 128)
grad_out_tile = pypto.view(grad_output, [tile_groups, group_size], [g_offset, 0])
weight_tile = pypto.view(weight, [tile_groups, group_size], [g_offset, 0])
scale_tile = pypto.view(scale, [tile_groups, 1], [g_offset, 0])
offset_tile = pypto.view(offset, [tile_groups, 1], [g_offset, 0])
grad_out_fp32 = pypto.cast(grad_out_tile, pypto.DT_FP32)
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_fp32 = pypto.cast(scale_tile, pypto.DT_FP32)
offset_fp32 = pypto.cast(offset_tile, pypto.DT_FP32)
protected_scale = pypto.maximum(scale_fp32, eps)
alpha = pypto.mul(protected_scale, n_levels)
weight_shifted = pypto.sub(weight_fp32, offset_fp32)
weight_norm = pypto.div(weight_shifted, alpha)
weight_clipped = pypto.clip(weight_norm, neg_clip_val, clip_val)
weight_scaled = pypto.mul(weight_clipped, n_levels)
weight_shifted2 = pypto.sub(weight_scaled, shift)
weight_rounded = pypto.round(weight_shifted2, decimals=0)
weight_unshifted = pypto.add(weight_rounded, shift)
weight_denorm = pypto.div(weight_unshifted, n_levels)
diff = pypto.sub(weight_norm, weight_clipped)
abs_diff = pypto.abs(diff)
big_number = 1e15
sign = pypto.mul(abs_diff, big_number)
is_out = pypto.clip(sign, 0.0, 1.0)
one = pypto.full(is_out.shape, 1.0, is_out.dtype)
mask_f32 = pypto.sub(one, is_out)
one_tile_group_gs = pypto.full([tile_groups, group_size], 1.0, pypto.DT_FP32)
inv_mask_f32 = pypto.sub(one_tile_group_gs, mask_f32)
scale_diff = pypto.sub(scale_fp32, eps)
diff_pos = pypto.maximum(scale_diff, 0.0)
amplified_diff = pypto.mul(diff_pos, 1e15)
scale_mask_f32 = pypto.clip(amplified_diff, 0.0, 1.0)
grad_weight_fp32 = pypto.mul(grad_out_fp32, mask_f32)
grad_offset_pre = pypto.mul(grad_out_fp32, inv_mask_f32)
grad_offset_fp32 = pypto.sum(grad_offset_pre, dim=1, keepdim=True)
term_w_norm = pypto.mul(weight_norm, mask_f32)
term_diff = pypto.sub(weight_denorm, term_w_norm)
grad_alpha_pre = pypto.mul(grad_out_fp32, term_diff)
grad_alpha_fp32 = pypto.sum(grad_alpha_pre, dim=1, keepdim=True)
grad_scale_pre = pypto.mul(grad_alpha_fp32, n_levels)
grad_scale_fp32 = pypto.mul(grad_scale_pre, scale_mask_f32)
grad_w_bf16 = pypto.cast(grad_weight_fp32, pypto.DT_BF16)
grad_s_bf16 = pypto.cast(grad_scale_fp32, pypto.DT_BF16)
grad_o_bf16 = pypto.cast(grad_offset_fp32, pypto.DT_BF16)
pypto.assemble(grad_w_bf16, [g_offset, 0], grad_weight_out)
pypto.assemble(grad_s_bf16, [g_offset, 0], grad_scale_out)
pypto.assemble(grad_o_bf16, [g_offset, 0], grad_offset_out)
def ai_infra_qat_asymmetric_per_group_backward(grad_output, weight_pto, scale_pto, offset_pto,
group_size=128, bit=4, eps=1e-4, clip_val=0.99):
n_levels = 2 ** (bit - 1)
shift = 0.5
neg_clip_val = -clip_val
grad_weight_out = torch.empty(weight_pto.shape, dtype=weight_pto.dtype, device=weight_pto.device)
grad_scale_out = torch.empty(scale_pto.shape, dtype=scale_pto.dtype, device=scale_pto.device)
grad_offset_out = torch.empty(offset_pto.shape, dtype=offset_pto.dtype, device=offset_pto.device)
grad_output_grouped = grad_output.view(-1, group_size)
weight_pto_grouped = weight_pto.view(-1, group_size)
grad_weight_out_grouped = grad_weight_out.view(-1, group_size)
inputs = [
grad_output_grouped,
weight_pto_grouped,
scale_pto,
offset_pto,
grad_weight_out_grouped,
grad_scale_out,
grad_offset_out,
eps,
n_levels,
neg_clip_val,
clip_val,
shift,
]
ai_infra_qat_asymmetric_per_group_backward_kernel(*inputs)
return grad_weight_out_grouped.view(weight_pto.shape), grad_scale_out, grad_offset_out
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 64,
},
)
def ai_infra_qat_symmetric_per_channel_kernel(
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
output_bf16: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
eps,
min_v,
max_v
):
pypto.experimental.set_operation_options(combine_axis=True)
n, m = weight.shape
unroll_list = [512, 32, 8]
pypto.set_vec_tile_shapes(32, 512)
for n_offset, unroll_length in pypto.loop_unroll(
0, n, 1,
name="LOOP_N_UNROLL",
idx_name="n_offset",
unroll_list=unroll_list
):
tile_n = unroll_length
weight_tile = pypto.view(weight, [tile_n, m], [n_offset, 0])
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_tile = pypto.view(scale, [tile_n, 1], [n_offset, 0])
scale_fp32 = pypto.cast(scale_tile, pypto.DT_FP32)
protected_scale = pypto.maximum(scale_fp32, eps)
normalized = pypto.div(weight_fp32, protected_scale)
rounded = pypto.round(normalized, decimals=0)
clamped = pypto.clip(rounded, min_v, max_v)
output = pypto.mul(clamped, protected_scale)
output_tile = pypto.cast(output, pypto.DT_BF16)
pypto.assemble(output_tile, [n_offset, 0], output_bf16)
def ai_infra_qat_symmetric_per_channel(weight, scale, eps, min_v, max_v):
output_pto = torch.empty(weight.shape, dtype=torch.bfloat16, device=weight.device)
inputs = [
weight,
scale,
output_pto,
eps,
min_v,
max_v
]
ai_infra_qat_symmetric_per_channel_kernel(*inputs)
return output_pto
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 64,
},
pass_options={"vec_nbuffer_setting": {-1: 2, -2: 1}},
)
def ai_infra_qat_symmetric_per_channel_backward_kernel(
grad_output: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
grad_weight_out: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
grad_scale_out: pypto.Tensor([pypto.DYNAMIC, 1], pypto.DT_BF16),
eps,
min_v,
max_v,
):
n, m = weight.shape
unroll_list = [512, 32, 8]
tile_shapes_m = min(m, 4096)
pypto.set_vec_tile_shapes(4, tile_shapes_m)
for n_offset, unroll_length in pypto.loop_unroll(
0, n, 1,
name="BACKWARD_LOOP_N_UNROLL",
idx_name="n_offset",
unroll_list=unroll_list
):
tile_n = unroll_length
grad_out_tile = pypto.view(grad_output, [tile_n, m], [n_offset, 0])
weight_tile = pypto.view(weight, [tile_n, m], [n_offset, 0])
scale_tile = pypto.view(scale, [tile_n, 1], [n_offset, 0])
grad_out_fp32 = pypto.cast(grad_out_tile, pypto.DT_FP32)
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_fp32_tile = pypto.cast(scale_tile, pypto.DT_FP32)
protected_scale_tile = pypto.maximum(scale_fp32_tile, eps)
scale_mask_tile = pypto.ge(scale_fp32_tile, eps)
scale_mask_fp32_tile = pypto.where(scale_mask_tile, 1.0, 0.0)
normalized = pypto.div(weight_fp32, protected_scale_tile)
rounded = pypto.round(normalized, decimals=0)
clamped = pypto.clip(rounded, min_v, max_v)
diff = pypto.sub(rounded, clamped)
abs_diff = pypto.abs(diff)
out_of_bounds = pypto.clip(abs_diff, 0.0, 1.0)
neg_out_of_bounds = pypto.mul(out_of_bounds, -1.0)
mask_float = pypto.add(neg_out_of_bounds, 1.0)
grad_weight_fp32 = pypto.mul(grad_out_fp32, mask_float)
grad_weight_tile = pypto.cast(grad_weight_fp32, pypto.DT_BF16)
pypto.assemble(grad_weight_tile, [n_offset, 0], grad_weight_out)
grad_scale_mul_tile = pypto.mul(grad_out_fp32, clamped)
grad_scale_mul_tile_sum = pypto.sum(grad_scale_mul_tile, dim=1, keepdim=True)
neg_weight_fp32 = pypto.mul(weight_fp32, -1.0)
weight_div_scale_tile = pypto.div(neg_weight_fp32, protected_scale_tile)
grad_scale_div_step1 = pypto.mul(grad_out_fp32, mask_float)
grad_scale_div_tile = pypto.mul(grad_scale_div_step1, weight_div_scale_tile)
grad_scale_div_tile_sum = pypto.sum(grad_scale_div_tile, dim=1, keepdim=True)
grad_scale_sum_m = pypto.add(grad_scale_mul_tile_sum, grad_scale_div_tile_sum)
grad_scale_fp32_masked = pypto.mul(grad_scale_sum_m, scale_mask_fp32_tile)
grad_scale_bf16_tile = pypto.cast(grad_scale_fp32_masked, pypto.DT_BF16)
pypto.assemble(grad_scale_bf16_tile, [n_offset, 0], grad_scale_out)
def ai_infra_qat_symmetric_per_channel_backward(grad_output, weight, scale, eps, min_v, max_v):
grad_weight_out = torch.empty(weight.shape, dtype=weight.dtype, device=weight.device)
grad_scale_out = torch.empty(scale.shape, dtype=scale.dtype, device=scale.device)
inputs = [
grad_output,
weight,
scale,
grad_weight_out,
grad_scale_out,
eps,
min_v,
max_v
]
ai_infra_qat_symmetric_per_channel_backward_kernel(*inputs)
return grad_weight_out, grad_scale_out
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 64
},
)
def ai_infra_qat_symmetric_per_tensor_kernel(
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([1, 1], pypto.DT_BF16),
output_bf16: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
eps,
min_v,
max_v
):
unroll_list = [512, 32, 8]
pypto.experimental.set_operation_options(combine_axis=True)
n, m = weight.shape
pypto.set_vec_tile_shapes(32, 512)
scale_fp32 = pypto.cast(scale, pypto.DT_FP32)
protected_scale = pypto.maximum(scale_fp32, eps)
for n_offset, unroll_length in pypto.loop_unroll(
0, n, 1,
name="LOOP_N_UNROLL",
idx_name="n_offset",
unroll_list=unroll_list
):
tile_n = unroll_length
weight_tile = pypto.view(weight, [tile_n, m], [n_offset, 0])
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_n = pypto.expand_clone(protected_scale, [tile_n, 1])
normalized = pypto.div(weight_fp32, scale_n)
rounded = pypto.round(normalized, decimals=0)
clamped = pypto.clip(rounded, min_v, max_v)
output = pypto.mul(clamped, scale_n)
output_tile = pypto.cast(output, pypto.DT_BF16)
pypto.assemble(output_tile, [n_offset, 0], output_bf16)
def ai_infra_qat_symmetric_per_tensor(weight, scale, eps, min_v, max_v):
output_pto = torch.empty(weight.shape, dtype=torch.bfloat16, device=weight.device)
inputs = [
weight,
scale,
output_pto,
eps,
min_v,
max_v
]
ai_infra_qat_symmetric_per_tensor_kernel(*inputs)
return output_pto
@pypto.frontend.jit(
runtime_options={
"stitch_function_max_num": 64,
},
pass_options={"vec_nbuffer_setting": {-1: 4}},
)
def ai_infra_qat_symmetric_per_tensor_backward_kernel(
grad_output: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
weight: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
scale: pypto.Tensor([1, 1], pypto.DT_BF16),
grad_weight_out: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_BF16),
grad_scale_out: pypto.Tensor([1, 1], pypto.DT_BF16),
eps,
min_v,
max_v,
):
pypto.experimental.set_operation_options(combine_axis=False)
n, m = weight.shape
unroll_list = [512, 32, 8]
tile_shapes_m = min(m, 4096)
pypto.set_vec_tile_shapes(4, tile_shapes_m)
scale_fp32 = pypto.cast(scale, pypto.DT_FP32)
protected_scale = pypto.maximum(scale_fp32, eps)
scale_mask = pypto.ge(scale_fp32, eps)
scale_mask_fp32 = pypto.where(scale_mask, 1.0, 0.0)
grad_scale_acc = pypto.full([1, 1], 0.0, pypto.DT_FP32)
for n_offset, unroll_length in pypto.loop_unroll(
0, n, 1,
name="BACKWARD_LOOP_N_UNROLL",
idx_name="n_offset",
unroll_list=unroll_list
):
tile_n = unroll_length
grad_out_tile = pypto.view(grad_output, [tile_n, m], [n_offset, 0])
grad_out_fp32 = pypto.cast(grad_out_tile, pypto.DT_FP32)
weight_tile = pypto.view(weight, [tile_n, m], [n_offset, 0])
weight_fp32 = pypto.cast(weight_tile, pypto.DT_FP32)
scale_n = pypto.expand_clone(protected_scale, [tile_n, 1])
normalized = pypto.div(weight_fp32, scale_n)
rounded = pypto.round(normalized, decimals=0)
clamped = pypto.clip(rounded, min_v, max_v)
diff = pypto.sub(rounded, clamped)
abs_diff = pypto.abs(diff)
out_of_bounds = pypto.clip(abs_diff, 0.0, 1.0)
neg_out_of_bounds = pypto.mul(out_of_bounds, -1.0)
mask_float = pypto.add(neg_out_of_bounds, 1.0)
grad_weight_fp32 = pypto.mul(grad_out_fp32, mask_float)
grad_weight_tile = pypto.cast(grad_weight_fp32, pypto.DT_BF16)
pypto.assemble(grad_weight_tile, [n_offset, 0], grad_weight_out)
grad_scale_mul_tile = pypto.mul(grad_out_fp32, clamped)
grad_scale_mul_tile_m = pypto.sum(grad_scale_mul_tile, dim=1, keepdim=True)
neg_weight_fp32 = pypto.mul(weight_fp32, -1.0)
weight_div_scale_tile = pypto.div(neg_weight_fp32, scale_n)
grad_scale_div_step1 = pypto.mul(grad_out_fp32, mask_float)
grad_scale_div_tile = pypto.mul(grad_scale_div_step1, weight_div_scale_tile)
grad_scale_div_tile_m = pypto.sum(grad_scale_div_tile, dim=1, keepdim=True)
pypto.set_vec_tile_shapes(512, 1)
grad_scale_mul_tile_n = pypto.sum(grad_scale_mul_tile_m, dim=0, keepdim=True)
grad_scale_div_tile_n = pypto.sum(grad_scale_div_tile_m, dim=0, keepdim=True)
grad_scale_tile = pypto.add(grad_scale_mul_tile_n, grad_scale_div_tile_n)
grad_scale_acc[:] = pypto.add(grad_scale_acc, grad_scale_tile)
final_grad_scale_fp32 = pypto.mul(grad_scale_acc, scale_mask_fp32)
final_grad_scale_bf16 = pypto.cast(final_grad_scale_fp32, pypto.DT_BF16)
pypto.assemble(final_grad_scale_bf16, [0, 0], grad_scale_out)
def ai_infra_qat_symmetric_per_tensor_backward(grad_output, weight, scale, eps, min_v, max_v):
grad_weight_out = torch.empty(weight.shape, dtype=weight.dtype, device=weight.device)
grad_scale_out = torch.empty(scale.shape, dtype=scale.dtype, device=scale.device)
inputs = [
grad_output,
weight,
scale,
grad_weight_out,
grad_scale_out,
eps,
min_v,
max_v
]
ai_infra_qat_symmetric_per_tensor_backward_kernel(*inputs)
return grad_weight_out, grad_scale_out