"""
QAT Test Module
This module provides test cases and golden reference implementations for QAT operators.
Main Functions:
- create_asymmetric_qat_golden: Golden reference for asymmetric per-group quantization
- create_symmetric_qat_nscale_golden: Golden reference for symmetric per-channel quantization
- create_symmetric_qat_golden: Golden reference for symmetric per-tensor quantization
- forward_test: Generic forward test interface
- backward_test_autograd: Generic backward test interface
Example:
python ai_infra_pypto_qat.py
"""
import logging
import math
import os
import re
from typing import Any, Tuple, Union
import numpy as np
import pytest
import torch
import torch_npu
from numpy.testing import assert_allclose
from qat_impl import (
ai_infra_qat_asymmetric_per_group,
ai_infra_qat_asymmetric_per_group_backward,
ai_infra_qat_symmetric_per_channel,
ai_infra_qat_symmetric_per_channel_backward,
ai_infra_qat_symmetric_per_tensor,
ai_infra_qat_symmetric_per_tensor_backward,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
DISTRIBUTION = [
"uniform_large",
]
def _uniform(shape: Tuple[int, ...], low: float, high: float) -> torch.Tensor:
"""Return a float32 tensor with uniform distribution in [low, high]."""
return torch.empty(shape, dtype=torch.float32).uniform_(low, high)
def _normal(shape: Tuple[int, ...]) -> torch.Tensor:
"""Return a float32 tensor drawn from a normal distribution.
μ is sampled uniformly from [-100, 100] and σ from [1, 25], matching the
original test configuration.
"""
mu = np.random.uniform(-100, 100)
sigma = np.random.uniform(1, 25)
return torch.randn(shape, dtype=torch.float32) * sigma + mu
def _outlier(shape: Tuple[int, ...]) -> torch.Tensor:
"""Generate a normal tensor and inject outliers.
0.1% of the elements are multiplied by 1000 to create extreme values.
"""
tensor = _normal(shape)
mask = torch.rand(shape) < 0.001
tensor[mask] *= 1000.0
return tensor
def create_input(
shape: Union[Tuple[int, ...], list],
dtype: str = "float32",
device: str = "cpu",
distribution: str = "uniform_large",
seed: int = 33,
) -> torch.Tensor:
"""Generate a tensor according to the requested distribution.
Parameters
----------
shape : tuple or list of ints
Desired tensor shape, e.g. (n, m).
dtype : str, optional
Torch dtype name (default "float32").
device : str, optional
Device string understood by ``torch.device`` (default "cpu").
distribution : str, optional
Supported formats:
- "uniform[low,high]" where low and high are floats, e.g., "uniform[-5,5]" or "uniform[-0.001,0.001]"
- "normal"
- "outlier"
The default is "uniform_small" (equivalent to "uniform[-0.001,0.001]").
Returns
-------
torch.Tensor
Tensor on the requested device with the requested dtype.
"""
torch.manual_seed(seed)
if isinstance(shape, list):
shape = tuple(shape)
if not isinstance(shape, tuple):
raise TypeError("shape must be a tuple or list of ints")
if distribution is None:
distribution = "uniform_large"
if distribution.startswith("uniform["):
match = re.match(r"uniform\[\s*([-+]?[0-9]*\.?[0-9]+)\s*,\s*([-+]?[0-9]*\.?[0-9]+)\s*\]$", distribution)
if not match:
raise ValueError(
f"Invalid uniform distribution spec '{distribution}'. Expected format uniform[low,high]"
)
low, high = map(float, match.groups())
tensor_fp32 = _uniform(shape, low, high)
elif distribution == "uniform_small":
tensor_fp32 = _uniform(shape, -0.001, 0.001)
elif distribution == "uniform_large":
tensor_fp32 = _uniform(shape, -5.0, 5.0)
elif distribution == "normal":
tensor_fp32 = _normal(shape)
elif distribution == "outlier":
tensor_fp32 = _outlier(shape)
else:
raise ValueError(
f"Unsupported distribution '{distribution}'. "
"Supported: uniform[low,high], uniform_small, uniform_large, normal, outlier"
)
tensor = tensor_fp32.to(dtype)
try:
tensor = tensor.to(device)
except Exception as exc:
raise RuntimeError(f"Failed to move tensor to device '{device}'.") from exc
return tensor
small_value_thres_dict = {
torch.float16: 2**-11,
torch.bfloat16: 2**-8,
torch.float32: 2**-14,
torch.uint8: 2**-4, torch.float8_e4m3fn: 2**-4
}
small_value_error_thres_dict = {
torch.float16: 2**-16,
torch.bfloat16: 2**-16,
torch.float32: 2**-30,
torch.uint8: 2**-6, torch.float8_e4m3fn: 2**-6
}
def get_split_index(golden_data, dtype):
thres = small_value_thres_dict[dtype]
large_mask = torch.abs(golden_data) >= thres
small_mask = torch.abs(golden_data) < thres
return large_mask, small_mask, thres
def compute_matrix_small_value(input_data, golden_data, dtype, small_mask):
if not torch.any(small_mask):
return 0
thres = small_value_error_thres_dict[dtype]
error_count = torch.sum(torch.abs(input_data[small_mask] - golden_data[small_mask]) > thres).item()
return error_count
def compute_matrix_large_value(input_data, golden_data, large_mask):
if not torch.any(large_mask):
return 0, 0, 0
input_large = input_data[large_mask]
golden_large = golden_data[large_mask]
abs_diff = torch.abs(input_large - golden_large)
relative_error = abs_diff / (torch.abs(golden_large) + 1e-7)
mare = torch.max(relative_error).item()
mere = torch.mean(relative_error).item()
rmse = torch.sqrt(torch.mean((input_large - golden_large) ** 2)).item()
return mare, mere, rmse
def compute_re_matrix(input_value, bm_value, small_value_thres):
if math.isinf(bm_value) or math.isnan(bm_value):
return 1
if math.isinf(input_value) or math.isnan(input_value):
return 1000
return input_value / max(bm_value, small_value_thres)
def compute_re_triplet_matrix(npu_matrix, golden_matrix, small_value_thres):
mare_npu, mere_npu, rmse_npu = npu_matrix
mare_bm, mere_bm, rmse_bm = golden_matrix
mare_matrix = compute_re_matrix(mare_npu, mare_bm, small_value_thres)
mere_matrix = compute_re_matrix(mere_npu, mere_bm, small_value_thres)
rmse_matrix = compute_re_matrix(rmse_npu, rmse_bm, small_value_thres)
return mare_matrix, mere_matrix, rmse_matrix
def precision_compare_triple(pto_data, bm_data, golden_data, thres=(2, 1.2, 1.2)):
logger.info(f"{pto_data.dtype=}")
logger.info(f"{bm_data.dtype=}")
logger.info(f"{golden_data.dtype=}")
dtype = pto_data.dtype
if dtype in ["int8", "int32"]:
raise NotImplementedError("precision compare triplet only support float")
if dtype == torch.uint8:
pto_data = torch_npu.npu_dtype_cast(pto_data, torch.float32, input_dtype=torch_npu.hifloat8)
bm_data = torch_npu.npu_dtype_cast(bm_data, torch.float32, input_dtype=torch_npu.hifloat8)
golden_data = torch_npu.npu_dtype_cast(golden_data, torch.float32, input_dtype=torch_npu.hifloat8)
else:
pto_data = pto_data.to(torch.float32)
bm_data = bm_data.to(torch.float32)
golden_data = golden_data.to(torch.float32)
pto_data = pto_data.cpu()
bm_data = bm_data.cpu()
golden_data = golden_data.cpu()
large_value_idx, small_value_idx, small_value_thres = get_split_index(golden_data, dtype)
npu_error_count = compute_matrix_small_value(pto_data, golden_data, dtype, small_value_idx)
bm_error_count = compute_matrix_small_value(bm_data, golden_data, dtype, small_value_idx)
small_value_matrix = npu_error_count / max(bm_error_count, 1)
mare_npu, mere_npu, rmse_npu = compute_matrix_large_value(pto_data, golden_data, large_value_idx)
mare_bm, mere_bm, rmse_bm = compute_matrix_large_value(bm_data, golden_data, large_value_idx)
mare_matrix, mere_matrix, rmse_matrix = compute_re_triplet_matrix(
[mare_npu, mere_npu, rmse_npu], [mare_bm, mere_bm, rmse_bm], small_value_thres)
is_mare_acceptable = mare_matrix <= thres[0]
is_mere_acceptable = mere_matrix <= thres[1]
is_rmse_acceptable = rmse_matrix <= thres[2]
if all([
small_value_matrix <= 2,
is_mare_acceptable,
is_mere_acceptable,
is_rmse_acceptable
]):
result = "PASS"
else:
result = "FAILED"
return result, mare_matrix, mere_matrix, rmse_matrix, small_value_matrix
def compare(pto_grad_w, bm_grad_w, golden_grad_w):
result, mare_matrix, mere_matrix, rmse_matrix, small_value_matrix = precision_compare_triple(
pto_grad_w, bm_grad_w, golden_grad_w)
logger.info(f" precision result: {result}")
logger.info(f" mare_matrix: {mare_matrix}")
logger.info(f" mere_matrix: {mere_matrix}")
logger.info(f" rmse_matrix: {rmse_matrix}")
logger.info(f" small_value_matrix: {small_value_matrix}")
if result != "PASS":
raise Exception("fail precision check")
return result, mare_matrix, mere_matrix, rmse_matrix, small_value_matrix
def clone_inputs(inputs):
"""Clone inputs and preserve requires_grad"""
new_inputs = []
for x in inputs:
if isinstance(x, torch.Tensor):
y = x.detach().clone()
y.requires_grad_(x.requires_grad)
new_inputs.append(y)
else:
new_inputs.append(x)
return tuple(new_inputs)
def to_double_inputs(inputs):
new_inputs = []
for x in inputs:
if isinstance(x, torch.Tensor):
y = x.detach().cpu().double()
y.requires_grad_(x.requires_grad)
new_inputs.append(y)
else:
new_inputs.append(x)
return tuple(new_inputs)
def _to_double_cpu_backward(inputs):
new_inputs = []
for x in inputs:
if isinstance(x, torch.Tensor):
new_inputs.append(x.detach().cpu().double().requires_grad_(x.requires_grad))
else:
new_inputs.append(x)
return tuple(new_inputs)
def normalize_outputs(out):
if isinstance(out, (tuple, list)):
return tuple(out)
return (out,)
def collect_grads(inputs):
grads = []
for x in inputs:
if isinstance(x, torch.Tensor) and x.requires_grad:
if x.grad is None:
raise Exception(f"[ERROR]grad is None, shape={tuple(x.shape)}")
else:
grads.append(x.grad.detach().clone())
else:
grads.append(None)
return grads
def forward_test(inputs: Tuple[Any, ...], pto_inputs, golden_func, pto_func):
"""
Generic forward test interface.
Parameters
----------
inputs : tuple
All inputs (tensor + non-tensor)
golden_func : callable
Reference implementation
golden_func(*inputs, is_golden=False)
golden_func(*inputs, is_golden=True)
pto_func : callable
Kernel under test
pto_func(*inputs)
Constraint
----------
golden_func last parameter must be is_golden
"""
bm_inputs = clone_inputs(inputs)
golden_inputs = to_double_inputs(inputs)
bm_out = normalize_outputs(golden_func(*bm_inputs, is_golden=False))
golden_out = normalize_outputs(golden_func(*golden_inputs, is_golden=True))
kernel_out = normalize_outputs(pto_func(*pto_inputs))
assert len(bm_out) == len(kernel_out)
compare_results = []
for i, _ in enumerate(bm_out):
logger.info(f"=== Forward Output[{i}] ===")
try:
assert_allclose(kernel_out[i].float().cpu(), bm_out[i].float().cpu(), rtol=1e-3, atol=1e-3)
except Exception as e:
logger.error(e)
result = compare(
kernel_out[i],
bm_out[i],
golden_out[i]
)
compare_results.append(result)
return compare_results
def backward_test_autograd(inputs, pto_inputs, golden_func, pto_func):
"""
Constraint:
1. inputs passes all input tensors and other parameters
2. golden_func last parameter is is_golden
3. pto_func first parameter is grad_outputs, rest are gradients of corresponding parameters in golden_func
"""
bm_inputs = clone_inputs(inputs)
bm_out = golden_func(*bm_inputs, is_golden=False)
bm_out = normalize_outputs(bm_out)
bm_out = tuple(o.to(torch.bfloat16) for o in bm_out)
grad_outputs = tuple(torch.randn(o.shape, device=o.device, dtype=o.dtype) for o in bm_out)
torch.autograd.backward(bm_out, grad_outputs)
bm_grads = collect_grads(bm_inputs)
golden_inputs = _to_double_cpu_backward(inputs)
golden_out = golden_func(*golden_inputs, is_golden=True)
golden_out = normalize_outputs(golden_out)
grad_outputs_golden = tuple(g.detach().cpu().double() for g in grad_outputs)
torch.autograd.backward(golden_out, grad_outputs_golden)
golden_grads = collect_grads(golden_inputs)
if len(grad_outputs) == 1:
pto_grads = pto_func(grad_outputs[0], *pto_inputs)
else:
pto_grads = pto_func(*grad_outputs, *pto_inputs)
if not isinstance(pto_grads, (tuple, list)):
pto_grads = (pto_grads,)
compare_results = []
idx = 0
for i, x in enumerate(inputs):
if isinstance(x, torch.Tensor) and x.requires_grad:
logger.info(f"=== compare grad of input[{i}] ===")
pto_grad = pto_grads[idx]
bm_grad = bm_grads[i]
golden_grad = golden_grads[i]
try:
assert_allclose(pto_grad.float().cpu(), bm_grad.float().cpu(), rtol=1e-3, atol=1e-3)
except Exception as e:
logger.error(e)
result = compare(pto_grad, bm_grad, golden_grad)
idx += 1
compare_results.append(result)
return compare_results
def create_asymmetric_qat_golden(group_size, bit, eps=1e-4, clip_val=0.99):
def asymmetric_qat_golden(weight, scale, offset, is_golden=False):
"""PyTorch reference implementation for Enhanced LSQ+ asymmetric quantization (BF16 I/O, FP32 compute).
Args:
weight: Input weight tensor (n, m) in BF16
scale: Quantization scale tensor (num_groups, 1) in BF16
offset: Quantization offset tensor (num_groups, 1) in BF16
group_size: Number of elements per group (default: 128)
bit: Quantization bit-width (2, 3, or 4)
"""
if not is_golden:
weight_in = weight.float()
scale_in = scale.float()
offset_in = offset.float()
else:
weight_in = weight.double()
scale_in = scale.double()
offset_in = offset.double()
if weight_in.ndim != 2:
raise ValueError(f"weight must be 2D (n, m), got shape {tuple(weight.shape)}")
if weight_in.shape[1] % group_size != 0:
raise ValueError(
f"weight.shape[1] must be divisible by group_size={group_size}, got m={weight_in.shape[1]}"
)
eps_tensor = torch.tensor(eps, device=scale_in.device, dtype=torch.float32)
protected_scale = torch.where(scale_in > eps_tensor, scale_in, eps_tensor)
orig_shape = weight.shape
num_groups = weight.numel() // group_size
expected_group_shape = (num_groups, 1)
if tuple(scale.shape) != expected_group_shape:
raise ValueError(f"scale must have shape {expected_group_shape}, got {tuple(scale.shape)}")
if tuple(offset.shape) != expected_group_shape:
raise ValueError(f"offset must have shape {expected_group_shape}, got {tuple(offset.shape)}")
weight_2d = weight_in.view(num_groups, group_size)
n_levels = 2 ** (bit - 1)
shift = 0.5
weight_shifted = weight_2d - offset_in
alpha = protected_scale * n_levels
weight_clipped = torch.clamp(weight_shifted / alpha, -clip_val, clip_val) * n_levels - shift
weight_rounded = (weight_clipped.round() - weight_clipped).detach() + weight_clipped
weight_unshifted = weight_rounded + shift
weight_denorm = weight_unshifted / n_levels
output_2d = weight_denorm * alpha + offset_in
output = output_2d.view(orig_shape)
if is_golden:
return output
else:
return output.to(torch.bfloat16)
return asymmetric_qat_golden
def create_symmetric_qat_nscale_golden(eps, min_v, max_v):
def symmetric_qat_golden(weight, scale, is_golden=False):
"""PyTorch reference implementation for embedding head quantization (BF16 I/O, FP32 compute).
Args:
weight: Input weight tensor (n, m) in BF16
scale: Quantization scale tensor (n, 1) in BF16
eps: Minimum scale threshold (default: 1e-4)
min_v: Quantization lower bound (default: -128.0)
max_v: Quantization upper bound (default: 127.0)
Returns:
Tuple of (quantized_weight, clamped, protected_scale) all in BF16
"""
if not is_golden:
weight_in = weight.float()
scale_in = scale.float()
else:
weight_in = weight.to(torch.float64)
scale_in = scale.to(torch.float64)
eps_tensor = torch.tensor(eps, device=weight.device, dtype=torch.float64)
protected_scale = torch.where(scale_in > eps_tensor, scale_in, eps_tensor)
weight_normalized = weight_in / protected_scale
weight_rounded = (weight_normalized.round() - weight_normalized).detach() + weight_normalized
clamped = torch.clamp(weight_rounded, min_v, max_v)
output = clamped * protected_scale
if is_golden:
return output
else:
return output.to(torch.bfloat16)
return symmetric_qat_golden
def create_symmetric_qat_golden(eps, min_v, max_v):
def symmetric_qat_golden(weight, scale, is_golden=False):
"""PyTorch reference implementation for embedding head quantization (BF16 I/O, FP32 compute).
Args:
weight: Input weight tensor (n, m) in BF16
scale: Quantization scale tensor (1, 1) scalar in BF16
eps: Minimum scale threshold (default: 1e-4)
min_v: Quantization lower bound (default: -128.0)
max_v: Quantization upper bound (default: 127.0)
Returns:
Tuple of (quantized_weight, clamped, protected_scale) all in BF16
"""
if not is_golden:
weight_in = weight.float()
scale_in = scale.float()
else:
weight_in = weight.double()
scale_in = scale.double()
eps_tensor = torch.tensor(eps, device=scale_in.device, dtype=scale_in.dtype)
protected_scale = torch.where(scale_in > eps_tensor, scale_in, eps_tensor)
weight_normalized = weight_in / protected_scale
weight_rounded = weight_normalized + (weight_normalized.round() - weight_normalized).detach()
clamped = torch.clamp(weight_rounded, min_v, max_v)
output = clamped * protected_scale
if is_golden:
return output
else:
return output.to(torch.bfloat16)
return symmetric_qat_golden
def run_asymmetric_per_group_test(n, m, group_size, bit, eps, clip_val, distribution, device_id):
"""Run a single test case for asymmetric per-group quantization."""
device = f"npu:{device_id}"
seed = 33
weight_shape = (n, m)
groups_per_row = m // group_size
num_groups = n * groups_per_row
scale_shape = (num_groups, 1)
offset_shape = (num_groups, 1)
weight_bm = create_input(weight_shape, torch.bfloat16, device, distribution, seed)
scale_bm = create_input(scale_shape, torch.bfloat16, device, distribution, seed)
offset_bm = create_input(offset_shape, torch.bfloat16, device, distribution, seed)
golden_inputs = [weight_bm, scale_bm, offset_bm]
pto_inputs = [weight_bm, scale_bm, offset_bm, group_size, bit, eps, clip_val]
golden = create_asymmetric_qat_golden(group_size, bit, eps, clip_val)
return forward_test(golden_inputs, pto_inputs, golden, ai_infra_qat_asymmetric_per_group)
@pytest.mark.parametrize(
('n', 'm', 'group', 'bit', 'eps', 'clip_val'),
[
pytest.param(1024, 2048, 128, 2, 0.0001, 0.99,
id="N1024-M2048-group128-bit2-eps0.0001-clip_val0.99",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(768, 2048, 128, 3, 0.0001, 0.99,
id="N768-M2048-group128-bit3-eps0.0001-clip_val0.99"),
]
)
def test_asymmetric_per_group(n, m, group, bit, eps, clip_val) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_asymmetric_per_group_test(n, m, group, bit, eps, clip_val, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, group, bit, eps, clip_val, dis]]
results.append(str_params + flattened_result)
def run_symmetric_per_channel_test(n, m, bit, eps, distribution, device_id):
device = f"npu:{device_id}"
seed = 33
min_v = float(-2 ** (bit - 1))
max_v = float(2 ** (bit - 1) - 1)
weight_shape = (n, m)
scale_shape = (n, 1)
weight = create_input(weight_shape, torch.bfloat16, device, distribution, seed)
scale = create_input(scale_shape, torch.bfloat16, device, distribution, seed)
golden_inputs = [weight, scale]
pto_inputs = [weight, scale, eps, min_v, max_v]
golden = create_symmetric_qat_nscale_golden(eps, min_v, max_v)
return forward_test(golden_inputs, pto_inputs, golden, ai_infra_qat_symmetric_per_channel)
@pytest.mark.parametrize(
('n', 'm', 'bit', 'eps'),
[
pytest.param(153376, 2048, 4, 0.0001,
id="N153376-M2048-bit4-eps0.0001",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(38344, 2048, 4, 0.0001,
id="N38344-M2048-bit4-eps0.0001"),
]
)
def test_symmetric_per_channel(n, m, bit, eps) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_symmetric_per_channel_test(n, m, bit, eps, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, bit, eps, dis]]
results.append(str_params + flattened_result)
def run_symmetric_per_tensor_test(n, m, bit, eps, distribution, device_id):
device = f"npu:{device_id}"
seed = 33
min_v = float(-2 ** (bit - 1))
max_v = float(2 ** (bit - 1) - 1)
weight_shape = (n, m)
scale_shape = (1, 1)
weight = create_input(weight_shape, torch.bfloat16, device, distribution, seed)
scale = create_input(scale_shape, torch.bfloat16, device, distribution, seed)
golden_inputs = [weight, scale]
pto_inputs = [weight, scale, eps, min_v, max_v]
golden = create_symmetric_qat_golden(eps, min_v, max_v)
return forward_test(golden_inputs, pto_inputs, golden, ai_infra_qat_symmetric_per_tensor)
@pytest.mark.parametrize(
('n', 'm', 'bit', 'eps'),
[
pytest.param(153376, 2048, 8, 0.0001,
id="N153376-M2048-bit8-eps0.0001",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(38344, 2048, 8, 0.0001,
id="N38344-M2048-bit8-eps0.0001"),
]
)
def test_symmetric_per_tensor(n, m, bit, eps) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_symmetric_per_tensor_test(n, m, bit, eps, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, bit, eps, dis]]
results.append(str_params + flattened_result)
if __name__ == "__main__":
test_asymmetric_per_group(1024, 2048, 128, 2, 0.0001, 0.99)
def run_asymmetric_per_group_backward_test(n, m, group_size, bit, eps, clip_val, distribution, device_id):
"""Run a single backward test case for asymmetric per-group quantization."""
device = f"npu:{device_id}"
seed = 33
weight_shape = (n, m)
groups_per_row = m // group_size
num_groups = n * groups_per_row
scale_shape = (num_groups, 1)
offset_shape = (num_groups, 1)
weight_bm = create_input(weight_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
scale_bm = create_input(scale_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
offset_bm = create_input(offset_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
golden_inputs = [weight_bm, scale_bm, offset_bm]
pto_inputs = [weight_bm, scale_bm, offset_bm, group_size, bit, eps, clip_val]
golden = create_asymmetric_qat_golden(group_size, bit, eps, clip_val)
return backward_test_autograd(golden_inputs, pto_inputs, golden, ai_infra_qat_asymmetric_per_group_backward)
@pytest.mark.parametrize(
('n', 'm', 'group', 'bit', 'eps', 'clip_val'),
[
pytest.param(1024, 2048, 128, 2, 0.0001, 0.99,
id="N1024-M2048-group128-bit2-eps0.0001-clip_val0.99",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(768, 2048, 128, 3, 0.0001, 0.99,
id="N768-M2048-group128-bit3-eps0.0001-clip_val0.99"),
]
)
def test_asymmetric_per_group_backward(n, m, group, bit, eps, clip_val) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_asymmetric_per_group_backward_test(n, m, group, bit, eps, clip_val, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, group, bit, eps, clip_val, dis]]
results.append(str_params + flattened_result)
def run_symmetric_per_channel_backward_test(n, m, bit, eps, distribution, device_id):
device = f"npu:{device_id}"
seed = 33
min_v = float(-2 ** (bit - 1))
max_v = float(2 ** (bit - 1) - 1)
weight_shape = (n, m)
scale_shape = (n, 1)
weight = create_input(weight_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
scale = create_input(scale_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
inputs = [weight, scale]
pto_inputs = [weight, scale, eps, min_v, max_v]
golden = create_symmetric_qat_nscale_golden(eps, min_v, max_v)
return backward_test_autograd(inputs, pto_inputs, golden, ai_infra_qat_symmetric_per_channel_backward)
@pytest.mark.parametrize(
('n', 'm', 'bit', 'eps'),
[
pytest.param(153376, 2048, 4, 0.0001,
id="N153376-M2048-bit4-eps0.0001",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(38344, 2048, 4, 0.0001,
id="N38344-M2048-bit4-eps0.0001"),
]
)
def test_symmetric_per_channel_backward(n, m, bit, eps) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_symmetric_per_channel_backward_test(n, m, bit, eps, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, bit, eps, dis]]
results.append(str_params + flattened_result)
def run_symmetric_per_tensor_backward_test(n, m, bit, eps, distribution, device_id):
device = f"npu:{device_id}"
seed = 33
min_v = float(-2 ** (bit - 1))
max_v = float(2 ** (bit - 1) - 1)
weight_shape = (n, m)
scale_shape = (1, 1)
weight = create_input(weight_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
scale = create_input(scale_shape, torch.bfloat16, device, distribution, seed).requires_grad_(True)
inputs = [weight, scale]
pto_inputs = [weight, scale, eps, min_v, max_v]
golden = create_symmetric_qat_golden(eps, min_v, max_v)
return backward_test_autograd(inputs, pto_inputs, golden, ai_infra_qat_symmetric_per_tensor_backward)
@pytest.mark.parametrize(
('n', 'm', 'bit', 'eps'),
[
pytest.param(153376, 2048, 8, 0.0001,
id="N153376-M2048-bit8-eps0.0001",
marks=pytest.mark.skip(reason="temporarily disabled")),
pytest.param(38344, 2048, 8, 0.0001,
id="N38344-M2048-bit8-eps0.0001"),
]
)
def test_symmetric_per_tensor_backward(n, m, bit, eps) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
results = []
for dis in DISTRIBUTION:
compare_result = run_symmetric_per_tensor_backward_test(n, m, bit, eps, dis, device_id)
flattened_result = [str(item) for sublist in compare_result for item in sublist]
str_params = [str(param) for param in [n, m, bit, eps, dis]]
results.append(str_params + flattened_result)