量化感知训练(QAT)算子接口文档
本文档包含量化感知训练(Quantization-Aware Training, QAT)相关的算子接口说明,涵盖对称量化和非对称量化两种方式。
目录
- ai_infra_qat_symmetric_per_tensor - 对称张量级量化算子
- ai_infra_qat_symmetric_per_channel - 对称逐通道量化算子
- ai_infra_qat_asymmetric_per_group - 非对称分组量化算子
ai_infra_qat_symmetric_per_tensor
对称量化感知训练(Quantization-Aware Training, QAT)算子,包含正向和反向两个算子。该算子用于对权重进行对称量化模拟,在训练过程中引入量化噪声,使模型能够适应量化带来的精度损失。
适用于 Embedding 层场景,scale 为标量形式(shape 为 (1,1)),所有权重元素共享同一个缩放系数。
正向算子:ai_infra_qat_symmetric_per_tensor
功能描述
对称量化感知训练(Quantization-Aware Training, QAT)正向算子。该算子用于对权重进行对称量化模拟,在训练过程中引入量化噪声,使模型能够适应量化带来的精度损失。
该算子适用于 Embedding 层场景,scale 为标量形式(shape 为 (1,1)),所有权重元素共享同一个缩放系数。
接口定义
def ai_infra_qat_symmetric_per_tensor(
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (1, 1)
eps: float, # 最小scale阈值
min_v: float, # 量化下界
max_v: float, # 量化上界
) -> Tensor: # BF16, shape: (N, M)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| weight | Tensor(BF16) | (N, M) | 输入权重张量,N为输出特征维度,M为输入特征维度 |
| scale | Tensor(BF16) | (1, 1) | 量化缩放系数,标量形式,所有权重元素共享同一缩放系数 |
| eps | float | - | scale的最小阈值,防止除零错误,当scale小于eps时使用eps替代 |
| min_v | float | - | 量化下界,对于INT8量化为-128 |
| max_v | float | - | 量化上界,对于INT8量化为127 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| output | Tensor(BF16) | (N, M) | 伪量化后的权重张量,与输入权重形状相同 |
算法原理
对称量化感知训练的核心是使用直通估计器(Straight-Through Estimator, STE)来近似量化操作的梯度。算子通过以下步骤实现:
Step 1: Scale 防零保护
s′={s,if s>εε,otherwises' = \begin{cases} s, & \text{if } s > \varepsilon \\ \varepsilon, & \text{otherwise} \end{cases}
其中 ss 为输入的 scale,ε\varepsilon 为 eps 参数。
Step 2: 归一化
Wnorm=Ws′W_{\text{norm}} = \frac{W}{s'}
将权重除以缩放系数,得到归一化后的权重。
Step 3: 伪量化(STE近似的四舍五入 + 截断)
Wquant=detach(round(Wnorm)−Wnorm)+WnormW_{\text{quant}} = \text{detach}\left(\text{round}(W_{\text{norm}}) - W_{\text{norm}}\right) + W_{\text{norm}}
Wclamp=clamp(Wquant,Vmin,Vmax)W_{\text{clamp}} = \text{clamp}(W_{\text{quant}}, V_{\text{min}}, V_{\text{max}})
使用 detach() 阻断四舍五入操作的梯度传播,实现 STE 近似。然后将量化后的值截断到有效范围内。
Step 4: 反量化
Wq=Wclamp×s′W_q = W_{\text{clamp}} \times s'
将截断后的值乘以缩放系数,恢复到原始数值范围。
约束条件
weight必须为 2 维张量,形状为 (N, M)。M∈[128, 3072],且被128整除。数据类型为BF16。scale必须为 2 维张量,形状为 (1, 1),数据类型为BF16。eps应在 (0, 1) 范围内,数据类型为float。min_v应小于max_v,并且都为浮点型整数。数据类型为float,小数位为全0。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量
N, M = 1024, 2048
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0")
scale = torch.tensor([[0.1]], dtype=torch.bfloat16, device="npu:0")
# 创建并调用算子
output = ai_infra_qat_symmetric_per_tensor(weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0)
print(f"输入权重形状: {weight.shape}")
print(f"输出权重形状: {output.shape}")
反向算子:ai_infra_qat_symmetric_per_tensor_backward
功能描述
对称量化感知训练(QAT)反向算子。该算子用于计算对称量化操作的梯度,支持权重和缩放系数的梯度计算。
该算子适用于 Embedding 层场景,scale 为标量形式(shape 为 (1,1)),对应正向算子 ai_infra_qat_symmetric_per_tensor 的反向传播。
接口定义
def ai_infra_qat_symmetric_per_tensor_backward(
grad_output: Tensor, # BF16, shape: (N, M)
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (1, 1)
eps: float = 1e-4, # 最小scale阈值
min_v: float = -128.0, # 量化下界
max_v: float = 127.0, # 量化上界
) -> Tuple[Tensor, Tensor]: # (grad_weight, grad_scale)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 默认值 | 说明 |
|---|---|---|---|---|
| grad_output | Tensor(BF16) | (N, M) | 必填 | 上游传递的梯度张量 |
| weight | Tensor(BF16) | (N, M) | 必填 | 原始输入权重张量 |
| scale | Tensor(BF16) | (1, 1) | 必填 | 量化缩放系数 |
| eps | float | - | 1e-4 | scale的最小阈值 |
| min_v | float | - | -128.0 | 量化下界 |
| max_v | float | - | 127.0 | 量化上界 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| grad_weight | Tensor(BF16) | (N, M) | 对权重的梯度 |
| grad_scale | Tensor(BF16) | (1, 1) | 对缩放系数的梯度 |
算法原理
反向传播需要计算损失函数对权重 W 和缩放系数 s 的梯度。基于正向传播的计算图,梯度通过链式法则反向传递。
第一步:反量化(Wq=Wclamp×s′W_q = W_{clamp} \times s')的梯度
对 WclampW_{clamp} 的梯度:
∂Loss∂Wclamp=∂Loss∂Wq×s′\frac{\partial\text{Loss}}{\partial W_{\text{clamp}}} = \frac{\partial\text{Loss}}{\partial W_q} \times s'
对 s′s' 的梯度:
∂Loss∂s′=∑i=1N∑j=1M(∂Loss∂Wq[i,j]×Wclamp[i,j])\frac{\partial\text{Loss}}{\partial s'} = \sum_{i=1}^N \sum_{j=1}^M \left( \frac{\partial\text{Loss}}{\partial W_q}[i,j] \times W_{\text{clamp}}[i,j] \right)
第二步:截断的梯度(STE 近似)
∂Loss∂Wquant=∂Loss∂Wclamp⊙{1,Vmin≤Wquant≤Vmax0,otherwise\frac{\partial\text{Loss}}{\partial W_{\text{quant}}} = \frac{\partial\text{Loss}}{\partial W_{\text{clamp}}}\odot\begin{cases} 1, & V_{\text{min}}\leq W_{\text{quant}}\leq V_{\text{max}}\\ 0, &\text{otherwise}\end{cases}
第三步:伪量化的梯度
∂Loss∂Wnorm=∂Loss∂Wquant\frac{\partial\text{Loss}}{\partial W_{\text{norm}}} = \frac{\partial\text{Loss}}{\partial W_{\text{quant}}}
第四步:归一化的梯度
对原始权重 W 的梯度:
∂Loss∂W=∂Loss∂Wnorm×1s′\frac{\partial\text{Loss}}{\partial W} = \frac{\partial\text{Loss}}{\partial W_{\text{norm}}} \times\frac{1}{s'}
对 s' 的梯度(叠加):
∂Loss∂s′+=−1(s′)2×sum(∂Loss∂Wnorm⊙W)\frac{\partial\text{Loss}}{\partial s'} \mathrel{+}= -\frac{1}{(s')^2} \times\text{sum}\left( \frac{\partial\text{Loss}}{\partial W_{\text{norm}}} \odot W \right)
第五步:scale 防零的梯度
∂Loss∂s=∂Loss∂s′⋅{1,s>ε0,otherwise\frac{\partial\text{Loss}}{\partial s} = \frac{\partial\text{Loss}}{\partial s'}\cdot\begin{cases} 1, & s > \varepsilon\\ 0, &\text{otherwise}\end{cases}
最终合并公式
对 W 的梯度:
∂Loss∂W=∂Loss∂Wq⊙{1,Vmin≤Wquant≤Vmax0,otherwise\frac{\partial\text{Loss}}{\partial W} = \frac{\partial\text{Loss}}{\partial W_q}\odot\begin{cases} 1, & V_{\text{min}}\leq W_{\text{quant}}\leq V_{\text{max}}\\ 0, &\text{otherwise}\end{cases}
对 s 的梯度:
∂Loss∂s=[sum(∂Loss∂Wq⊙Wclamp)−1s′⋅sum(∂Loss∂Wq⊙1mask⊙W)]⋅1s>ε\frac{\partial \text{Loss}}{\partial s} = \left[ \text{sum}\left( \frac{\partial \text{Loss}}{\partial W_q} \odot W_{\text{clamp}} \right) - \frac{1}{s'} \cdot \text{sum}\left( \frac{\partial \text{Loss}}{\partial W_q} \odot \mathbf{1}_{\text{mask}} \odot W \right) \right] \cdot \mathbf{1}_{s > \varepsilon}
约束条件
grad_output和weight的形状必须相同,为 (N, M)。M∈[128, 3072],且被128整除。数据类型为BF16。scale的形状必须为 (1, 1),数据类型为BF16。eps应在 (0, 1) 范围内,数据类型为float。min_v应小于max_v,并且都为浮点型整数。数据类型为float,小数位为全0。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量
N, M = 1024, 2048
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0", requires_grad=True)
scale = torch.tensor([[0.1]], dtype=torch.bfloat16, device="npu:0", requires_grad=True)
# 前向传播
output = ai_infra_qat_symmetric_per_tensor(weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0)
# 模拟上游梯度
grad_output = torch.ones_like(output)
# 反向传播
grad_weight, grad_scale = ai_infra_qat_symmetric_per_tensor_backward(
grad_output, weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0
)
print(f"权重梯度形状: {grad_weight.shape}")
print(f"Scale梯度形状: {grad_scale.shape}")
ai_infra_qat_symmetric_per_channel
对称量化感知训练(Quantization-Aware Training, QAT)算子(N scale 版本),包含正向和反向两个算子。该算子用于对权重进行对称量化模拟,支持每个输出通道独立的缩放系数。
适用于 Lm Head 层场景,scale 为向量形式(shape 为 (N,1)),每个输出通道对应一个独立的缩放系数。
正向算子:ai_infra_qat_symmetric_per_channel
功能描述
对称量化感知训练(Quantization-Aware Training, QAT)正向算子(N scale 版本)。该算子用于对权重进行对称量化模拟,支持每个输出通道独立的缩放系数。
该算子适用于 Lm Head 层场景,scale 为向量形式(shape 为 (N,1)),每个输出通道对应一个独立的缩放系数。
接口定义
def ai_infra_qat_symmetric_per_channel(
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (N, 1)
eps: float = 1e-4, # 最小scale阈值
min_v: float = -128.0, # 量化下界
max_v: float = 127.0, # 量化上界
) -> Tensor: # BF16, shape: (N, M)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 默认值 | 说明 |
|---|---|---|---|---|
| weight | Tensor(BF16) | (N, M) | 必填 | 输入权重张量,N为输出特征维度,M为输入特征维度 |
| scale | Tensor(BF16) | (N, 1) | 必填 | 量化缩放系数,每个输出通道对应一个独立的缩放系数 |
| eps | float | - | 1e-4 | scale的最小阈值,防止除零错误 |
| min_v | float | - | -128.0 | 量化下界,对于INT8量化为-128 |
| max_v | float | - | 127.0 | 量化上界,对于INT8量化为127 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| output | Tensor(BF16) | (N, M) | 伪量化后的权重张量,与输入权重形状相同 |
算法原理
与 symmetric_qat 算法原理相同,主要区别在于 scale 的形状。当 scale 为 (N, 1) 时,每个输出通道使用独立的缩放系数。
Step 1: Scale 防零保护
si′={si,if si>εε,otherwises'_i = \begin{cases} s_i, & \text{if } s_i > \varepsilon \\ \varepsilon, & \text{otherwise} \end{cases}
其中 sis_i 为第 i 个输出通道的 scale。
Step 2: 归一化(广播缩放)
Wnorm[i,j]=W[i,j]si′W_{\text{norm}}[i, j] = \frac{W[i, j]}{s'_i}
scale 沿 M 维度广播,实现逐通道归一化。
Step 3: 伪量化(STE近似的四舍五入 + 截断)
Wquant=detach(round(Wnorm)−Wnorm)+WnormW_{\text{quant}} = \text{detach}\left(\text{round}(W_{\text{norm}}) - W_{\text{norm}}\right) + W_{\text{norm}}
Wclamp=clamp(Wquant,Vmin,Vmax)W_{\text{clamp}} = \text{clamp}(W_{\text{quant}}, V_{\text{min}}, V_{\text{max}})
Step 4: 反量化
Wq[i,j]=Wclamp[i,j]×si′W_q[i, j] = W_{\text{clamp}}[i, j] \times s'_i
约束条件
weight必须为 2 维张量,形状为 (N, M)。M∈[128, 3072],且被128整除。数据类型为BF16。scale必须为 2 维张量,形状为 (N, 1),数据类型为BF16eps应在 (0, 1) 范围内,数据类型为floatmin_v应小于max_v,并且都为浮点型整数。数据类型为float,小数位为全0。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量 - Lm Head 场景
# 例如:词表大小 153376,隐藏维度 2048
N, M = 153376, 2048
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0")
scale = torch.abs(torch.randn(N, 1, dtype=torch.bfloat16, device="npu:0")) + 0.01
# 创建并调用算子
output = ai_infra_qat_symmetric_per_channel(weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0)
print(f"输入权重形状: {weight.shape}")
print(f"Scale形状: {scale.shape}")
print(f"输出权重形状: {output.shape}")
反向算子:ai_infra_qat_symmetric_per_channel_backward
功能描述
对称量化感知训练(QAT)反向算子(N scale 版本)。该算子用于计算对称量化操作的梯度,支持每个输出通道独立的缩放系数梯度计算。
该算子适用于 Lm Head 层场景,scale 为向量形式(shape 为 (N,1)),对应正向算子 ai_infra_qat_symmetric_per_channel 的反向传播。
接口定义
def ai_infra_qat_symmetric_per_channel_backward(
grad_output: Tensor, # BF16, shape: (N, M)
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (N, 1)
eps: float = 1e-4, # 最小scale阈值
min_v: float = -128.0, # 量化下界
max_v: float = 127.0, # 量化上界
) -> Tuple[Tensor, Tensor]: # (grad_weight, grad_scale)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 默认值 | 说明 |
|---|---|---|---|---|
| grad_output | Tensor(BF16) | (N, M) | 必填 | 上游传递的梯度张量 |
| weight | Tensor(BF16) | (N, M) | 必填 | 原始输入权重张量 |
| scale | Tensor(BF16) | (N, 1) | 必填 | 量化缩放系数,每个输出通道一个 |
| eps | float | - | 1e-4 | scale的最小阈值 |
| min_v | float | - | -128.0 | 量化下界 |
| max_v | float | - | 127.0 | 量化上界 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| grad_weight | Tensor(BF16) | (N, M) | 对权重的梯度 |
| grad_scale | Tensor(BF16) | (N, 1) | 对缩放系数的梯度,每个输出通道一个 |
算法原理
与 ai_infra_qat_symmetric_per_tensor_backward 算法原理相似,但由于 scale 为 (N, 1) 形状,梯度计算有所不同。
关键差异
由于每个输出通道有独立的 scale,梯度计算无需跨通道求和,只需沿 M 维度求和:
对 W 的梯度:
∂Loss∂W[i,j]=∂Loss∂Wq[i,j]⋅1mask[i,j]⋅1si′\frac{\partial\text{Loss}}{\partial W}[i,j] = \frac{\partial\text{Loss}}{\partial W_q}[i,j] \cdot \mathbf{1}_{\text{mask}}[i,j] \cdot \frac{1}{s'_i}
对 s 的梯度:
∂Loss∂si=[∑j=1M(∂Loss∂Wq[i,j]⋅Wclamp[i,j])−1si′∑j=1M(∂Loss∂Wq[i,j]⋅1mask[i,j]⋅W[i,j])]⋅1si>ε\frac{\partial\text{Loss}}{\partial s_i} = \left[ \sum_{j=1}^{M} \left( \frac{\partial\text{Loss}}{\partial W_q}[i,j] \cdot W_{\text{clamp}}[i,j] \right) - \frac{1}{s'_i} \sum_{j=1}^{M} \left( \frac{\partial\text{Loss}}{\partial W_q}[i,j] \cdot \mathbf{1}_{\text{mask}}[i,j] \cdot W[i,j] \right) \right] \cdot \mathbf{1}_{s_i > \varepsilon}
梯度计算步骤
- 重算前向中间值:根据 weight 和 scale 重新计算 normalized、rounded、clamped
- 计算截断掩码:判断元素是否在 [min_v, max_v] 范围内
- 计算 scale 掩码:判断 scale 是否大于 eps
- 计算 grad_weight:上游梯度 × 截断掩码
- 计算 grad_scale:乘法路径 + 除法路径,沿 M 维度求和
约束条件
grad_output和weight的形状必须相同,为 (N, M)。M∈[128, 3072],且被128整除。数据类型为BF16。scale的形状必须为 (N, 1),数据类型为BF16。eps应在 (0, 1) 范围内,数据类型为float。min_v应小于max_v,并且都为浮点型整数。数据类型为float,小数位为全0。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量 - Lm Head 场景
N, M = 153376, 2048
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0", requires_grad=True)
scale = torch.abs(torch.randn(N, 1, dtype=torch.bfloat16, device="npu:0")) + 0.01
scale.requires_grad_(True)
# 前向传播
output = ai_infra_qat_symmetric_per_channel(weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0)
# 模拟上游梯度
grad_output = torch.ones_like(output)
# 反向传播
grad_weight, grad_scale = ai_infra_qat_symmetric_per_channel_backward(
grad_output, weight, scale, eps=1e-4, min_v=-128.0, max_v=127.0
)
print(f"权重梯度形状: {grad_weight.shape}")
print(f"Scale梯度形状: {grad_scale.shape}")
与 ai_infra_qat_symmetric_per_tensor 的区别
| 特性 | ai_infra_qat_symmetric_per_tensor | ai_infra_qat_symmetric_per_channel |
|---|---|---|
| scale 形状 | (1, 1) | (N, 1) |
| 缩放粒度 | 全局统一缩放 | 逐通道独立缩放 |
| 适用场景 | Embedding 层 | Lm Head 层 |
| 量化精度 | 较低 | 较高 |
| grad_scale 形状 | (1, 1) | (N, 1) |
| 梯度求和维度 | 沿 N 和 M 维度求和 | 仅沿 M 维度求和 |
ai_infra_qat_asymmetric_per_group
非对称量化感知训练(Asymmetric Quantization-Aware Training, QAT)算子,包含正向和反向两个算子。该算子用于对权重进行非对称量化模拟,支持分组量化(Group Quantization)和可学习的偏移量(offset)。
适用于 Transformer Linear 层场景,通过分组量化实现更精细的量化粒度,提高量化后模型的精度。非对称量化相比对称量化能够更好地适应权重分布的不对称性。
正向算子:ai_infra_qat_asymmetric_per_group
功能描述
非对称量化感知训练(Asymmetric Quantization-Aware Training, QAT)正向算子。该算子用于对权重进行非对称量化模拟,支持分组量化(Group Quantization)和可学习的偏移量(offset)。
该算子适用于 Transformer Linear 层场景,通过分组量化实现更精细的量化粒度,提高量化后模型的精度。非对称量化相比对称量化能够更好地适应权重分布的不对称性。
接口定义
def ai_infra_qat_asymmetric_per_group(
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (N*M/group_size, 1)
offset: Tensor, # BF16, shape: (N*M/group_size, 1)
group_size: int = 128, # 分组大小
bit: int = 4, # 量化位宽,支持2、3、4
eps: float = 1e-4, # 最小scale阈值
clip_val: float = 0.99, # 截断值
) -> Tensor: # BF16, shape: (N, M)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 默认值 | 说明 |
|---|---|---|---|---|
| weight | Tensor(BF16) | (N, M) | 必填 | 输入权重张量,N为输出特征维度,M为输入特征维度 |
| scale | Tensor(BF16) | (N*M/group_size, 1) | 必填 | 量化缩放系数,每组一个缩放系数 |
| offset | Tensor(BF16) | (N*M/group_size, 1) | 必填 | 量化偏移量,每组一个偏移量 |
| group_size | int | - | 128 | 分组大小,每组包含的权重元素数量 |
| bit | int | - | 4 | 量化位宽,支持2、3、4位量化 |
| eps | float | - | 1e-4 | scale的最小阈值,防止除零错误 |
| clip_val | float | - | 0.99 | 截断值,用于限制量化范围 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| output | Tensor(BF16) | (N, M) | 伪量化后的权重张量,与输入权重形状相同 |
算法原理
非对称量化感知训练基于增强型 LSQ+(Learned Step Size Quantization Plus)算法,通过分组量化和可学习的 scale/offset 实现更灵活的量化。
核心公式
Step 1: Scale 防零保护
s′={s,if s>εε,otherwises' = \begin{cases} s, & \text{if } s > \varepsilon \\ \varepsilon, & \text{otherwise} \end{cases}
Step 2: 权重重塑为分组形式
Wgroup=reshape(W,[G,group_size])W_{\text{group}} = \text{reshape}(W, [G, \text{group\_size}])
其中 G=N×Mgroup_sizeG = \frac{N \times M}{\text{group\_size}} 为总组数。
Step 3: 计算量化参数
α=s′×nlevels,nlevels=2(bit−1)\alpha = s' \times n_{\text{levels}}, \quad n_{\text{levels}} = 2^{(\text{bit}-1)}
shift=0.5\text{shift} = 0.5
Step 4: 非对称量化
Wshifted=Wgroup−offsetW_{\text{shifted}} = W_{\text{group}} - \text{offset}
Wclipped=clamp(Wshiftedα,−clip_val,clip_val)×nlevels−shiftW_{\text{clipped}} = \text{clamp}\left(\frac{W_{\text{shifted}}}{\alpha}, -\text{clip\_val}, \text{clip\_val}\right) \times n_{\text{levels}} - \text{shift}
Step 5: 伪量化(STE)
Wrounded=detach(round(Wclipped)−Wclipped)+WclippedW_{\text{rounded}} = \text{detach}(\text{round}(W_{\text{clipped}}) - W_{\text{clipped}}) + W_{\text{clipped}}
Step 6: 反量化
Wunshifted=Wrounded+shiftW_{\text{unshifted}} = W_{\text{rounded}} + \text{shift}
Wdenorm=WunshiftednlevelsW_{\text{denorm}} = \frac{W_{\text{unshifted}}}{n_{\text{levels}}}
Wout=Wdenorm×α+offsetW_{\text{out}} = W_{\text{denorm}} \times \alpha + \text{offset}
分组量化说明
分组量化将权重矩阵划分为多个小组,每组拥有独立的 scale 和 offset:
权重矩阵 (N, M):
┌─────────────────────────────────────┐
│ Group 0 │ Group 1 │ ... │ Group G-1 │
│ (128个元素)│ (128个元素)│ │ (128个元素) │
└─────────────────────────────────────┘
↓ ↓ ↓
scale[0] scale[1] scale[G-1]
offset[0] offset[1] offset[G-1]
约束条件
group_size取64、128、256,数据类型为int。weight必须为 2 维张量,形状为 (N, M),M∈[128, 3072]且被group_size整除。数据类型为BF16。scale和offset的形状必须为 (N*M/group_size, 1)。数据类型为BF16。bit只能取 2、3、4,数据类型为int。eps应在 (0, 1) 范围内,数据类型为float。clip_val应在 (0, 1) 范围内,数据类型为float。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
- 支持位宽: 2-bit, 3-bit, 4-bit
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量 - Transformer Linear 层场景
# 例如:FFN 层,输入维度 3072,输出维度 768
N, M = 768, 3072
group_size = 128
bit = 4
# 计算组数
num_groups = N * M // group_size
# 创建输入
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0")
scale = torch.abs(torch.randn(num_groups, 1, dtype=torch.bfloat16, device="npu:0")) + 0.01
offset = torch.randn(num_groups, 1, dtype=torch.bfloat16, device="npu:0")
# 创建并调用算子
output = ai_infra_qat_asymmetric_per_group(weight, scale, offset, group_size=group_size, bit=bit, eps=1e-4, clip_val=0.99)
print(f"输入权重形状: {weight.shape}")
print(f"Scale形状: {scale.shape}")
print(f"Offset形状: {offset.shape}")
print(f"输出权重形状: {output.shape}")
反向算子:ai_infra_qat_asymmetric_per_group_backward
功能描述
非对称量化感知训练(QAT)反向算子。该算子用于计算非对称量化操作的梯度,支持权重、缩放系数和偏移量的梯度计算。
该算子适用于 Transformer Linear 层场景,对应正向算子 ai_infra_qat_asymmetric_per_group 的反向传播,通过分组量化实现精细的梯度计算。
接口定义
def ai_infra_qat_asymmetric_per_group_backward(
grad_output: Tensor, # BF16, shape: (N, M)
weight: Tensor, # BF16, shape: (N, M)
scale: Tensor, # BF16, shape: (N*M/group_size, 1)
offset: Tensor, # BF16, shape: (N*M/group_size, 1)
group_size: int = 128, # 分组大小
bit: int = 4, # 量化位宽
eps: float = 1e-4, # 最小scale阈值
clip_val: float = 0.99, # 截断值
) -> Tuple[Tensor, Tensor, Tensor]: # (grad_weight, grad_scale, grad_offset)
参数说明
输入参数
| 参数名 | 类型 | 形状 | 默认值 | 说明 |
|---|---|---|---|---|
| grad_output | Tensor(BF16) | (N, M) | 必填 | 上游传递的梯度张量 |
| weight | Tensor(BF16) | (N, M) | 必填 | 原始输入权重张量 |
| scale | Tensor(BF16) | (N*M/group_size, 1) | 必填 | 量化缩放系数 |
| offset | Tensor(BF16) | (N*M/group_size, 1) | 必填 | 量化偏移量 |
| group_size | int | - | 128 | 分组大小 |
| bit | int | - | 4 | 量化位宽,支持2、3、4 |
| eps | float | - | 1e-4 | scale的最小阈值 |
| clip_val | float | - | 0.99 | 截断值 |
输出参数
| 参数名 | 类型 | 形状 | 说明 |
|---|---|---|---|
| grad_weight | Tensor(BF16) | (N, M) | 对权重的梯度 |
| grad_scale | Tensor(BF16) | (N*M/group_size, 1) | 对缩放系数的梯度 |
| grad_offset | Tensor(BF16) | (N*M/group_size, 1) | 对偏移量的梯度 |
算法原理
反向传播基于 LSQ+ 算法的梯度公式,需要重新计算前向传播的中间变量,并正确处理截断区域和非截断区域的梯度。
前向状态重计算
首先重算以下中间变量:
protected_scale: 经过防零保护的 scalealpha: = protected_scale × n_levelsweight_shifted: = weight - offsetweight_scaled: = weight_shifted / alpha(未截断)weight_clipped: 截断后的值weight_denorm: 反量化后的值
梯度计算
截断掩码(STE 激活区域)
mask[i,j]={1,−clip_val≤Wscaled[i,j]≤clip_val0,otherwise\text{mask}[i,j] = \begin{cases} 1, & -\text{clip\_val} \leq W_{\text{scaled}}[i,j] \leq \text{clip\_val} \\ 0, & \text{otherwise} \end{cases}
grad_weight 计算
只有在截断范围内的元素才能传导梯度:
∂Loss∂W=∂Loss∂Wout⊙mask\frac{\partial\text{Loss}}{\partial W} = \frac{\partial\text{Loss}}{\partial W_{\text{out}}} \odot \text{mask}
grad_offset 计算
截断区域外的梯度累加到 offset:
∂Loss∂offset=∑j∈group(∂Loss∂Wout⊙(1−mask))\frac{\partial\text{Loss}}{\partial \text{offset}} = \sum_{j \in \text{group}} \left( \frac{\partial\text{Loss}}{\partial W_{\text{out}}} \odot (1 - \text{mask}) \right)
沿 group 维度求和。
grad_scale 计算
∂Loss∂s=∂Loss∂α×nlevels×1s>ε\frac{\partial\text{Loss}}{\partial s} = \frac{\partial\text{Loss}}{\partial \alpha} \times n_{\text{levels}} \times \mathbf{1}_{s > \varepsilon}
其中:
∂Loss∂α=∑j∈group(∂Loss∂Wout⊙(Wdenorm−Wscaled⊙mask))\frac{\partial\text{Loss}}{\partial \alpha} = \sum_{j \in \text{group}} \left( \frac{\partial\text{Loss}}{\partial W_{\text{out}}} \odot (W_{\text{denorm}} - W_{\text{scaled}} \odot \text{mask}) \right)
关键实现细节
无缓存设计
该算子采用无缓存设计,在反向传播时重新计算前向传播的中间变量,避免了存储大量中间状态带来的内存开销。
掩码生成
使用数值方法生成掩码,避免直接使用条件判断:
# 判断是否在截断范围内
diff = weight_norm - weight_clipped
abs_diff = abs(diff)
is_out = clip(abs_diff * big_number, 0.0, 1.0)
mask = 1.0 - is_out
梯度累加
对于 scale 和 offset 的梯度,需要沿 group 维度进行求和:
grad_offset = (grad_output * (1 - mask)).sum(dim=1, keepdim=True)
grad_alpha = (grad_output * (weight_denorm - weight_scaled * mask)).sum(dim=1, keepdim=True)
约束条件
group_size取64、128、256,数据类型为int。grad_output和weight的形状必须相同,为 (N, M),M∈[128, 3072]且被group_size整除。数据类型为BF16。scale和offset的形状必须为 (N*M/group_size, 1),数据类型为BF16。bit只能取 2、3、4,数据类型为int。eps应在 (0, 1) 范围内,数据类型为float。clip_val应在 (0, 1) 范围内,数据类型为float。
支持规格
- 数据类型: BF16(输入/输出),FP32(内部计算)
- 芯片平台: A2/A3
- 支持位宽: 2-bit, 3-bit, 4-bit
使用示例
import torch
import pypto
# 设置设备
torch.npu.set_device(0)
# 创建输入张量
N, M = 1024, 2048
group_size = 128
bit = 4
num_groups = N * M // group_size
weight = torch.randn(N, M, dtype=torch.bfloat16, device="npu:0", requires_grad=True)
scale = torch.abs(torch.randn(num_groups, 1, dtype=torch.bfloat16, device="npu:0")) + 0.01
scale.requires_grad_(True)
offset = torch.randn(num_groups, 1, dtype=torch.bfloat16, device="npu:0", requires_grad=True)
# 前向传播
output = ai_infra_qat_asymmetric_per_group(weight, scale, offset, group_size=group_size, bit=bit, eps=1e-4, clip_val=0.99)
# 模拟上游梯度
grad_output = torch.ones_like(output)
# 反向传播
grad_weight, grad_scale, grad_offset = ai_infra_qat_asymmetric_per_group_backward(
grad_output, weight, scale, offset, group_size=group_size, bit=bit, eps=1e-4, clip_val=0.99
)
print(f"权重梯度形状: {grad_weight.shape}")
print(f"Scale梯度形状: {grad_scale.shape}")
print(f"Offset梯度形状: {grad_offset.shape}")
与对称量化的对比
| 特性 | 对称量化 | 非对称量化 |
|---|---|---|
| 量化范围 | [-Q, Q] | [-Q, Q] + offset |
| 参数数量 | scale | scale + offset |
| 适用场景 | 权重分布对称 | 权重分布不对称 |
| 量化精度 | 较低 | 较高 |
| 计算复杂度 | 较低 | 较高 |