torch_npu.npu_grouped_matmul

产品支持情况

产品 是否支持
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 推理系列产品

功能说明

  • API功能:npu_grouped_matmul是一种对多个矩阵乘法(matmul)操作进行分组计算的高效方法。该API实现了对多个矩阵乘法操作的批量处理,通过将具有相同形状或相似形状的矩阵乘法操作组合在一起,减少内存访问开销和计算资源的浪费,从而提高计算效率。

  • 计算公式:

    公式中@@符号表示矩阵乘法,×\times符号表示矩阵Hadamard乘积:

    • 非量化场景(公式1):

      yi=xi@weighti+biasiy_i = x_i @ weight_i + bias_i

    • perchannel、pertensor量化场景(公式2):

      yi=(xi@weighti)×scalei+offsetiy_i = (x_i @ weight_i) \times scale_i + offset_i

      • xint8输入,biasint32输入(公式2-1):

        yi=(xi@weighti+biasi)×scalei+offsetiy_i = (x_i @ weight_i + bias_i) \times scale_i + offset_i

      • xint8输入,biasbfloat16float16float32输入,无offset(公式2-2):

        yi=(xi@weighti)×scalei+biasiy_i = (x_i @ weight_i) \times scale_i + bias_i

    • pertoken、pertensor+pertensor、pertensor+perchannel量化场景(公式3):

      yi=(xi@weighti+biasi)×scalei×pertokenscaleiy_i = (x_i @ weight_i + bias_i) \times scale_i \times pertokenscale_i

      • xint8输入,bias为int32输入(公式3-1):

        yi=(xi@weighti+biasi)×scalei×pertokenscaleiy_i = (x_i @ weight_i + bias_i) \times scale_i \times pertokenscale_i

      • xint8输入,biasbfloat16float16float32输入(公式3-2):

        yi=(xi@weighti)×scalei×pertokenscalei+biasiy_i = (x_i @ weight_i) \times scale_i \times pertokenscale_i + bias_i

      • xint4输入, weight的数据类型为int4,数据排布格式为NZ的输入(公式3-3):

        yi=xi@(weighti×scalei)×pertokenscaleiy_i=x_i@ (weight_i \times scale_i) \times pertokenscale_i

    • 伪量化场景(公式4):

      yi=xi@((weighti+antiquant_offseti)×antiquant_scalei)+biasiy_i = x_i @ ((weight_i + antiquant\_offset_i) \times antiquant\_scale_i) + bias_i

函数原型

npu_grouped_matmul(x, weight, *, bias=None, scale=None, offset=None, antiquant_scale=None, antiquant_offset=None, per_token_scale=None, group_list=None, activation_input=None, activation_quant_scale=None, activation_quant_offset=None, split_item=0, group_type=None, group_list_type=0, act_type=0, output_dtype=None, tuning_config=None) -> List[Tensor]

参数说明

  • x (List[Tensor]):必选参数。输入矩阵列表,表示矩阵乘法中的左矩阵。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:float16float32bfloat16int8int4
      • Atlas 推理系列产品:float16
    • 列表最大长度为128。

    • 当split_item=0时,张量支持2至6维输入;其他情况下,张量仅支持2维输入。

  • weight (List[Tensor]):必选参数。权重矩阵列表,表示矩阵乘法中的右矩阵。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:

        • group_list输入类型为List[int]时,支持float16float32bfloat16int8
        • group_list输入类型为Tensor时,支持float16float32bfloat16int4int8
      • Atlas 推理系列产品:float16

    • 列表最大长度为128。

    • 每个张量支持2维或3维输入。

  • *:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。

  • bias (List[Tensor]):可选参数。每个分组的矩阵乘法输出的独立偏置项。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:float16float32int32
      • Atlas 推理系列产品:float16
    • 列表长度与weight列表长度相同。

    • 每个张量仅支持1维输入。

  • scale (List[Tensor]):可选参数。用于缩放原数值以匹配量化后的范围值,代表量化参数中的缩放因子,对应公式(2)、公式(3)。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:

        • group_list输入类型为List[int]时,支持int64
        • group_list输入类型为Tensor时,支持float32bfloat16int64
      • Atlas 推理系列产品:仅支持传入None

    • 列表长度与weight列表长度相同。

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:每个张量仅支持1维输入。

  • offset (List[Tensor]):可选参数。用于调整量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(2)。当前仅支持传入None

  • antiquant_scale (List[Tensor]):可选参数。用于缩放原数值以匹配伪量化后的范围值,代表伪量化参数中的缩放因子,对应公式(4)。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:float16bfloat16
      • Atlas 推理系列产品:仅支持传入None
    • 列表长度与weight列表长度相同。

    • 每个张量支持输入维度如下(其中gg为matmul组数,GG为pergroup数,GiG_i为第i个tensor的pergroup数):

      • 伪量化perchannel场景,weight为单tensor时,shape限制为[g,n][g, n]weight为多tensor时,shape限制为[ni][n_i]
      • 伪量化pergroup场景,weight为单tensor时,shape限制为[g,G,n][g, G, n]; weight为多tensor时,shape限制为[Gi,ni][G_i, n_i]
  • antiquant_offset (List[Tensor]):可选参数。用于调整伪量化后的数值偏移量,从而更准确地表示原始浮点数值,对应公式(4)。

    • 支持的数据类型如下:

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:float16bfloat16
      • Atlas 推理系列产品:仅支持传入None
    • 列表长度与weight列表长度相同。

    • 每个张量输入维度和antiquant_scale输入维度一致。

  • per_token_scale (List[Tensor]):可选参数。用于缩放原数值以匹配量化后的范围值,代表pertoken量化参数中由x量化引入的缩放因子,对应公式(3)和公式(5)。

    • group_list输入类型为List[int]时,当前只支持传入None
    • group_list输入类型为Tensor时:
      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float32
      • 列表长度与x列表长度相同。
      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:每个张量仅支持1维输入。
  • group_list (List[int]/Tensor):可选参数。用于指定分组的索引,表示x的第0维矩阵乘法的索引情况。数据类型支持int64

    • Atlas 推理系列产品:仅支持Tensor类型。仅支持1维输入,长度与weight列表长度相同。
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持List[int]Tensor类型。若为Tensor类型,仅支持1维输入,长度与weight列表长度相同。
    • 配置值要求如下:
      • group_list输入类型为List[int]时,配置值必须为非负递增数列,且长度不能为1。
      • group_list输入类型为Tensor时:
        • group_list_type为0时,group_list必须为非负、单调非递减数列。
        • group_list_type为1时,group_list必须为非负数列,且长度不能为1。
        • group_list_type为2时,group_list shape为[E,2][E, 2],E表示Group大小,数据排布为[[groupIdx0,groupSize0],[groupIdx1,groupSize1]...][[groupIdx0, groupSize0], [groupIdx1, groupSize1]...],其中groupSize为分组轴上每组大小,必须为非负数。
  • activation_input (List[Tensor]):可选参数。代表激活函数的反向输入,当前仅支持传入None

  • activation_quant_scale (List[Tensor]):可选参数。预留参数,当前只支持传入None

  • activation_quant_offset (List[Tensor]):可选参数。预留参数,当前只支持传入None

  • split_item (int):可选参数。用于指定切分模式。数据类型支持int32

    • 0、1:输出为多个张量,数量与weight相同。
    • 2、3:输出为单个张量。
  • group_type (int):可选参数。代表需要分组的轴。数据类型支持int32

    • group_list输入类型为List[int]时仅支持传入None

    • group_list输入类型为Tensor时,若矩阵乘为C[m,n]=A[m,k]∗B[k,n]C[m,n]=A[m,k]*B[k,n]group_type支持的枚举值为:-1代表不分组;0代表m轴分组;2代表k轴分组。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前支持取-1、0、2。
      • Atlas 推理系列产品:当前只支持取0。
  • group_list_type (int):可选参数。代表group_list的表达形式。数据类型支持int32

    • group_list输入类型为List[int]时仅支持传入None

    • group_list输入类型为Tensor时可取值0、1或2:

      • 0:默认值,group_list中数值为分组轴大小的cumsum结果(累积和)。
      • 1:group_list中数值为分组轴上每组大小。
      • 2:group_list shape为[E,2][E, 2],E表示Group大小,数据排布为[[groupIdx0,groupSize0],[groupIdx1,groupSize1]...][[groupIdx0, groupSize0], [groupIdx1, groupSize1]...],其中groupSize为分组轴上每组大小。
      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:仅当xweight参数输入类型为INT8,并且group_type取0(m轴分组)时,支持取2。
      • Atlas 推理系列产品:不支持取2。
  • act_type (int):可选参数。代表激活函数类型。数据类型支持int32

    • group_list输入类型为List[int]时仅支持传入None

    • group_list输入类型为Tensor时,支持的枚举值包括:0代表不激活;1代表RELU激活;2代表GELU_TANH激活;3代表暂不支持;4代表FAST_GELU激活;5代表SILU激活。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0-5。
      • Atlas 推理系列产品:当前只支持传入0。
  • output_dtype (torch.dtype):可选参数。输出数据类型。支持的配置包括:

    • None:默认值,表示输出数据类型与输入x的数据类型相同。
    • 与输出y数据类型一致的类型,具体参考约束说明
  • tuning_config (List[int]):可选参数,数组中的第一个元素表示各个专家处理的token数的预期值,算子tiling时会按照数组中的第一个元素进行最优tiling,性能更优(使用场景参见约束说明);从第二个元素开始预留,用户无须填写,未来会进行扩展。如不使用该参数不传即可。

    • Atlas 推理系列产品:当前暂不支持该参数。

返回值说明

List[Tensor]

  • split_item为0或1时,返回的张量数量与weight相同。
  • split_item为2或3时,返回的张量数量为1。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式。

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:内轴限制InnerLimit为65536。

  • xweight中每一组tensor的最后一维大小都应小于InnerLimit。xi的最后一维指当x不转置时xi的K轴或当x转置时xiMM轴。weighti的最后一维指当weight不转置时weightiNN轴或当weight转置时weightiKK轴。

  • tuning_config使用场景限制:

    仅在量化场景(输入int8,输出为int32/bfloat16/float16/int8,数据类型如下表),且为单tensor单专家的场景下使用。

    x weight output_dtype y
    int8 int8 int8 int8
    int8 int8 bfloat16 bfloat16
    int8 int8 float16 float16
    int8 int8 int32 int32
  • 各场景输入与输出数据类型使用约束:

    • group_list输入类型为List[int],Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品数据类型使用约束。

      表 1 数据类型约束

      场景 x weight bias scale antiquant_scale antiquant_offset output_dtype y
      非量化 float16 float16 float16 无需赋值 无需赋值 无需赋值 float16 float16
      非量化 bfloat16 bfloat16 float32 无需赋值 无需赋值 无需赋值 bfloat16 bfloat16
      非量化 float32 float32 float32 无需赋值 无需赋值 无需赋值 float32 float32
      perchannel全量化 int8 int8 int32 int64 无需赋值 无需赋值 int8 int8
      伪量化 float16 int8 float16 无需赋值 float16 float16 float16 float16
      伪量化 bfloat16 int8 float32 无需赋值 bfloat16 bfloat16 bfloat16 bfloat16
    • group_list输入类型为Tensor,数据类型使用约束。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:

        表 2 数据类型约束

        场景 x weight bias scale antiquant_scale antiquant_offset per_token_scale output_dtype y
        非量化 float16 float16 float16 无需赋值 无需赋值 无需赋值 无需赋值 None/float16 float16
        非量化 bfloat16 bfloat16 float32 无需赋值 无需赋值 无需赋值 无需赋值 None/bfloat16 bfloat16
        非量化 float32 float32 float32 无需赋值 无需赋值 无需赋值 无需赋值 None/float32(仅x/weight/y均为单张量) float32
        perchannel全量化 int8 int8 int32 int64 无需赋值 无需赋值 无需赋值 None/int8 int8
        perchannel全量化 int8 int8 int32 bfloat16 无需赋值 无需赋值 无需赋值 bfloat16 bfloat16
        perchannel全量化 int8 int8 int32 float32 无需赋值 无需赋值 无需赋值 float16 float16
        pertoken全量化 int8 int8 int32 bfloat16 无需赋值 无需赋值 float32 bfloat16 bfloat16
        pertoken全量化 int8 int8 int32 float32 无需赋值 无需赋值 float32 float16 float16
        pertoken全量化 int4 int4 无需赋值 uint64 无需赋值 无需赋值 None/float32 float16 float16
        pertoken全量化 int4 int4 无需赋值 uint64 无需赋值 无需赋值 None/float32 bfloat16 bfloat16
        伪量化 float16 int8/int4 float16 无需赋值 float16 float16 无需赋值 None/float16 float16
        伪量化 bfloat16 int8/int4 float32 无需赋值 bfloat16 bfloat16 无需赋值 None/bfloat16 bfloat16

        Note

        • 伪量化场景,若weight的类型为int8,仅支持perchannel模式;若weight的类型为int4,支持perchannel和pergroup两种模式。若为pergroup,pergroup数GGGiG_i必须要能整除对应的kik_i。若weight为多tensor,定义pergroup长度si=ki/Gis_i= k_i/G_i,要求所有si(i=1,2,...g)s_i(i=1,2,...g)都相等。
        • 伪量化场景,若weight的类型为int4,则weight中每一组tensor的最后一维大小都应是偶数。weighti的最后一维指weight不转置时weighti的N轴或当weight转置时weightiKK轴。并且在pergroup场景下,当weight转置时,要求pergroup长度sis_i是偶数。tensor转置:指若tensor shape为[M,K][M,K]时,则stride为[1,M][1,M],数据排布为[K,M][K,M]的场景,即非连续tensor。
        • 当前PyTorch不支持int4类型数据,需要使用时可以通过torch_npu.npu_quantize接口使用int32数据表示int4
      • Atlas 推理系列产品:

        表 3 数据类型约束

        x weight bias scale antiquant_scale antiquant_offset per_token_scale output_dtype y
        float16 float16 float16 无需赋值 无需赋值 无需赋值 float32 float16 float16
  • 根据输入x、输入weight与输出y的Tensor数量不同,支持以下几种场景。场景中的“单”表示单个张量,“多”表示多个张量。场景顺序为xweighty,例如“单多单”表示x为单张量,weight为多张量,y为单张量。

    • group_list输入类型为List[int],Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品各场景的限制。

      支持场景 场景说明 场景限制
      多多多 xweight为多张量,y为多张量。每组数据的张量是独立的。 1.仅支持split_item为0或1。
      2.x中tensor要求维度一致且支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致。
      3.x中tensor大于2维,group_list必须传空。
      4.x中tensor为2维且传入group_listgroup_list的差值需与x中tensor的第一维一一对应。
      单多单 x为单张量,weight为多张量,y为单张量。 1.仅支持split_item为2或3。
      2.必须传group_list,且最后一个值与x中tensor的第一维相等。
      3.xweighty中tensor需为2维。
      4.weight中每个tensor的N轴必须相等。
      单多多 x为单张量,weight为多张量,y为多张量。 1.仅支持split_item为0或1。
      2.必须传group_listgroup_list的差值需与y中tensor的第一维一一对应。
      3.xweighty中tensor需为2维。
      多多单 xweight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。 1.仅支持split_item为2或3。
      2.xweighty中tensor需为2维。
      3.weight中每个tensor的N轴必须相等。
      4.若传入group_listgroup_list的差值需与x中tensor的第一维一一对应。
    • group_list输入类型为Tensor,各场景的限制。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:

        Note

        • 量化、伪量化仅支持group_type为-1和0场景。
        • 仅pertoken量化场景支持激活函数计算。
        group_type 支持场景 场景说明 场景限制
        -1 多多多 xweight为多张量,y为多张量。每组数据的张量是独立的。 1.仅支持split_item为0或1。
        2.x中tensor要求维度一致且支持2-6维,weight中tensor需为2维,y中tensor维度和x保持一致。
        3.group_list必须传空。
        4.支持weight转置,但weight中每个tensor是否转置需保持统一。
        5.x不支持转置。
        0 单单单 xweighty均为单张量。 1.仅支持split_item为2或3。
        2.weight中tensor需为3维,xy中tensor需为2维。
        3.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等,当group_list_type为2时,第二列数值的总和与x中tensor的第一维相等。
        4.group_list第1维最大支持1024,即最多支持1024个group。
        5.支持weight转置。
        6.x不支持转置。
        0 单多单 x为单张量,weight为多张量,y为单张量。 1.仅支持split_item为2或3。
        2.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等且长度最大为128,当group_list_type为2时,第二列数值的总和与x中tensor的第一维相等且长度最大为128。
        3.xweighty中tensor需为2维。
        4.weight中每个tensor的N轴必须相等。
        5.支持weight转置,但weight中每个tensor是否转置需保持统一。
        6.x不支持转置。
        0 多多单 xweight为多张量,y为单张量。每组矩阵乘法的结果连续存放在同一个张量中。 1.仅支持split_item为2或3。
        2.xweighty中tensor需为2维。
        3.weight中每个tensor的N轴必须相等。
        4.若传入group_list,当group_list_type为0时,group_list的差值需与x中tensor的第一维一一对应,当group_list_type为1时,group_list的数值需与x中tensor的第一维一一对应且长度最大为128,当group_list_type为2时,group_list第二列的数值需与x中tensor的第一维一一对应且长度最大为128。
        5.支持weight转置,但weight中每个tensor是否转置需保持统一。
        6.x不支持转置。
      • Atlas 推理系列产品:

        输入输出只支持float16的数据类型,输出y的n轴大小需要是16的倍数。

        group_type 支持场景 场景说明 场景限制
        0 单单单 xweighty均为单张量。 1.仅支持split_item为2或3。
        2.weight中tensor需为3维,xy中tensor需为2维。
        3.必须传group_list,且当group_list_type为0时,最后一个值与x中tensor的第一维相等,当group_list_type为1时,数值的总和与x中tensor的第一维相等。
        4.group_list第1维最大支持1024,即最多支持1024个group。
        5.支持weight转置,不支持x转置。

调用示例

  • 单算子模式调用

    • 通用调用示例

      import torch
      import torch_npu
      
      x1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
      x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16)
      x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16)
      x = [x1, x2, x3]
      
      weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
      weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16)
      weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16)
      weight = [weight1, weight2, weight3]
      
      bias1 = torch.randn(256, device='npu', dtype=torch.float16)
      bias2 = torch.randn(1024, device='npu', dtype=torch.float16)
      bias3 = torch.randn(128, device='npu', dtype=torch.float16)
      bias = [bias1, bias2, bias3]
      
      group_list = None
      split_item = 0
      npu_out = torch_npu.npu_grouped_matmul(x, weight, bias=bias, group_list=group_list, split_item=split_item, group_type=-1)
      
    • x为int4输入, weight的数据类型为int4数据排布格式为NZ,调用示例如下:

      import numpy as np
      import torch
      import torch_npu
      
      E, K, N = 1, 16, 64
      x = torch.randint(10, (15, 16), dtype=torch.int8).npu()
      weight = torch.randint(10, (1, 16, 64), dtype=torch.int8).npu()
      
      x_quant = torch_npu.npu_quantize(x.to(torch.float32), torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False)
      weight_nz = torch_npu.npu_format_cast(weight.to(torch.float32), 29)
      weight_quant = torch_npu.npu_quantize(weight_nz, torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False)
      
      scale = torch.rand((E, 1, N), dtype=torch.float32).npu()
      
      k_group = scale.shape[1]
      scale_np = scale.cpu().numpy()
      scale_uint32 = scale_np.astype(np.float32)
      scale_uint32.dtype = np.uint32
      scale_uint64 = np.zeros((E, k_group, N * 2), dtype=np.uint32)
      scale_uint64[...,::2] = scale_uint32
      scale_uint64.dtype = np.int64
      scale = torch.from_numpy(scale_uint64).npu()
      
      group_list = torch.Tensor([14]).to(torch.int64).npu()
      per_token_scale = torch.rand((15), dtype=torch.float32).npu()
      
      output = torch_npu.npu_grouped_matmul([x_quant], [weight_quant], scale=[scale], per_token_scale=[per_token_scale],
                                              group_list=group_list, group_list_type=0, group_type=0,
                                              split_item=3, output_dtype=torch.float16)
      
  • 图模式调用

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品/Atlas 推理系列产品/Atlas A3 训练系列产品/Atlas A3 推理系列产品:

      import torch
      import torch.nn as nn
      import torch_npu
      import torchair as tng
      from torchair.configs.compiler_config import CompilerConfig
      
      config = CompilerConfig()
      npu_backend = tng.get_npu_backend(compiler_config=config)
      
      class GMMModel(nn.Module):
          def __init__(self):
              super().__init__()
          
          def forward(self, x, weight):
              return torch_npu.npu_grouped_matmul(x, weight, group_type=-1)
      
      def main():
          x1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
          x2 = torch.randn(1024, 256, device='npu', dtype=torch.float16)
          x3 = torch.randn(512, 1024, device='npu', dtype=torch.float16)
          x = [x1, x2, x3]
          
          weight1 = torch.randn(256, 256, device='npu', dtype=torch.float16)
          weight2 = torch.randn(256, 1024, device='npu', dtype=torch.float16)
          weight3 = torch.randn(1024, 128, device='npu', dtype=torch.float16)
          weight = [weight1, weight2, weight3]
          
          model = GMMModel().npu()
          model = torch.compile(model, backend=npu_backend, dynamic=False)
          custom_output = model(x, weight)
      
      if __name__ == '__main__':
          main()