torch_npu.npu_rotary_mul

产品支持情况

产品 是否支持
Atlas A3 训练系列产品
Atlas A2 训练系列产品
Atlas 训练系列产品
Atlas 推理系列产品

功能说明

  • API功能:实现Rotary Position Embedding (RoPE) 旋转位置编码,通过对输入特征进行二维平面旋转注入位置信息。

  • 计算公式:

    output=x∗cos+rotate(x)∗sinoutput = x * cos + rotate(x) * sin

    其中xx是输入inputcoscossinsin分别是旋转系数输入r1r2,输入rotate支持两种计算模式:

    • 当rotary_mode='half'时,将输入向量沿最后一个维度分为两半,然后应用旋转:

      x1,x2=chunk(x,2,dim=−1)rotate(x)=concat(−x2,x1)x_1, x_2 = chunk(x,2,dim=-1)\\ rotate(x) = concat(-x_2,x_1)

    • 当rotary_mode='interleave'时,将输入向量按交错顺序处理,然后应用旋转:

      x1=x[...,::2],x2=x[...,1::2]rotate(x)=rearrange(torch.stack((−x2,x1),dim=−1),"...dtwo−>...(dtwo)",two=2)x_1 = x[..., ::2], x_2 = x[..., 1::2]\\ rotate(x) = rearrange(torch.stack((-x_2, x_1), dim=-1), "... d two -> ...(d two)", two=2)\\

  • 等价计算逻辑:

    可使用fused_rotary_position_embedding等价替换torch_npu.npu_rotary_mul,两者计算逻辑一致。

    import torch
    from einops import rearrange
    
    # mode = 0
    def rotate_half(x):
        x1, x2 = torch.chunk(x, 2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    # mode = 1
    def rotate_interleaved(x):
       x1 = x[..., ::2]
       x2 = x[..., 1::2]
       return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2)
    
    def fused_rotary_position_embedding(x, cos, sin, interleaved=False):
        if not interleaved:
            return x * cos + rotate_half(x) * sin
        else:
            return x * cos + rotate_interleaved(x) * sin
    

函数原型

torch_npu.npu_rotary_mul(input, r1, r2, rotary_mode='half') -> Tensor

Note

在模型训练场景中,正向算子的输入input将被保留以供反向计算时使用。在r1r2不需要计算反向梯度场景下(requires_grad=False),使用该接口相较融合前小算子使用的设备内存占用会有所增加。

参数说明

  • input (Tensor):必选参数,输入维度支持3维、4维,数据类型支持float16bfloat16float32
  • r1 (Tensor):必选参数,表示coscos旋转系数,输入维度支持3维、4维,数据类型支持float16bfloat16float32
  • r2 (Tensor):必选参数,表示sinsin旋转系数,输入维度支持3维、4维,数据类型支持float16bfloat16float32
  • rotary_mode (str):可选参数,数据类型支持str,用于选择计算模式,支持halfinterleave两种模式。默认值为half

返回值说明

Tensor

输出计算结果,shape和dtype需与input一致。

约束说明

  • jit_compile=False场景(适用Atlas A2 训练系列产品,Atlas A3 训练系列产品):

    • half模式:

      input:layout支持:BNSD、BSND、SBND、TND;D<896BNSD、BSND、SBND、TND;D < 896,且为2的倍数;B,N<1000B, N < 1000;当需要计算cos/sincos/sin的反向梯度时,B∗N<=1024B*N <= 1024

      r1、r2:数据范围:[-1, 1];对应input layout的支持情况:

      • x为BNSDBNSD: 11SD、B1SD、BNSD11SD、B1SD、BNSD

      • x为BSNDBSND: 1S1D、BS1D、BSND1S1D、BS1D、BSND

      • x为SBNDSBND: S11D、SB1D、SBNDS11D、SB1D、SBND;

      • x为TNDTND: T1D、TNDT1D、TND

        [!NOTICE]
        half模式下,当输入layout是BNSDBNSD,且DD为非32Bytes对齐时,建议不使用该融合算子(模型启动脚本中不开启--use-fused-rotary-pos-emb选项),否则可能出现性能下降。

    • interleave模式:

      input:layout支持:BNSD、BSND、SBND、TND;B∗N<1000;D<896BNSD、BSND、SBND、TND; B*N < 1000; D < 896, 且DD为2的倍数;

      r1、r2:数据范围:[-1, 1];对应input layout的支持情况:

      • x为BNSD:11SDBNSD: 11SD;
      • x为BSND:1S1DBSND: 1S1D;
      • x为SBND:S11DSBND: S11D
      • x为TND:T1DTND: T1D
  • jit_compile=True场景(适用Atlas 训练系列产品,Atlas A2 训练系列产品,Atlas 推理系列产品):

    仅支持rotary_mode为half模式,且r1/r2 layout一般为11SD、1S1D、S11D11SD、1S1D、S11D

    shape要求输入为4维,其中BB维度和NN维度数值需小于等于1000,DD维度数值为128。

    广播场景下,广播轴的总数据量不能超过1024。

调用示例

  • 四维输入示例:
>>> import torch
>>> import torch_npu
>>>
>>> x = torch.rand(2, 2, 5, 128).npu()
>>> r1 = torch.rand(1, 2, 1, 128).npu()
>>> r2 = torch.rand(1, 2, 1, 128).npu()
>>> out = torch_npu.npu_rotary_mul(x, r1, r2)
>>> out.shape
torch.Size([2, 2, 5, 128])
>>> out
tensor([[[[ 0.1017, -0.0871,  0.2722,  ...,  0.4668,  0.4320,  0.4252],
          [ 0.2908, -0.0068,  0.4026,  ...,  0.1540,  0.2653,  0.6754],
          [ 0.1124, -0.0637,  0.0834,  ...,  0.5127,  0.1423,  0.0636],
          [ 0.1014,  0.0129,  0.3392,  ...,  0.7390,  0.7147,  0.1751],
          [ 0.3266, -0.0177,  0.2263,  ...,  0.9936,  0.3717,  0.3403]],

         [[ 0.1999, -0.5646,  0.0910,  ...,  0.1747,  0.3801,  0.0675],
          [ 0.2688,  0.3714,  0.2647,  ...,  0.0769,  0.0481,  0.1988],
          [ 0.1404,  0.1749,  0.4082,  ...,  0.2291,  0.5246,  0.0615],
          [-0.4368,  0.2962,  0.2655,  ...,  0.0284,  0.5518,  0.2853],
          [ 0.0812,  0.4214,  0.4906,  ...,  0.1684,  0.5756,  0.2966]]],


        [[[ 0.3887, -0.0777,  0.0328,  ...,  0.4946,  0.5197,  0.8397],
          [ 0.0283, -0.0858,  0.2244,  ...,  0.2542,  0.3899,  0.8239],
          [ 0.1993, -0.0765,  0.2022,  ...,  0.7701,  0.6514,  0.0557],
          [ 0.1424, -0.0795,  0.4005,  ...,  0.3839,  0.5843,  0.2539],
          [ 0.2812, -0.0479,  0.1748,  ...,  0.6403,  0.5840,  0.3274]],

         [[ 0.1308, -0.2528,  0.6242,  ...,  0.2614,  0.4986,  0.0893],
          [ 0.3121,  0.1706,  0.6207,  ...,  0.0731,  0.1644,  0.2398],
          [ 0.3232,  0.0695,  0.2875,  ...,  0.1104,  0.3334,  0.2233],
          [ 0.4909,  0.3554,  0.8431,  ...,  0.2265,  0.4873,  0.3106],
          [-0.2269, -0.1447, -0.0395,  ...,  0.1374,  0.2142,  0.3628]]]],
       device='npu:0')
  • 三维输入示例:
>>> import torch
>>> import torch_npu
>>>
>>> x = torch.rand(2, 5, 128).npu()
>>> r1 = torch.rand(2, 1, 128).npu()
>>> r2 = torch.rand(2, 1, 128).npu()
>>> out = torch_npu.npu_rotary_mul(x, r1, r2, "half")
>>> out
tensor([[[-0.1200, -0.2515, -0.3189,  ...,  0.2283,  1.1038,  0.3439],
         [ 0.1083,  0.0257,  0.1864,  ...,  0.5940,  0.8644,  0.5961],
         [-0.0147, -0.1542,  0.0516,  ...,  0.7441,  0.2782,  0.4797],
         [-0.0601, -0.0338, -0.3731,  ...,  0.9809,  0.7416,  0.4876],
         [ 0.1785, -0.0542, -0.3634,  ...,  0.5057,  0.7511,  1.3088]],

        [[ 0.0076,  0.0931, -0.4161,  ...,  0.4964,  0.2680,  0.1291],
         [-0.2149,  0.1523, -0.0274,  ...,  0.1997,  0.8318,  0.2630],
         [ 0.1087,  0.4846,  0.0684,  ...,  0.0183,  0.9503,  0.0555],
         [-0.1946,  0.6020, -0.6751,  ...,  0.8629,  0.5454,  0.1392],
         [ 0.0772,  0.5112, -0.4875,  ...,  0.7065,  0.6798,  0.1513]]],
       device='npu:0')