torch_npu.contrib.module.LinearQuant

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品

功能说明

LinearQuant是对torch_npu.npu_quant_matmul接口的封装类,完成A8W8、A4W4量化算子的矩阵乘计算。

函数原型

torch_npu.contrib.module.LinearQuant(in_features, out_features, *, bias=True, offset=False, pertoken_scale=False, device=None, dtype=None, output_dtype=None)

参数说明

计算参数

  • in_featuresint):matmul计算中k轴的值。
  • out_featuresint):matmul计算中n轴的值。
  • biasbool):代表是否需要bias计算参数。如果设置成False,则bias不会加入量化matmul的计算。
  • offsetbool):代表是否需要offset计算参数。如果设置成False,则offset不会加入量化matmul的计算。
  • pertoken_scalebool):可选参数,代表是否需要pertoken_scale计算参数。如果设置成False,则pertoken_scale不会加入量化matmul的计算。Atlas 推理系列产品当前不支持pertoken_scale。
  • device:默认值为None。预留参数,暂未使用
  • dtype:默认值为None。预留参数,暂未使用
  • output_dtypeScalarType):表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8
    • Atlas 推理系列产品:支持输入int8float16
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持输入int8float16bfloat16int32

计算输入

x1Tensor):数据格式支持NDND,shape最少是2维,最多是6维。

  • Atlas 推理系列产品:数据类型支持int8
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int8int32,其中int32表示使用本接口进行int4类型矩阵乘计算,int32类型承载的是int4数据,每个int32数据存放8个int4数据。

变量说明

  • weight(Tensor):与x1的数据类型须保持一致。数据格式支持NDND,shape需要在2-6维范围。当数据类型为int32时,shape必须为2维。

    • Atlas 推理系列产品:数据类型支持int8,需要调用torchair.experimental.inference.use_internal_format_weight或torch_npu.npu_format_cast完成weight(batch, n, k)高性能数据排布功能。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int8int32(同x1,表示int4的数据计算),需要调用torch_npu.npu_format_cast完成weight(batch, n, k)高性能数据排布功能,但不推荐使用该module方式,推荐torch_npu.npu_quant_matmul。
  • scale(Tensor):量化计算的scale。数据格式支持NDND,shape需要是1维(t,),t=1或n,其中n与weight的n一致。如需传入int64数据类型的scale,需要提前调用torch_npu.npu_trans_quant_param接口来获取int64数据类型的scale。

    • Atlas 推理系列产品:数据类型支持float32int64
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32int64bfloat16
  • offset(Tensor):量化计算的offset。可选参数。数据类型支持float32,数据格式支持NDND,shape需要是1维(t,),t=1或n,其中n与weight的n一致。

  • pertoken_scale(Tensor):可选参数,量化计算的pertoken。数据类型支持float32,数据格式支持NDND,shape需要是1维(m,),其中m与x1的m一致。Atlas 推理系列产品当前不支持pertoken_scale。

  • bias(Tensor):可选参数。矩阵乘中的bias。数据格式支持NDND,shape支持1维(n,)或3维(batch, 1, n),n与weight的n一致,同时batch值需要等于x1,weight broadcast后推导出的batch值。当输出为2、4、5、6维情况下,bias shape为1维;当输出为3维情况下,bias shape为1维或3维。

    • Atlas 推理系列产品:数据类型支持int32
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持int32bfloat16float16float32
  • output_dtype(ScalarType):可选参数。表示输出Tensor的数据类型。默认值为None,代表输出Tensor数据类型为int8

    • Atlas 推理系列产品:支持输入int8float16
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持输入int8float16bfloat16int32

返回值说明

Tensor

代表量化matmul的计算结果:

  • 如果output_dtype为int8或者None,输出的数据类型为int8
  • 如果output_dtype为float16,输出的数据类型为float16
  • 如果output_dtype为bfloat16,输出的数据类型为bfloat16
  • 如果output_dtype为int32,输出的数据类型为int32

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式。

  • x1weightscale不能是空。

  • x1weight最后一维的shape大小不能超过65535。

  • int4类型计算的额外约束:

    x1weight的数据类型均为int32,每个int32类型的数据存放8个int4数据。输入shape需要将数据原本int4类型时的最后一维shape缩小8倍。int4数据的最后一维shape应为8的倍数,例如:进行(m, k)乘(k, n)的int4类型矩阵乘计算时,需要输入int32类型,shape为(m, k//8)、(k, n//8)的数据,其中k与n都应是8的倍数。x1只能接受shape为(m, k//8)且数据排布连续的数据,weight只能接受shape为(n, k//8)且数据排布连续的数据。

    Note

    数据排布连续是指数组中所有相邻的数,包括换行时内存地址连续,使用Tensor.is_contiguous返回值为true则表明tensor数据排布连续。

  • 输入参数或变量间支持的数据类型组合情况如下:

    表1 Atlas 推理系列产品

    x1(入参)

    weight(变量)

    scale(变量)

    offset(变量)

    bias(变量)

    pertoken_scale(变量)

    output_dtype(入参或变量)

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    注:None表示传入参数或变量为False的场景。

    表2 Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品

    x1(入参)

    weight(变量)

    scale(变量)

    offset(变量)

    bias(变量)

    pertoken_scale(变量)

    output_dtype(入参或变量)

    int8

    int8

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    int64/float32

    float32/None

    int32/None

    None

    int8

    int8

    int8

    float32/bfloat16

    None

    int32/bfloat16/float32/None

    float32/None

    bfloat16

    int8

    int8

    float32

    None

    int32/float16/float32/None

    float32/None

    float16

    int32

    int32

    int64/float32

    None

    int32/None

    None

    float16

    int8

    int8

    float32/bfloat16

    None

    int32/None

    None

    int32

    注:None表示传入参数或变量为False的场景。

调用示例

  • 单算子模式调用

    • int8类型输入场景,示例代码如下:

      import torch
      import torch_npu
      import logging
      import os
      from torch_npu.contrib.module import LinearQuant
      x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
      x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
      scale = torch.randn(1, dtype=torch.float32).npu()
      offset = torch.randn(128, dtype=torch.float32).npu()
      bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
      in_features = 512
      out_features = 128
      output_dtype = torch.int8
      model = LinearQuant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
      model = model.npu()
      model.weight.data = x2
      model.scale.data = scale
      model.offset.data = offset
      model.bias.data = bias
      # 接口内部调用npu_trans_quant_param功能
      output = model(x1)
      
    • int32类型输入场景,示例代码如下,仅支持如下产品:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品
      import torch
      import torch_npu
      import logging
      import os
      from torch_npu.contrib.module import LinearQuant
      # 用int32类型承载int4数据,实际int4 shape为x1:(1, 512) x2: (128, 512)
      x1 = torch.randint(-1, 1, (1, 64), dtype=torch.int32).npu()
      x2 = torch.randint(-1, 1, (128, 64), dtype=torch.int32).npu()
      scale = torch.randn(1, dtype=torch.float32).npu()
      bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
      in_features = 512
      out_features = 128
      output_dtype = torch.float16
      model = LinearQuant(in_features, out_features, bias=True, offset=False, output_dtype=output_dtype)
      model = model.npu()
      model.weight.data = x2
      model.scale.data = scale
      model.bias.data = bias
      output = model(x1)
      
  • 图模式调用,示例代码如下,仅支持如下产品:

    import torch
    import torch_npu
    import torchair as tng
    from torchair.ge_concrete_graph import ge_apis as ge
    from torchair.configs.compiler_config import CompilerConfig
    from torch_npu.contrib.module import LinearQuant
    import logging
    from torchair.core.utils import logger
    logger.setLevel(logging.DEBUG)
    import os
    import numpy as np
    os.environ["ENABLE_ACLNN"] = "true"
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    x1 = torch.randint(-1, 1, (1, 512), dtype=torch.int8).npu()
    x2 = torch.randint(-1, 1, (128, 512), dtype=torch.int8).npu()
    scale = torch.randn(1, dtype=torch.float32).npu()
    offset = torch.randn(128, dtype=torch.float32).npu()
    bias = torch.randint(-1,1, (128,), dtype=torch.int32).npu()
    in_features = 512
    out_features = 128
    output_dtype = torch.int8
    model = LinearQuant(in_features, out_features, bias=True, offset=True, output_dtype=output_dtype)
    model = model.npu()
    model.weight.data = x2
    model.scale.data = scale
    model.offset.data = offset
    if output_dtype != torch.bfloat16:
        # 使能高带宽x2的数据排布功能
        tng.experimental.inference.use_internal_format_weight(model)
    model.bias.data = bias
    model = torch.compile(model, backend=npu_backend, dynamic=False)
    output = model(x1)