GridSampler3DGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | √ |
功能说明
-
算子功能:GridSampler中3D场景的反向传播,完成张量input与张量grid的梯度计算。
-
计算流程:
- 根据grid存储的(x, y, z)值,计算出映射到input上的坐标,这些坐标和align_corners、padding_mode有关。
- 坐标根据输入的interpolation_mode,选择使用bilinear、nearest不同插值模式计算输出值。
- 根据grad存储的梯度值乘上对应点的权重值,计算出最终dx、dgrid的结果。
-
计算公式:
grad、input、grid、dx、dgrid的尺寸如下:
grad:(N,C,Dout,Hout,Wout)input:(N,C,Din,Hin,Win)grid:(N,Dout,Hout,Wout,3)dx:(N,C,Din,Hin,Win)dgrid:(N,Dout,Hout,Wout,3)grad: (N, C, D_{out}, H_{out}, W_{out})\\ input: (N, C, D_{in}, H_{in}, W_{in})\\ grid: (N, D_{out}, H_{out}, W_{out}, 3)\\ dx: (N, C, D_{in}, H_{in}, W_{in})\\ dgrid: (N, D_{out}, H_{out}, W_{out}, 3)
其中grad、input、grid、dx、dgrid中的N是一致的,grad、input和dx中的C是一致的,input和dx中的DinD_{in}、HinH_{in}、WinW_{in}是一致的,grad、grid和dgrid中的DoutD_{out}、HoutH_{out}、WoutW_{out}一致的,grid最后一维大小为3,表示input像素位置信息为(x, y, z),会将x、y、z的取值范围归一化到[-1, 1]之间。
-
坐标反归一化:
grid中的(x, y, z)需要先反归一化到input像素坐标(ix, iy, iz),同时计算梯度乘子
gix_mult、giy_mult、giz_mult(用于后续dgrid计算):-
align_corners = true:
ix=(x+1)2⋅(Win−1),gix_mult=Win−12iy=(y+1)2⋅(Hin−1),giy_mult=Hin−12iz=(z+1)2⋅(Din−1),giz_mult=Din−12ix = \frac{(x+1)}{2} \cdot (W_{in} - 1), \quad gix\_mult = \frac{W_{in} - 1}{2} \\ iy = \frac{(y+1)}{2} \cdot (H_{in} - 1), \quad giy\_mult = \frac{H_{in} - 1}{2} \\ iz = \frac{(z+1)}{2} \cdot (D_{in} - 1), \quad giz\_mult = \frac{D_{in} - 1}{2}
-
align_corners = false:
ix=(x+1)⋅Win−12,gix_mult=Win2iy=(y+1)⋅Hin−12,giy_mult=Hin2iz=(z+1)⋅Din−12,giz_mult=Din2ix = \frac{(x+1) \cdot W_{in} - 1}{2}, \quad gix\_mult = \frac{W_{in}}{2} \\ iy = \frac{(y+1) \cdot H_{in} - 1}{2}, \quad giy\_mult = \frac{H_{in}}{2} \\ iz = \frac{(z+1) \cdot D_{in} - 1}{2}, \quad giz\_mult = \frac{D_{in}}{2}
-
-
padding_mode对梯度乘子的影响:
- padding_mode="zeros",
gix_mult不变 - padding_mode="border",gix_mult=gix_mult×grad_clipgix\_mult = gix\_mult × grad\_clip(坐标在边界外时grad_clip=0,否则=1)
- padding_mode="reflection",gix_mult=gix_mult×grad_refl×grad_clipgix\_mult = gix\_mult × grad\_refl × grad\_clip(grad_refl是反射坐标变换函数对输入坐标的导数,表示反射后输出坐标随输入坐标变化的方向和速率。取值为-1,0,1)。
- padding_mode="zeros",
-
各插值模式的梯度公式:
-
Bilinear(三线性插值,Trilinear)
四个角点坐标和权重为:
角点 坐标(ip,jp,kp)(i_p, j_p, k_p) 权重wpw_p tnw(顶-北-西) (⌊iz⌋,⌊iy⌋,⌊ix⌋)(⌊iz⌋, ⌊iy⌋, ⌊ix⌋) (ixbse−ix)×(iybse−iy)×(izbse−iz)(ix_{bse} - ix) × (iy_{bse} - iy) × (iz_{bse} - iz) tne(顶-北-东) (⌊iz⌋,⌊iy⌋,⌊ix⌋+1)(⌊iz⌋, ⌊iy⌋, ⌊ix⌋+1) (ix−ixbsw)×(iybsw−iy)×(izbsw−iz)(ix - ix_{bsw}) × (iy_{bsw} - iy) × (iz_{bsw} - iz) tsw(顶-南-西) (⌊iz⌋,⌊iy⌋+1,⌊ix⌋)(⌊iz⌋, ⌊iy⌋+1, ⌊ix⌋) (ixbne−ix)×(iy−iybne)×(izbne−iz)(ix_{bne} - ix) × (iy - iy_{bne}) × (iz_{bne} - iz) tse(顶-南-东) (⌊iz⌋,⌊iy⌋+1,⌊ix⌋+1)(⌊iz⌋, ⌊iy⌋+1, ⌊ix⌋+1) (ix−ixbnw)×(iy−iybnw)×(izbnw−iz)(ix - ix_{bnw}) × (iy - iy_{bnw}) × (iz_{bnw} - iz) bnw(底-北-西) (⌊iz⌋+1,⌊iy⌋,⌊ix⌋)(⌊iz⌋+1, ⌊iy⌋, ⌊ix⌋) (ixtse−ix)×(iytse−iy)×(iz−iztse)(ix_{tse} - ix) × (iy_{tse} - iy) × (iz - iz_{tse}) bne(底-北-东) (⌊iz⌋+1,⌊iy⌋,⌊ix⌋+1)(⌊iz⌋+1, ⌊iy⌋, ⌊ix⌋+1) (ix−ixtsw)×(iytsw−iy)×(iz−iztsw)(ix - ix_{tsw}) × (iy_{tsw} - iy) × (iz - iz_{tsw}) bsw(底-南-西) (⌊iz⌋+1,⌊iy⌋+1,⌊ix⌋)(⌊iz⌋+1, ⌊iy⌋+1, ⌊ix⌋) (ixtne−ix)×(iy−iytne)×(iz−iztne)(ix_{tne} - ix) × (iy - iy_{tne}) × (iz - iz_{tne}) bse(底-南-东) (⌊iz⌋+1,⌊iy⌋+1,⌊ix⌋+1)(⌊iz⌋+1, ⌊iy⌋+1, ⌊ix⌋+1) (ix−ixtnw)×(iy−iytnw)×(iz−iztnw)(ix - ix_{tnw}) × (iy - iy_{tnw}) × (iz - iz_{tnw}) 其中:
ixtnw=floor(ix)iytnw=floor(iy)iztnw=floor(iz)ix_{tnw} = floor(ix) \\ iy_{tnw} = floor(iy) \\ iz_{tnw} = floor(iz) \\
ixtne=ixtnw+1iytne=iytnwiztne=iztnwix_{tne} = ix_{tnw} + 1 \\ iy_{tne} = iy_{tnw} \\ iz_{tne} = iz_{tnw}\\
ixtsw=ixtnwiytsw=iytnw+1iztsw=iztnwix_{tsw} = ix_{tnw} \\ iy_{tsw} = iy_{tnw} + 1 \\ iz_{tsw} = iz_{tnw}\\
ixtse=ixtnw+1iytse=iytnw+1iztse=iztnwix_{tse} = ix_{tnw} + 1 \\ iy_{tse} = iy_{tnw} + 1 \\ iz_{tse} = iz_{tnw} \\
ixbnw=ixtnwiybnw=iytnwizbnw=iztnw+1ix_{bnw} = ix_{tnw} \\ iy_{bnw} = iy_{tnw} \\ iz_{bnw} = iz_{tnw} + 1 \\
ixbne=ixtnw+1iybne=iytnwizbne=iztnw+1ix_{bne} = ix_{tnw} + 1 \\ iy_{bne} = iy_{tnw} \\ iz_{bne} = iz_{tnw} + 1 \\
ixbsw=ixtnwiybsw=iytnw+1izbsw=iztnw+1ix_{bsw} = ix_{tnw} \\ iy_{bsw} = iy_{tnw} + 1 \\ iz_{bsw} = iz_{tnw} + 1 \\
ixbse=ixtnw+1iybse=iytnw+1izbse=iztnw+1ix_{bse} = ix_{tnw} + 1 \\ iy_{bse} = iy_{tnw} + 1 \\ iz_{bse} = iz_{tnw} + 1 \\
-
dx(input梯度):将上游梯度按三线性权重散射到input对应位置
dx(N,C,ip,jp,kp)+=wp⋅grad(N,C,Dout,Hout,Wout)dx(N, C, i_p, j_p, k_p) \mathrel{+}= w_p \cdot grad(N, C, D_{out}, H_{out}, W_{out})
即对每个输出像素(d, h, w),将其梯度乘以三线性权重,累加到input的8个相邻体素位置(越界位置不累加)。
-
dgrid(grid梯度):对(ix, iy, iz)的偏导
gix=∑c[−Vtnw⋅(iybse−iy)(izbse−iz)+Vtne⋅(iybsw−iy)(izbsw−iz)−Vtsw⋅(iy−iybne)(izbne−iz)+Vtse⋅(iy−iybnw)(izbnw−iz)−Vbnw⋅(iytse−iy)(iz−iztse)+Vbne⋅(iytsw−iy)(iz−iztsw)−Vbsw⋅(iy−iytne)(iz−iztne)+Vbse⋅(iy−iytnw)(iz−iztnw)]⋅grad(N,C,Dout,Hout,Wout)gix = \sum_{c} \left[ -V_{tnw} \cdot (iy_{bse}-iy)(iz_{bse}-iz) + V_{tne} \cdot (iy_{bsw}-iy)(iz_{bsw}-iz) - V_{tsw} \cdot (iy-iy_{bne})(iz_{bne}-iz) + V_{tse} \cdot (iy-iy_{bnw})(iz_{bnw}-iz) - V_{bnw} \cdot (iy_{tse}-iy)(iz-iz_{tse}) + V_{bne} \cdot (iy_{tsw}-iy)(iz-iz_{tsw}) - V_{bsw} \cdot (iy-iy_{tne})(iz-iz_{tne}) + V_{bse} \cdot (iy-iy_{tnw})(iz-iz_{tnw}) \right] \cdot grad(N, C, D_{out}, H_{out}, W_{out})
giy=∑c[−Vtnw⋅(ixbse−ix)(izbse−iz)−Vtne⋅(ix−ixbsw)(izbsw−iz)+Vtsw⋅(ixbne−ix)(izbne−iz)+Vtse⋅(ix−ixbnw)(izbnw−iz)−Vbnw⋅(ixtse−ix)(iz−iztse)−Vbne⋅(ix−ixtsw)(iz−iztsw)+Vbsw⋅(ixtne−ix)(iz−iztne)+Vbse⋅(ix−ixtnw)(iz−iztnw)]⋅grad(N,C,Dout,Hout,Wout)giy = \sum_{c} \left[ -V_{tnw} \cdot (ix_{bse}-ix)(iz_{bse}-iz) - V_{tne} \cdot (ix-ix_{bsw})(iz_{bsw}-iz) + V_{tsw} \cdot (ix_{bne}-ix)(iz_{bne}-iz) + V_{tse} \cdot (ix-ix_{bnw})(iz_{bnw}-iz) - V_{bnw} \cdot (ix_{tse}-ix)(iz-iz_{tse}) - V_{bne} \cdot (ix-ix_{tsw})(iz-iz_{tsw}) + V_{bsw} \cdot (ix_{tne}-ix)(iz-iz_{tne}) + V_{bse} \cdot (ix-ix_{tnw})(iz-iz_{tnw}) \right] \cdot grad(N, C, D_{out}, H_{out}, W_{out})
giz=∑c[−Vtnw⋅(ixbse−ix)(iybse−iy)−Vtne⋅(ix−ixbsw)(iybsw−iy)−Vtsw⋅(ixbne−ix)(iy−iybne)−Vtse⋅(ix−ixbnw)(iy−iybnw)+Vbnw⋅(ixtse−ix)(iytse−iy)+Vbne⋅(ix−ixtsw)(iytsw−iy)+Vbsw⋅(ixtne−ix)(iy−iytne)+Vbse⋅(ix−ixtnw)(iy−iytnw)]⋅grad(N,C,Dout,Hout,Wout)giz = \sum_{c} \left[ -V_{tnw} \cdot (ix_{bse}-ix)(iy_{bse}-iy) - V_{tne} \cdot (ix-ix_{bsw})(iy_{bsw}-iy) - V_{tsw} \cdot (ix_{bne}-ix)(iy-iy_{bne}) - V_{tse} \cdot (ix-ix_{bnw})(iy-iy_{bnw}) + V_{bnw} \cdot (ix_{tse}-ix)(iy_{tse}-iy) + V_{bne} \cdot (ix-ix_{tsw})(iy_{tsw}-iy) + V_{bsw} \cdot (ix_{tne}-ix)(iy-iy_{tne}) + V_{bse} \cdot (ix-ix_{tnw})(iy-iy_{tnw}) \right] \cdot grad(N, C, D_{out}, H_{out}, W_{out})
其中 Vp=input(N,C,ip,jp,kp)V_p = input(N, C, i_p, j_p, k_p)(仅当角点在边界内时参与计算)。
-
最终:
dgrid(N,Dout,Hout,Wout,0)=gix_mult⋅gixdgrid(N, D_{out}, H_{out}, W_{out}, 0) = gix\_mult \cdot gix
dgrid(N,Dout,Hout,Wout,1)=giy_mult⋅giydgrid(N, D_{out}, H_{out}, W_{out}, 1) = giy\_mult \cdot giy
dgrid(N,Dout,Hout,Wout,2)=giy_mult⋅gizdgrid(N, D_{out}, H_{out}, W_{out}, 2) = giy\_mult \cdot giz
-
Nearest(最邻近插值)
- dx:将上游梯度直接累加到最近邻位置
dx(N,C,round(iz),round(iy),round(ix))+=grad(N,C,Dout,Hout,Wout)dx(N, C, \text{round}(iz), \text{round}(iy), \text{round}(ix)) \mathrel{+}= grad(N, C, D_{out}, H_{out}, W_{out})
- dgrid:最邻近插值对坐标不可导,因此 dgrid = 0。
-
-
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| grad | 输入 | 表示反向传播过程中上一层的输出梯度,对应公式描述中的`grad`。数据类型与`x`的数据类型一致。当数据类型DOUBLE时,数据格式不支持NDHWC。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| x | 输入 | 表示反向传播的输入张量,对应公式描述中的`input`。shape仅支持五维,且需满足`x`和`grad`的N轴和C轴的值保持一致,x的D,H,W值不可为0。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| grid | 输入 | 表示采用像素位置的张量,对应公式描述中的`grid`。shape仅支持五维,且需满足`grid`和`grad`的N轴、D轴、H轴、W轴的值保持一致,最后一维的值等于3。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NDHWC |
| interpolation_mode | 可选属性 |
|
STRING | - |
| padding_mode | 可选属性 |
|
STRING | - |
| align_corners | 可选属性 |
|
BOOL | - |
| dx | 输出 | 表示反向传播的输出梯度,对应公式描述中的`dx`。数据类型、数据格式和shape与`x`的数据类型、数据格式和shape保持一致。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NCDHW、NDHWC |
| dgrid | 输出 | 表示`grid`梯度,对应公式描述中的`dgrid`。数据类型、数据格式和shape与`grid`的数据类型、数据格式和shape保持一致。 | FLOAT16、FLOAT32、DOUBLE、BFLOAT16 | NDHWC |
Atlas 训练系列产品:输入参数和输出参数的数据类型不支持DOUBLE、BFLOAT16。
约束说明
无
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_grid_sampler3_d_backward | 通过aclnnGridSampler3DBackward接口方式调用GridSampler3DGrad算子。 |
| 图模式 | - | 通过算子IR构图方式调用GridSampler3DGrad算子。 |