RotaryPositionEmbeddingGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:执行单路旋转位置编码RotaryPositionEmbedding的反向计算。
-
计算公式:
取旋转位置编码的正向计算中,broadcast的轴列表为
dims,则计算公式可表达如下:- Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
(1)half模式(mode等于0):
dy1,dy2=chunk(dy,chunks=2,dim=−1)dy1, dy2 = chunk(dy, chunks=2, dim=-1)
cos1,cos2=chunk(cos,chunks=2,dim=−1)cos1, cos2 = chunk(cos, chunks=2, dim=-1)
sin1,sin2=chunk(sin,chunks=2,dim=−1)sin1, sin2 = chunk(sin, chunks=2, dim=-1)
x1,x2=chunk(x,chunks=2,dim=−1)x1, x2 = chunk(x, chunks=2, dim=-1)
dx=cat((cos1∗dy1+sin2∗dy2,cos2∗dy2−sin1∗dy1),dim=−1)dx = cat((cos1 * dy1 + sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim=-1)
dcos=sum(dy∗x,dims)dcos = sum(dy * x, dims)
dsin=sum(dy∗cat((−x2,x1),dim=−1),dims)dsin = sum(dy * cat((-x2, x1), dim=-1), dims)
(2)interleave模式(mode等于1):
dy1,dy2=dy[...,::2],dy[...,1::2]dy1, dy2 = dy[..., :: 2], dy[..., 1 :: 2]
cos1,cos2=cos[...,::2],cos[...,1::2]cos1, cos2 = cos[..., :: 2], cos[..., 1 :: 2]
sin1,sin2=sin[...,::2],sin[...,1::2]sin1, sin2 = sin[..., :: 2], sin[..., 1 :: 2]
x1,x2=x[...,::2],x[...,1::2]x1, x2 = x[..., :: 2], x[..., 1 :: 2]
dx=stack((cos1∗dy1+sin2∗dy2,cos2∗dy2−sin1∗dy1),dim=−1).reshape(dy.shape)dx = stack((cos1 * dy1 + sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim=-1).reshape(dy.shape)
dcos=sum(dy∗x,dims)dcos = sum(dy * x, dims)
dsin=sum(dy∗stack((−x2,x1),dim=−1).reshape(dy.shape),dims)dsin = sum(dy * stack((-x2, x1), dim=-1).reshape(dy.shape), dims)
- Ascend 950PR/Ascend 950DT:
(3)quarter模式(mode等于2):
dy1,dy2,dy3,dy4=chunk(dy,chunks=4,dim=−1)dy1, dy2, dy3, dy4 = chunk(dy, chunks=4, dim=-1)
cos1,cos2,cos3,cos4=chunk(cos,chunks=4,dim=−1)cos1, cos2, cos3, cos4 = chunk(cos, chunks=4, dim=-1)
sin1,sin2,sin3,sin4=chunk(sin,chunks=4,dim=−1)sin1, sin2, sin3, sin4 = chunk(sin, chunks=4, dim=-1)
x1,x2,x3,x4=chunk(x,chunks=4,dim=−1)x1, x2, x3, x4 = chunk(x, chunks=4, dim=-1)
dx=cat((cos1∗dy1+sin2∗dy2,cos2∗dy2−sin1∗dy1,cos3∗dy3+sin4∗dy4,cos4∗dy4−sin3∗dy3),dim=−1)dx = cat((cos1 * dy1 + sin2 * dy2, cos2 * dy2 - sin1 * dy1, cos3 * dy3 + sin4 * dy4, cos4 * dy4 - sin3 * dy3), dim=-1)
dcos=sum(dy∗x,dims)dcos = sum(dy * x, dims)
dsin=sum(dy∗cat((−x2,x1,−x4,x3),dim=−1),dims)dsin = sum(dy * cat((-x2, x1, -x4, x3), dim=-1), dims)
(4)interleave-half模式(mode等于3):
dy1,dy2=chunk(dy,chunks=2,dim=−1)dy1, dy2 = chunk(dy, chunks=2, dim=-1)
cos1,cos2=chunk(cos,chunks=2,dim=−1)cos1, cos2 = chunk(cos, chunks=2, dim=-1)
sin1,sin2=chunk(sin,chunks=2,dim=−1)sin1, sin2 = chunk(sin, chunks=2, dim=-1)
x1,x2=x[...,::2],x[...,1::2]x1, x2 = x[..., :: 2], x[..., 1 :: 2]
dx=stack((cos1∗dy1+sin2∗dy2,cos2∗dy2−sin1∗dy1),dim=−1).reshape(dy.shape)dx = stack((cos1 * dy1 + sin2 * dy2, cos2 * dy2 - sin1 * dy1), dim=-1).reshape(dy.shape)
dcos=sum(dy∗cat((x1,x2),dim=−1),dims)dcos = sum(dy * cat((x1, x2), dim=-1), dims)
dsin=sum(dy∗cat((−x2,x1),dim=−1),dims)dsin = sum(dy * cat((-x2, x1), dim=-1), dims)
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| dy | 输入 | 公式中的dy,表示正向计算输出y的导数。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| cos | 输入 | 公式中的cos,正向计算输入,需与dy数据类型一致。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| sin | 输入 | 公式中的sin,正向计算输入,需与dy数据类型一致。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| xOptional | 可选输入 | 公式中的x,正向计算输入。如果为空指针,则不计算dcosOut和dsinOut。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| mode | 输入 | 公式中的旋转模式。 | INT64 | - |
| dxOut | 输出 | 公式中的dx,输入x的导数。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| dcosOut | 输出 | 公式中的dcos,输入cos的导数,仅当xOptional非空时有效。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| dsinOut | 输出 | 公式中的dsin,输入sin的导数,仅当xOptional非空时有效。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
约束说明
-
Ascend 950PR/Ascend 950DT: 输入张量dy支持BNSD、BSND、SBND、TND排布。各参数的shape约束可以描述如下:
- 输入张量dy、cos、sin及输出张量dx的最后一维大小必须相同,且小于等于1024。对于half、interleave和interleave-half模式,最后一维必须能被2整除,对于quarter模式,最后一维必须能被4整除。
- 输入张量dy和输出张量dx的shape必须完全相同。
- 输入张量cos和sin的shape必须完全相同,cos和sin的shape需要与dy满足broadcast关系,且广播后的shape必须等于dy的shape。
- 当dy为TND时,cos、sin支持T1D、TND。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- 输入张量dy支持BNSD、BSND、SBND排布。
- 输入张量dy、cos、sin、xOptional及输出张量dxOut、dcosOut、dsinOut的D维度大小必须相同,满足D<896,且必须为2的倍数。
- 输入张量dy、xOptional和输出张量dxOut的shape必须完全相同。
- 输入张量cos、sin和输出张量dcosOut、dsinOut的shape必须完全相同,且cos和sin的shape必须完全相同。
- half模式:
- B,N < 1000;当需要计算dsin、dcos时,B * N <= 1024
- 当dy为BNSD时,cos、sin支持11SD、B1SD、BNSD;当cos、sin为B1SD时需满足B < S
- 当dy为BSND时,cos、sin支持1S1D、BS1D、BSND;当cos、sin为BS1D时需满足B < S
- 当dy为SBND时,cos、sin支持S11D、SB1D、SBND
- interleave模式:
- B * N < 1000
- 当dy为BNSD时,cos、sin支持11SD
- 当dy为BSND时,cos、sin支持1S1D
- 当dy为SBND时,cos、sin支持S11D
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_rotary_position_embedding_grad | 通过aclnnRotaryPositionEmbeddingGrad接口方式调用RotaryPositionEmbeddingGrad算子。 |