算子名称:RotaryStride

产品支持情况

产品 是否支持
Atlas A2 训练系列产品

功能说明

  • 算子功能:带跨步的位置编码。
  • 计算公式:

    ymi=xmicos⁡(mθi)−xmi+1sin⁡(mθi)ymi+1=xmisin⁡(mθi)+xmi+1cos⁡(mθi)y_m^{i} = x_m^{i} \cos(m\theta_i) - x_m^{i+1} \sin(m\theta_i) \\ y_m^{i+1} = x_m^{i} \sin(m\theta_i) + x_m^{i+1} \cos(m\theta_i)

  • 在基本RoPE位置编码的基础上,追加了跨步功能。输入是shape为(Batch, Seqlen, Headnum, Stride)的张量,Stride为跨步大小,仅计算Stride的后head_dim维。head_dim为实际需要计算旋转位置编码的维度大小,Stride大于等于head_dim。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
blockDim 输入 AI CORE的数量,比如:Ascend910B是40。 int64_t -
in 输入 公式中的输入张量x,shape为 (B, S, N, stride) BFLOAT16/HALF ND
sin 输入 公式中的输入张量sin,shape为 (MaxS, D),MaxS指最大序列长度 FLOAT ND
cos 输入 公式中的输入张量x,shape为 (MaxS, D),MaxS指最大序列长度 FLOAT ND
out 输入 公式中的输出张量y,shape为 (B, S, N, stride) BFLOAT16/HALF ND
gbD 输入 head_dim维度的大小 int64_t -

约束说明

  • x/y 仅支持BFLOAT16/HAFL类型。
  • stride >= head_dim。

调用说明

torch.ops.npu_ops_transformer_ext.rotary_stride(blockDim, in, sin, cos, out, gbD)