GridSampler2DGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品

功能说明

  • 算子功能:GridSampler中2D场景的反向传播,完成张量input与张量grid的梯度计算。

  • 计算流程:

    1. 根据grid存储的(x, y)值,计算出映射到input上的坐标,这些坐标和align_corners、padding_mode有关。
    2. 根据输入的interpolationMode,选择使用bilinear、nearest、bicubic不同插值模式计算该坐标周围点分配到梯度的权重值。
    3. 根据grad存储的梯度值乘上对应点的权重值,计算出最终dx、dgrid的结果。
  • 计算公式:

    grad、input、grid、dx、dgrid的尺寸如下:

    grad:(N,C,Hout,Wout)input:(N,C,Hin,Win)grid:(N,Hout,Wout,2)dx:(N,C,Hin,Win)dgrid:(N,Hout,Wout,2)grad: (N, C, H_{out}, W_{out})\\ input: (N, C, H_{in}, W_{in})\\ grid: (N, H_{out}, W_{out}, 2)\\ dx: (N, C, H_{in}, W_{in})\\ dgrid: (N, H_{out}, W_{out}, 2)

    其中grad、input、grid、dx、dgrid中的N均相同,grad、input和dx中的C相同,input和dx中的HinH_{in}WinW_{in}相同,grad、grid和dgrid中的HoutH_{out}WoutW_{out}相同,grid最后一维大小为2,表示input像素位置信息为(x, y)。x和y的取值范围归一化到[-1, 1],(-1, 1)表示左上角坐标,(1, -1)表示右下角坐标。

    1. 坐标反归一化: grid中的(x, y)需要先反归一化到input像素坐标(ix, iy),同时计算梯度乘子gix_multgiy_mult(用于后续dgrid计算):

      • align_corners = true:

        ix=(x+1)2⋅(Win−1),gix_mult=Win−12iy=(y+1)2⋅(Hin−1),giy_mult=Hin−12ix = \frac{(x+1)}{2} \cdot (W_{in} - 1), \quad gix\_mult = \frac{W_{in} - 1}{2} \\ iy = \frac{(y+1)}{2} \cdot (H_{in} - 1), \quad giy\_mult = \frac{H_{in} - 1}{2}

      • align_corners = false:

        ix=(x+1)⋅Win−12,gix_mult=Win2iy=(y+1)⋅Hin−12,giy_mult=Hin2ix = \frac{(x+1) \cdot W_{in} - 1}{2}, \quad gix\_mult = \frac{W_{in}}{2} \\ iy = \frac{(y+1) \cdot H_{in} - 1}{2}, \quad giy\_mult = \frac{H_{in}}{2}

    2. padding_mode对梯度乘子的影响:

      • padding_mode="zeros",gix_mult不变
      • padding_mode="border",gix_mult=gix_mult×grad_clipgix\_mult = gix\_mult × grad\_clip(坐标在边界外时grad_clip=0,否则=1)
      • padding_mode="reflection",gix_mult=gix_mult×grad_refl×grad_clipgix\_mult = gix\_mult × grad\_refl × grad\_clip(grad_refl是反射坐标变换函数对输入坐标的导数,表示反射后输出坐标随输入坐标变化的方向和速率。取值为-1,0,1)。
    3. 各插值模式的梯度公式:

      • Bilinear(双线性插值)

        四个角点坐标和权重为:

        角点 坐标(ip,jp)(i_p, j_p) 权重wpw_p
        nw(西北) (iynw,ixnw)(iy_{nw}, ix_{nw}) (ixse−ix)×(iyse−iy)(ix_{se} - ix) × (iy_{se} - iy)
        ne(东北) (iyne,ixne)(iy_{ne}, ix_{ne}) (ix−ixsw)×(iysw−iy)(ix - ix_{sw}) × (iy_{sw} - iy)
        sw(西南) (iysw,ixsw)(iy_{sw}, ix_{sw}) (ixne−ix)×(iy−iyne)(ix_{ne} - ix) × (iy - iy_{ne})
        se(东南) (iyse,ixse)(iy_{se}, ix_{se}) (ix−ixnw)×(iy−iynw)(ix - ix_{nw}) × (iy - iy_{nw})

        其中:

        ixnw=floor(ix)iynw=floor(iy)ixne=floor(ix)+1iyne=floor(iy)ixsw=floor(ix)iysw=floor(iy)+1ixse=floor(ix)+1iyse=floor(iy)+1ix_{nw} = floor(ix) \\ iy_{nw} = floor(iy) \\ ix_{ne} = floor(ix) + 1 \\ iy_{ne} = floor(iy)\\ ix_{sw} = floor(ix) \\ iy_{sw} = floor(iy) + 1 \\ ix_{se} = floor(ix) + 1 \\ iy_{se} = floor(iy) + 1

        • dx(input梯度):将上游梯度按权重散射到input对应位置

          dx(N,C,ip,jp)+=wp⋅grad(N,C,Hout,Wout)dx(N, C, i_p, j_p) \mathrel{+}= w_p \cdot grad(N, C, H_{out}, W_{out})

          即对每个输出像素(h, w),将其梯度乘以双线性权重,累加到input的四个相邻像素位置(越界位置不累加)。

        • dgrid(grid梯度):对(ix, iy)的偏导

          gix=∑c[−Vnw⋅(iyse−iy)+Vne⋅(iysw−iy)−Vsw⋅(iy−iyne)+Vse⋅(iy−iynw)]⋅grad(N,C,Hout,Wout)gix = \sum_{c} \left[ -V_{nw} \cdot (iy_{se} - iy) + V_{ne} \cdot (iy_{sw} - iy) - V_{sw} \cdot (iy - iy_{ne}) + V_{se} \cdot (iy - iy_{nw}) \right] \cdot grad(N, C, H_{out}, W_{out})

          giy=∑c[−Vnw⋅(ixse−ix)−Vne⋅(ix−ixsw)+Vsw⋅(ixne−ix)+Vse⋅(ix−ixnw)]⋅grad(N,C,Hout,Wout)giy = \sum_{c} \left[ -V_{nw} \cdot (ix_{se} - ix) - V_{ne} \cdot (ix - ix_{sw}) + V_{sw} \cdot (ix_{ne} - ix) + V_{se} \cdot (ix - ix_{nw}) \right] \cdot grad(N, C, H_{out}, W_{out})

          其中 Vp=input(N,C,ip,jp)V_p = input(N, C, i_p, j_p)(仅当角点在边界内时参与计算)。

        • 最终:

          dgrid(N,Hout,Wout,0)=gix_mult⋅gixdgrid(N, H_{out}, W_{out}, 0) = gix\_mult \cdot gix

          dgrid(N,Hout,Wout,1)=giy_mult⋅giydgrid(N, H_{out}, W_{out}, 1) = giy\_mult \cdot giy

      • Nearest(最邻近插值)

        • dx:将上游梯度直接累加到最近邻位置

          dx(N,C,round(iy),round(ix))+=grad(N,C,Hout,Wout)dx(N, C, \text{round}(iy), \text{round}(ix)) \mathrel{+}= grad(N, C, H_{out}, W_{out})

        • dgrid:最邻近插值对坐标不可导,因此 dgrid = 0

      • Bicubic(双三次插值)

        • dx:

          dx(N,C,iy′,ix′)+=grad(N,C,Hout,Wout)⋅x_coeffs[i]⋅y_coeffs[j]dx(N, C, iy', ix') \mathrel{+}= grad(N, C, H_{out}, W_{out}) \cdot x\_coeffs[i] \cdot y\_coeffs[j]

          其中:

          (ix′,iy′)=(ixnw−1+i,iynw−1+j)(ix', iy') = (ix_{nw}-1+i, iy_{nw}-1+j)i,j∈{0,1,2,3}i,j \in \{0,1,2,3\},越界位置根据padding_mode处理。

          A=−0.75x0=x+1.0x_coeffs[0]=((A∗x0−5∗A)∗x0+8∗A)∗x0−4∗AA = -0.75 \\ x_0 = x + 1.0 \\ x\_coeffs[0] = ((A * x_0 - 5* A) * x_0 + 8 * A) * x_0 - 4 * A

          x1=xx_coeffs[1]=((A+2)∗x1−(A+3))∗x1∗x1+1x_1 = x \\ x\_coeffs[1] = ((A + 2)* x_1 - (A + 3)) * x_1 * x_1 + 1

          x2=1−xx_coeffs[2]=((A+2)∗x2−(A+3))∗x2∗x2+1x_2 = 1 - x \\ x\_coeffs[2] = ((A + 2)* x_2 - (A + 3)) * x_2 * x_2 + 1

          x3=2−xx_coeffs[3]=((A∗x3−5∗A)∗x3+8∗A)∗x3−4∗Ax_3 = 2 - x \\ x\_coeffs[3] = ((A * x_3 - 5* A) * x_3 + 8 * A) * x_3 - 4 * A

          y0=y+1.0y_coeffs[0]=((A∗y0−5∗A)∗y0+8∗A)∗y0−4∗Ay_0 = y + 1.0 \\ y\_coeffs[0] = ((A * y_0 - 5* A) * y_0 + 8 * A) * y_0 - 4 * A

          y1=yy_coeffs[1]=((A+2)∗y1−(A+3))∗y1∗y1+1y_1 = y \\ y\_coeffs[1] = ((A + 2)* y_1 - (A + 3)) * y_1 * y_1 + 1

          y2=1−yy_coeffs[2]=((A+2)∗y2−(A+3))∗y2∗y2+1y_2 = 1 - y \\ y\_coeffs[2] = ((A + 2)* y_2 - (A + 3)) * y_2 * y_2 + 1

          y3=2−yy_coeffs[3]=((A∗y3−5∗A)∗y3+8∗A)∗y3−4∗Ay_3 = 2 - y \\ y\_coeffs[3] = ((A * y_3 - 5* A) * y_3 + 8 * A) * y_3 - 4 * A

        • dgrid:

          gix=−∑C∑i=03∑j=03Vij⋅x_coeffs_grad[i]⋅y_coeffs[j]⋅grad(N,C,Hout,Wout)gix = -\sum_{C}\sum_{i=0}^{3}\sum_{j=0}^{3} V_{ij} \cdot x\_coeffs\_grad[i] \cdot y\_coeffs[j] \cdot grad(N, C, H_{out}, W_{out})

          giy=−∑C∑i=03∑j=03Vij⋅y_coeffs_grad[j]⋅x_coeffs[i]⋅grad(N,C,Hout,Wout)giy = -\sum_{C}\sum_{i=0}^{3}\sum_{j=0}^{3} V_{ij} \cdot y\_coeffs\_grad[j] \cdot x\_coeffs[i] \cdot grad(N, C, H_{out}, W_{out})

          其中:

          Vij=get_value_bounded(input(N,C,Hin,Win),ixnw−1+i,iynw−1+j)V_{ij} = get\_value\_bounded(input(N, C, H_{in}, W_{in}), ix_{nw}-1+i, iy_{nw}-1 +j)x_coeffs_grady_coeffs_grad是三次插值系数对tx/ty的导数:

          tx=ix−floor(ix)ty=iy−floor(iy)tx = ix - floor(ix) \\ ty = iy - floor(iy) \\

          x_coeffs_grad[0]=(−3A⋅x−10A)⋅x−8Ax=∣−1−tx∣x\_coeffs\_grad[0] = (-3A \cdot x - 10A) \cdot x - 8A \\ \quad x = |-1 - tx|

          x_coeffs_grad[1]=(−3(A+2)⋅x−2(A+3))⋅xx=∣0−tx∣x\_coeffs\_grad[1] = (-3(A+2) \cdot x - 2(A+3)) \cdot x \\ \quad x = | 0 - tx|

          x_coeffs_grad[2]=(3(A+2)⋅x−2(A+3))⋅xx=∣1−tx∣x\_coeffs\_grad[2] = (3(A+2) \cdot x - 2(A+3)) \cdot x \\ \quad x = |1 - tx|

          x_coeffs_grad[3]=(3A⋅x−10A)⋅x+8Ax=∣2−tx∣x\_coeffs\_grad[3] = (3A \cdot x - 10A) \cdot x + 8A \\ \quad x = |2 - tx|

          y_coeffs_grad[0]=(−3A⋅y−10A)⋅y−8Ay=∣−1−ty∣y\_coeffs\_grad[0] = (-3A \cdot y - 10A) \cdot y - 8A \\ \quad y = |-1 - ty|

          y_coeffs_grad[1]=(−3(A+2)⋅y−2(A+3))⋅yy=∣0−ty∣y\_coeffs\_grad[1] = (-3(A+2) \cdot y - 2(A+3)) \cdot y \\ \quad y = | 0 - ty|

          y_coeffs_grad[2]=(3(A+2)⋅y−2(A+3))⋅yy=∣1−ty∣y\_coeffs\_grad[2] = (3(A+2) \cdot y - 2(A+3)) \cdot y \\ \quad y = |1 - ty|

          y_coeffs_grad[3]=(3A⋅y−10A)⋅y+8Ay=∣2−ty∣y\_coeffs\_grad[3] = (3A \cdot y - 10A) \cdot y + 8A \\ \quad y = |2 - ty|

          最终:

          dgrid(N,Hout,Wout,0)=gix_mult⋅gixdgrid(N, H_{out}, W_{out}, 0) = gix\_mult \cdot gix

          dgrid(N,Hout,Wout,1)=giy_mult⋅giydgrid(N, H_{out}, W_{out}, 1) = giy\_mult \cdot giy

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
grad 输入 表示反向传播过程中上一层的输出梯度,对应公式描述中的`grad`。 FLOAT16、FLOAT32、DOUBLE、BFLOAT16 NHWC
x 输入 表示反向传播的输入张量,对应公式描述中的`input`。shape仅支持四维,且需满足`x`和`grad`的N轴和C轴的值保持一致,`x`最后两维的维度值不可为0。 FLOAT16、FLOAT32、DOUBLE、BFLOAT16 ND
grid 输入 表示采用像素位置的张量,对应公式描述中的`grid`。shape仅支持四维,且需满足`grid`和`grad`的N轴、H轴、W轴的值保持一致,`grid`最后一维的值等于2。 FLOAT16、FLOAT32、DOUBLE、BFLOAT16 ND
interpolation_mode 可选属性
  • 表示插值模式,对应公式描述中的`interpolation_mode`。支持"bilinear"(双线性插值)、"nearest"(最邻近插值)、"bicubic"(双三次插值)。
  • 默认值为"bilinear"。
STRING -
padding_mode 可选属性
  • 用于表示填充模式,对应公式描述中的`padding_mode`。支持"zeros"、"border"、"reflection"三种模式。
  • 默认值为"zeros"。
STRING -
align_corners 可选属性
  • 表示设定特征图坐标与特征值的对应方式,对应公式描述中的`align_corners`。设定为true时,特征值位于像素中心;设定为false时,特征值位于像素的角点。
  • 默认值为false。
BOOL -
dx 输出 表示反向传播的输出梯度,对应公式描述中的`dx`。数据类型、数据格式和shape与`x`的数据类型、数据格式和shape保持一致。 FLOAT16、FLOAT32、DOUBLE、BFLOAT16 ND
dgrid 输出 表示grid梯度,对应公式描述中的`dgrid`。数据类型、数据格式和shape与`grid`的数据类型、数据格式和shape保持一致。 FLOAT16、FLOAT32、DOUBLE、BFLOAT16 ND

Atlas 训练系列产品:输入参数和输出参数的数据类型不支持DOUBLE、BFLOAT16。

约束说明

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_grid_sampler2_d_backward 通过aclnnGridSampler2DBackward接口方式调用GridSampler2DGrad算子。
图模式 - 通过算子IR构图方式调用GridSampler2DGrad算子。