torch_npu.contrib.module.LinearA8W8Quant
[!NOTICE]
该接口计划废弃,可以使用torch_npu.contrib.module.LinearQuant接口进行替换。
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 推理系列产品 | √ |
功能说明
LinearA8W8Quant是对torch_npu.npu_quant_matmul接口的封装类,完成A8W8量化算子的矩阵乘计算。
函数原型
torch_npu.contrib.module.LinearA8W8Quant(in_features, out_features, *, bias=True, offset=False, pertoken_scale=False, output_dtype=None)
参数说明
计算参数
- in_features(
int):matmul计算中k轴的值。 - out_features(
int):matmul计算中n轴的值。 - bias(
bool):代表是否需要bias计算参数。如果设置成False,则bias不会加入量化matmul的计算。 - offset(
bool):代表是否需要offset计算参数。如果设置成False,则offset不会加入量化matmul的计算。 - pertoken_scale(
bool):代表是否需要pertoken_scale计算参数。如果设置成False,则pertoken_scale不会加入量化matmul的计算。Atlas 推理系列产品当前不支持pertoken_scale。 - output_dtype(
ScalarType):表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8。- Atlas 推理系列产品:支持输入
int8、float16。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持输入
int8、float16、bfloat16。
- Atlas 推理系列产品:支持输入
计算输入
x1(Tensor):数据类型支持int8。数据格式支持NDND,shape需要在2-6维范围。
变量说明
-
weight(
Tensor):矩阵乘中的weight。数据格式支持int8。数据格式支持NDND,shape为(batch, n, k),shape需要在2-6维范围。- Atlas 推理系列产品:需要调用torchair.experimental.inference.use_internal_format_weight或torch_npu.npu_format_cast完成weight(batch, n, k)高性能数据排布功能。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:需要调用torch_npu.npu_format_cast完成weight(batch,n,k)高性能数据排布功能,但不推荐使用该module方式,推荐torch_npu.npu_quant_matmul。
-
scale(
Tensor):量化计算的scale。数据格式支持NDND,shape为1维(t,),t=1或n,其中n与weight的n一致。如需传入int64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param接口来获取int64数据类型的scale。- Atlas 推理系列产品:数据类型支持
float32、int64。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持
float32、int64、bfloat16。
- Atlas 推理系列产品:数据类型支持
-
offset(
Tensor):量化计算的offset。可选参数。数据类型支持float32,数据格式支持NDND,shape为1维(t,),t=1或n,其中n与weight的n一致。 -
pertoken_scale(
Tensor):可选参数。量化计算的pertoken。数据类型支持float32,数据格式支持NDND,shape为1维(m,),其中m与x1的m一致。目前仅在输出为float16和bfloat16场景下可不为空。Atlas 推理系列产品当前不支持pertoken_scale。 -
bias(
Tensor):可选参数。矩阵乘中的bias。数据格式支持NDND,shape支持1维(n,)或3维(batch, 1, n),n与weight的n一致,同时batch值需要等于x1,weight broadcast后推导出的batch值。当输出为2、4、5、6维情况下,bias shape为1维;当输出为3维情况下,bias shape为1维或3维。- Atlas 推理系列产品:数据类型支持
int32。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持
int32、bfloat16、float16、float32。
- Atlas 推理系列产品:数据类型支持
-
output_dtype(
ScalarType):可选参数。表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8。- Atlas 推理系列产品:支持输入
int8、float16。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持输入
int8、float16、bfloat16。
- Atlas 推理系列产品:支持输入
返回值说明
Tensor
代表量化matmul的计算结果:
- 如果
output_dtype为float16,输出的数据类型为float16。 - 如果
output_dtype为int8或者None,输出的数据类型为int8。 - 如果
output_dtype为bfloat16,输出的数据类型为bfloat16。
约束说明
-
该接口支持推理场景下使用。
-
该接口支持图模式。
-
x1、weight、scale不能是空。 -
x1与weight最后一维的shape大小不能超过65535。 -
输入参数或变量间支持的数据类型组合情况如下:
表1 Atlas 推理系列产品
表2 Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例
-
单算子模式调用
# int8类型输入 import torch import torch_npu import logging import os from torch_npu.contrib.module import LinearA8W8Quant x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset model.bias.data = bias output = model(x1) -
图模式调用
import torch import torch_npu import torchair as tng from torchair.ge_concrete_graph import ge_apis as ge from torchair.configs.compiler_config import CompilerConfig from torch_npu.contrib.module import LinearA8W8Quant import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) import os import numpy as np os.environ["ENABLE_ACLNN"] = "true" config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu() x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu() scale = torch.randn(1, dtype=torch.float32).npu() offset = torch.randn(128, dtype=torch.float32).npu() bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu() in_features = 512 out_features = 128 output_dtype = torch.int8 model = LinearA8W8Quant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype) model = model.npu() model.weight.data = x2 model.scale.data = scale model.offset.data = offset if output_dtype != torch.bfloat16: # 包含npu_trans_quant_param功能,<term>Atlas 推理系列产品</term>还包含使能高带宽的x2数据排布功能 tng.experimental.inference.use_internal_format_weight(model) model.bias.data = bias model = torch.compile(model, backend=npu_backend, dynamic=False) output = model(x1)