torch_npu.npu_grouped_matmul
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 推理系列产品 | √ |
功能说明
-
API功能:
npu_grouped_matmul是一种对多个矩阵乘法(matmul)操作进行分组计算的高效方法。该API实现了对多个矩阵乘法操作的批量处理,通过将具有相同形状或相似形状的矩阵乘法操作组合在一起,减少内存访问开销和计算资源的浪费,从而提高计算效率。 -
计算公式:
公式中@@符号表示矩阵乘法,×\times符号表示矩阵Hadamard乘积:
-
非量化场景(公式1):
yi=xi@weighti+biasiy_i = x_i @ weight_i + bias_i
-
perchannel、pertensor量化场景(公式2):
yi=(xi@weighti)×scalei+offsetiy_i = (x_i @ weight_i) \times scale_i + offset_i
-
x为int8输入,bias为int32输入(公式2-1):yi=(xi@weighti+biasi)×scalei+offsetiy_i = (x_i @ weight_i + bias_i) \times scale_i + offset_i
-
x为int8输入,bias为bfloat16、float16、float32输入,无offset(公式2-2):yi=(xi@weighti)×scalei+biasiy_i = (x_i @ weight_i) \times scale_i + bias_i
-
-
pertoken、pertensor+pertensor、pertensor+perchannel量化场景(公式3):
yi=(xi@weighti+biasi)×scalei×pertokenscaleiy_i = (x_i @ weight_i + bias_i) \times scale_i \times pertokenscale_i
-
x为int8输入,bias为int32输入(公式3-1):yi=(xi@weighti+biasi)×scalei×pertokenscaleiy_i = (x_i @ weight_i + bias_i) \times scale_i \times pertokenscale_i
-
x为int8输入,bias为bfloat16,float16,float32输入(公式3-2):yi=(xi@weighti)×scalei×pertokenscalei+biasiy_i = (x_i @ weight_i) \times scale_i \times pertokenscale_i + bias_i
-
x为int4输入,weight的数据类型为int4,数据排布格式为NZ的输入(公式3-3):yi=xi@(weighti×scalei)×pertokenscaleiy_i=x_i@ (weight_i \times scale_i) \times pertokenscale_i
-
-
伪量化场景(公式4):
yi=xi@((weighti+antiquant_offseti)×antiquant_scalei)+biasiy_i = x_i @ ((weight_i + antiquant\_offset_i) \times antiquant\_scale_i) + bias_i
-
函数原型
npu_grouped_matmul(x, weight, *, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, per_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=None, group_list_type=0, act_type=0, output_dtype=None, tuning_config=None) -> List[Tensor]
参数说明
-
x (
List[Tensor]):必选参数。输入矩阵列表,表示矩阵乘法中的左矩阵。-
支持的数据类型如下:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
float16、float32、bfloat16、int8和int4。 - Atlas 推理系列产品:
float16。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
-
列表最大长度为128。
-
当split_item=0时,张量支持2至6维输入;其他情况下,张量仅支持2维输入。
-
-
weight (
List[Tensor]):必选参数。权重矩阵列表,表示矩阵乘法中的右矩阵。-
支持的数据类型如下:
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- 当
group_list输入类型为List[int]时,支持float16、float32、bfloat16和int8。 - 当
group_list输入类型为Tensor时,支持float16、float32、bfloat16、int4和int8。
- 当
-
Atlas 推理系列产品:
float16。
-
-
列表最大长度为128。
-
每个张量支持2维或3维输入。
-
-
*:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
-
bias (
List[Tensor]):可选参数。每个分组的矩阵乘法输出的独立偏置项。-
支持的数据类型如下:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
float16、float32和int32。 - Atlas 推理系列产品:
float16。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
-
列表长度与weight列表长度相同。
-
每个张量仅支持1维输入。
-
-
scale (
List[Tensor]):可选参数。用于缩放原数值以匹配量化后的范围值,代表量化参数中的缩放因子,对应公式(2)、公式(3)。-
支持的数据类型如下:
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- 当
group_list输入类型为List[int]时,支持int64。 - 当
group_list输入类型为Tensor时,支持float32、bfloat16和int64。
- 当
-
Atlas 推理系列产品:仅支持传入
None。
-
-
列表长度与weight列表长度相同。
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:每个张量仅支持1维输入。
-
-
offset (
List[Tensor]):可选参数。用于调整量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(2)。当前仅支持传入None。 -
antiquant_scale (
List[Tensor]):可选参数。用于缩放原数值以匹配伪量化后的范围值,代表伪量化参数中的缩放因子,对应公式(4)。-
支持的数据类型如下:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
float16、bfloat16。 - Atlas 推理系列产品:仅支持传入
None。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
-
列表长度与weight列表长度相同。
-
每个张量支持输入维度如下(其中gg为matmul组数,GG为pergroup数,GiG_i为第i个tensor的pergroup数):
- 伪量化perchannel场景,
weight为单tensor时,shape限制为[g,n][g, n];weight为多tensor时,shape限制为[ni][n_i]。 - 伪量化pergroup场景,weight为单tensor时,shape限制为[g,G,n][g, G, n]; weight为多tensor时,shape限制为[Gi,ni][G_i, n_i]。
- 伪量化perchannel场景,
-
-
antiquant_offset (
List[Tensor]):可选参数。用于调整伪量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(4)。-
支持的数据类型如下:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
float16、bfloat16。 - Atlas 推理系列产品:仅支持传入
None。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
-
列表长度与
weight列表长度相同。 -
每个张量输入维度和
antiquant_scale输入维度一致。
-
-
per_token_scale (
List[Tensor]):可选参数。用于缩放原数值以匹配量化后的范围值,代表pertoken量化参数中由x量化引入的缩放因子,对应公式(3)和公式(5)。group_list输入类型为List[int]时,当前只支持传入None。group_list输入类型为Tensor时:- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持
float32。 - 列表长度与
x列表长度相同。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:每个张量仅支持1维输入。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持
-
group_list (
List[int]/Tensor):可选参数。用于指定分组的索引,表示x的第0维矩阵乘法的索引情况。数据类型支持int64。- Atlas 推理系列产品:仅支持
Tensor类型。仅支持1维输入,长度与weight列表长度相同。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持
List[int]或Tensor类型。若为Tensor类型,仅支持1维输入,长度与weight列表长度相同。 - 配置值要求如下:
group_list输入类型为List[int]时,配置值必须为非负递增数列,且长度不能为1。group_list输入类型为Tensor时:- 当
group_list_type为0时,group_list必须为非负、单调非递减数列。 - 当
group_list_type为1时,group_list必须为非负数列,且长度不能为1。 - 当
group_list_type为2时,group_listshape为[E,2][E, 2],E表示Group大小,数据排布为[[groupIdx0,groupSize0],[groupIdx1,groupSize1]...][[groupIdx0, groupSize0], [groupIdx1, groupSize1]...],其中groupSize为分组轴上每组大小,必须为非负数。
- 当
- Atlas 推理系列产品:仅支持
-
activation_input (
List[Tensor]):可选参数。代表激活函数的反向输入,当前仅支持传入None。 -
activation_quant_scale (
List[Tensor]):可选参数。预留参数,当前只支持传入None。 -
activation_quant_offset (
List[Tensor]):可选参数。预留参数,当前只支持传入None。 -
split_item (
int):可选参数。用于指定切分模式。数据类型支持int32。- 0、1:输出为多个张量,数量与
weight相同。 - 2、3:输出为单个张量。
- 0、1:输出为多个张量,数量与
-
group_type (
int):可选参数。代表需要分组的轴。数据类型支持int32。-
group_list输入类型为List[int]时仅支持传入None。 -
group_list输入类型为Tensor时,若矩阵乘为C[m,n]=A[m,k]∗B[k,n]C[m,n]=A[m,k]*B[k,n],group_type支持的枚举值为:-1代表不分组;0代表m轴分组;2代表k轴分组。- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前支持取-1、0、2。
- Atlas 推理系列产品:当前只支持取0。
-
-
group_list_type (
int):可选参数。代表group_list的表达形式。数据类型支持int32。-
group_list输入类型为List[int]时仅支持传入None。 -
group_list输入类型为Tensor时可取值0、1或2:- 0:默认值,
group_list中数值为分组轴大小的cumsum结果(累积和)。 - 1:
group_list中数值为分组轴上每组大小。 - 2:
group_listshape为[E,2][E, 2],E表示Group大小,数据排布为[[groupIdx0,groupSize0],[groupIdx1,groupSize1]...][[groupIdx0, groupSize0], [groupIdx1, groupSize1]...],其中groupSize为分组轴上每组大小。 - Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:仅当
x和weight参数输入类型为INT8,并且group_type取0(m轴分组)时,支持取2。 - Atlas 推理系列产品:不支持取2。
- 0:默认值,
-
-
act_type (
int):可选参数。代表激活函数类型。数据类型支持int32。-
group_list输入类型为List[int]时仅支持传入None。 -
group_list输入类型为Tensor时,支持的枚举值包括:0代表不激活;1代表RELU激活;2代表GELU_TANH激活;3代表暂不支持;4代表FAST_GELU激活;5代表SILU激活。- Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0-5。
- Atlas 推理系列产品:当前只支持传入0。
-
-
output_dtype (
torch.dtype):可选参数。输出数据类型。支持的配置包括:None:默认值,表示输出数据类型与输入x的数据类型相同。- 与输出
y数据类型一致的类型,具体参考约束说明。
-
tuning_config (
List[int]):可选参数,数组中的第一个元素表示各个专家处理的token数的预期值,算子tiling时会按照数组中的第一个元素进行最优tiling,性能更优(使用场景参见约束说明);从第二个元素开始预留,用户无须填写,未来会进行扩展。如不使用该参数不传即可。- Atlas 推理系列产品:当前暂不支持该参数。
返回值说明
List[Tensor]:
- 当
split_item为0或1时,返回的张量数量与weight相同。 - 当
split_item为2或3时,返回的张量数量为1。
约束说明
-
该接口支持推理场景下使用。
-
该接口支持图模式。
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:内轴限制InnerLimit为65536。
-
x和weight中每一组tensor的最后一维大小都应小于InnerLimit。xi的最后一维指当x不转置时xi的K轴或当x转置时xi的MM轴。weighti的最后一维指当weight不转置时weighti的NN轴或当weight转置时weighti的KK轴。 -
tuning_config使用场景限制:仅在量化场景(输入
int8,输出为int32/bfloat16/float16/int8,数据类型如下表),且为单tensor单专家的场景下使用。x weight output_dtype y int8int8int8int8int8int8bfloat16bfloat16int8int8float16float16int8int8int32int32 -
各场景输入与输出数据类型使用约束:
-
group_list输入类型为List[int]时,Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品数据类型使用约束。表 1 数据类型约束
场景 x weight bias scale antiquant_scale antiquant_offset output_dtype y 非量化 float16float16float16无需赋值 无需赋值 无需赋值 float16float16非量化 bfloat16bfloat16float32无需赋值 无需赋值 无需赋值 bfloat16bfloat16非量化 float32float32float32无需赋值 无需赋值 无需赋值 float32float32perchannel全量化 int8int8int32int64无需赋值 无需赋值 int8int8伪量化 float16int8float16无需赋值 float16float16float16float16伪量化 bfloat16int8float32无需赋值 bfloat16bfloat16bfloat16bfloat16 -
group_list输入类型为Tensor时,数据类型使用约束。-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
表 2 数据类型约束
场景 x weight bias scale antiquant_scale antiquant_offset per_token_scale output_dtype y 非量化 float16float16float16无需赋值 无需赋值 无需赋值 无需赋值 None/ float16float16非量化 bfloat16bfloat16float32无需赋值 无需赋值 无需赋值 无需赋值 None/ bfloat16bfloat16非量化 float32float32float32无需赋值 无需赋值 无需赋值 无需赋值 None/ float32(仅x/weight/y均为单张量)float32perchannel全量化 int8int8int32int64无需赋值 无需赋值 无需赋值 None/ int8int8perchannel全量化 int8int8int32bfloat16无需赋值 无需赋值 无需赋值 bfloat16bfloat16perchannel全量化 int8int8int32float32无需赋值 无需赋值 无需赋值 float16float16pertoken全量化 int8int8int32bfloat16无需赋值 无需赋值 float32bfloat16bfloat16pertoken全量化 int8int8int32float32无需赋值 无需赋值 float32float16float16pertoken全量化 int4int4无需赋值 uint64无需赋值 无需赋值 None/ float32float16float16pertoken全量化 int4int4无需赋值 uint64无需赋值 无需赋值 None/ float32bfloat16bfloat16伪量化 float16int8/int4float16无需赋值 float16float16无需赋值 None/ float16float16伪量化 bfloat16int8/int4float32无需赋值 bfloat16bfloat16无需赋值 None/ bfloat16bfloat16Note
- 伪量化场景,若
weight的类型为int8,仅支持perchannel模式;若weight的类型为int4,支持perchannel和pergroup两种模式。若为pergroup,pergroup数GG或GiG_i必须要能整除对应的kik_i。若weight为多tensor,定义pergroup长度si=ki/Gis_i= k_i/G_i,要求所有si(i=1,2,...g)s_i(i=1,2,...g)都相等。 - 伪量化场景,若
weight的类型为int4,则weight中每一组tensor的最后一维大小都应是偶数。weighti的最后一维指weight不转置时weighti的N轴或当weight转置时weighti的KK轴。并且在pergroup场景下,当weight转置时,要求pergroup长度sis_i是偶数。tensor转置:指若tensor shape为[M,K][M,K]时,则stride为[1,M][1,M],数据排布为[K,M][K,M]的场景,即非连续tensor。 - 当前PyTorch不支持
int4类型数据,需要使用时可以通过torch_npu.npu_quantize接口使用int32数据表示int4。
- 伪量化场景,若
-
Atlas 推理系列产品:
表 3 数据类型约束
x weight bias scale antiquant_scale antiquant_offset per_token_scale output_dtype y float16float16float16无需赋值 无需赋值 无需赋值 float32float16float16
-
-
-
根据输入
x、输入weight与输出y的Tensor数量不同,支持以下几种场景。场景中的“单”表示单个张量,“多”表示多个张量。场景顺序为x、weight、y,例如“单多单”表示x为单张量,weight为多张量,y为单张量。-
group_list输入类型为List[int]时,Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品各场景的限制。支持场景 场景说明 场景限制 多多多 x和weight为多张量,y为多张量。每组数据的张量是独立的。1.仅支持 split_item为0或1。
2.x中tensor要求维度一致且支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致。
3.x中tensor大于2维,group_list必须传空。
4.x中tensor为2维且传入group_list,group_list的差值需与x中tensor的第一维一一对应。单多单 x为单张量,weight为多张量,y为单张量。1.仅支持 split_item为2或3。
2.必须传group_list,且最后一个值与x中tensor的第一维相等。
3.x、weight、y中tensor需为2维。
4.weight中每个tensor的N轴必须相等。单多多 x为单张量,weight为多张量,y为多张量。1.仅支持 split_item为0或1。
2.必须传group_list,group_list的差值需与y中tensor的第一维一一对应。
3.x、weight、y中tensor需为2维。多多单 x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。1.仅支持 split_item为2或3。
2.x、weight、y中tensor需为2维。
3.weight中每个tensor的N轴必须相等。
4.若传入group_list,group_list的差值需与x中tensor的第一维一一对应。 -
group_list输入类型为Tensor时,各场景的限制。-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
Note
- 量化、伪量化仅支持
group_type为-1和0场景。 - 仅pertoken量化场景支持激活函数计算。
group_type 支持场景 场景说明 场景限制 -1 多多多 x和weight为多张量,y为多张量。每组数据的张量是独立的。1.仅支持 split_item为0或1。
2.x中tensor要求维度一致且支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致。
3.group_list必须传空。
4.支持weight转置,但weight中每个tensor是否转置需保持统一。
5.x不支持转置。0 单单单 x、weight与y均为单张量。1.仅支持 split_item为2或3。
2.weight中tensor需为3维,x、y中tensor需为2维。
3.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等,当group_list_type为2时,第二列数值的总和与x中tensor的第一维相等。
4.group_list第1维最大支持1024,即最多支持1024个group。
5.支持weight转置。
6.x不支持转置。0 单多单 x为单张量,weight为多张量,y为单张量。1.仅支持 split_item为2或3。
2.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等且长度最大为128,当group_list_type为2时,第二列数值的总和与x中tensor的第一维相等且长度最大为128。
3.x、weight、y中tensor需为2维。
4.weight中每个tensor的N轴必须相等。
5.支持weight转置,但weight中每个tensor是否转置需保持统一。
6.x不支持转置。0 多多单 x和weight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。1.仅支持 split_item为2或3。
2.x、weight、y中tensor需为2维。
3.weight中每个tensor的N轴必须相等。
4.若传入group_list,当group_list_type为0时,group_list的差值需与x中tensor的第一维一一对应,当group_list_type为1时,group_list的数值需与x中tensor的第一维一一对应且长度最大为128,当group_list_type为2时,group_list第二列的数值需与x中tensor的第一维一一对应且长度最大为128。
5.支持weight转置,但weight中每个tensor是否转置需保持统一。
6.x不支持转置。 - 量化、伪量化仅支持
-
Atlas 推理系列产品:
输入输出只支持
float16的数据类型,输出y的n轴大小需要是16的倍数。group_type 支持场景 场景说明 场景限制 0 单单单 x、weight与y均为单张量。1.仅支持 split_item为2或3。
2.weight中tensor需为3维,x、y中tensor需为2维。
3.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等。
4.group_list第1维最大支持1024,即最多支持1024个group。
5.支持weight转置,不支持x转置。
-
-
调用示例
-
单算子模式调用
-
通用调用示例
import torch import torch_npu x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] bias1 = torch.randn(256, device='npu', dtype=torch.float16) bias2 = torch.randn(1024, device='npu', dtype=torch.float16) bias3 = torch.randn(128, device='npu', dtype=torch.float16) bias = [bias1, bias2, bias3] group_list = None split_item = 0 npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item, group_type=-1) -
x为int4输入, weight的数据类型为int4数据排布格式为NZ,调用示例如下:
import numpy as np import torch import torch_npu E, K, N = 1, 16, 64 x = torch.randint(10, (15, 16), dtype=torch.int8).npu() weight = torch.randint(10, (1, 16, 64), dtype=torch.int8).npu() x_quant = torch_npu.npu_quantize(x.to(torch.float32), torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) weight_nz = torch_npu.npu_format_cast(weight.to(torch.float32), 29) weight_quant = torch_npu.npu_quantize(weight_nz, torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False) scale = torch.rand((E, 1, N), dtype=torch.float32).npu() k_group = scale.shape[1] scale_np = scale.cpu().numpy() scale_uint32 = scale_np.astype(np.float32) scale_uint32.dtype = np.uint32 scale_uint64 = np.zeros((E, k_group, N * 2), dtype=np.uint32) scale_uint64[...,::2] = scale_uint32 scale_uint64.dtype = np.int64 scale = torch.from_numpy(scale_uint64).npu() group_list = torch.Tensor([14]).to(torch.int64).npu() per_token_scale = torch.rand((15), dtype=torch.float32).npu() output = torch_npu.npu_grouped_matmul([x_quant], [weight_quant], scale=[scale], per_token_scale=[per_token_scale], group_list=group_list, group_list_type=0, group_type=0, split_item=3, output_dtype=torch.float16)
-
-
图模式调用
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:
import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class GMMModel(nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return torch_npu.npu_grouped_matmul(x, weight, group_type=-1) def main(): x1 = torch.randn(256, 256, device='npu', dtype=torch.float16) x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16) x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16) x = [x1, x2, x3] weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16) weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16) weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16) weight = [weight1, weight2, weight3] model = GMMModel().npu() model = torch.compile(model, backend=npu_backend, dynamic=False) custom_output = model(x, weight) if __name__ == '__main__': main()
-