ResizeNearestNeighborV2Grad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:ResizeNearestNeighborV2的反向传播,根据最近邻插值的映射关系,将输出梯度散射回输入空间并累加。
-
计算公式:
正向ResizeNearestNeighborV2中,每个输出像素值复制于某个输入像素。反向传播时,输入像素的梯度为所有映射到该位置的输出像素梯度的累加和:
y(N,C,h,w)+=grads(N,C,h′,w′)y(N, C, h, w) += \text{grads}(N, C, h', w')
其中 (h′,w′)(h', w') 为输出坐标,通过最近邻映射回输入坐标:
h=min(⌊h′×scaleH⌋,Hin−1)h = \min(\lfloor h' \times \text{scaleH} \rfloor, H_{in} - 1)
w=min(⌊w′×scaleW⌋,Win−1)w = \min(\lfloor w' \times \text{scaleW} \rfloor, W_{in} - 1)
scaleH=HinHout,scaleW=WinWout\text{scaleH} = \frac{H_{in}}{H_{out}}, \quad \text{scaleW} = \frac{W_{in}}{W_{out}}
当属性取值不同时,映射公式变化如下:
- align_corners=True,half_pixel_centers=False:h=min(round(h′×scaleH),Hin−1)h = \min(\text{round}(h' \times \text{scaleH}), H_{in} - 1)
- align_corners=False,half_pixel_centers=True:h=min(⌊(h′+0.5)×scaleH⌋,Hin−1)h = \min(\lfloor (h' + 0.5) \times \text{scaleH} \rfloor, H_{in} - 1)
- align_corners=False,half_pixel_centers=False:h=min(⌊h′×scaleH⌋,Hin−1)h = \min(\lfloor h' \times \text{scaleH} \rfloor, H_{in} - 1)
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| grads | 输入 | 正向ResizeNearestNeighborV2输出的梯度Tensor,对应公式中`grads`。 | FLOAT16、FLOAT32、BFLOAT16 | NCHW、NHWC |
| size | 输入 | 原始图像的高和宽,即正向输入的空间尺寸,用于确定输出梯度Tensor的H和W维度。 | INT32 | - |
| y | 输出 | 正向ResizeNearestNeighborV2输入的梯度Tensor,对应公式中`y`。 | FLOAT16、FLOAT32、BFLOAT16 | NCHW、NHWC |
属性说明:
| 属性名 | 类型 | 默认值 | 描述 |
|---|---|---|---|
| align_corners | Bool | false | 决定是否对齐输入和输出图像的角点像素。为True时角点像素对齐。 |
| half_pixel_centers | Bool | false | 决定插值时像素中心位置。为True时像素中心在半像素坐标处。 |
| scales | ListFloat | {0.0f, 0.0f} | 空间尺寸乘数,可选参数。 |
约束说明
- 仅支持4维Tensor(NCHW或NHWC格式)
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_upsample_nearest2d_grad | 通过aclnnUpsampleNearest2dBackward接口方式调用ResizeNearestNeighborV2Grad算子。 |
| 图模式 | - | 通过算子IR构图方式调用ResizeNearestNeighborV2Grad算子。 |