算子名称:RotaryStride
产品支持情况
| 产品 |
是否支持 |
| Atlas A2 训练系列产品 |
是 |
功能说明
参数说明
| 参数名 |
输入/输出/属性 |
描述 |
数据类型 |
数据格式 |
| blockDim |
输入 |
AI CORE的数量,比如:Ascend910B是40。 |
int64_t |
- |
| in |
输入 |
公式中的输入张量x,shape为 (B, S, N, stride) |
BFLOAT16/HALF |
ND |
| sin |
输入 |
公式中的输入张量sin,shape为 (MaxS, D),MaxS指最大序列长度 |
FLOAT |
ND |
| cos |
输入 |
公式中的输入张量x,shape为 (MaxS, D),MaxS指最大序列长度 |
FLOAT |
ND |
| out |
输入 |
公式中的输出张量y,shape为 (B, S, N, stride) |
BFLOAT16/HALF |
ND |
| gbD |
输入 |
head_dim维度的大小 |
int64_t |
- |
约束说明
- x/y 仅支持BFLOAT16/HAFL类型。
- stride >= head_dim。
调用说明
torch.ops.npu_ops_transformer_ext.rotary_stride(blockDim, in, sin, cos, out, gbD)