"""
Copyright (c) OpenMMLab. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
Modification by: Huawei Developers
Modification date: 2025-05-30
Modification Description:
Modification 1. Add support for Ascend NPU
"""
import torch
import torch_npu
from torch.autograd import Function
import mx_driving._C
class GridSamplerFunction(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, grid: torch.Tensor, interpolation="bilinear", padding="zeros", align=True):
out = torch.nn.functional.grid_sample(features, grid, interpolation, padding, align)
ctx.save_for_backward(features, grid)
ctx.interpolation = interpolation
ctx.padding = padding
ctx.align = align
return out
@staticmethod
def backward(ctx, grad: torch.Tensor):
x, grid = ctx.saved_tensors
interpolation, padding, align = ctx.interpolation, ctx.padding, ctx.align
interpolation_mode_map = {"bilinear": 0, "nearest": 1}
interpolation_mode = interpolation_mode_map.get(interpolation, 0)
padding_mode_map = {"zeros": 0, "border": 1, "reflection": 2}
padding_mode = padding_mode_map.get(padding, 0)
dx, dgrid = mx_driving._C.grid_sampler3d_grad_v1(
grad, x, grid, interpolation_mode, padding_mode, align
)
return dx, dgrid, None, None, None
def grid_sampler3d_v1(
features: torch.Tensor, grid: torch.Tensor, interpolation="bilinear", padding="zeros", align=True
):
DEVICE_NAME = torch_npu.npu.get_device_name(features.device.index)
if "Ascend910" in DEVICE_NAME:
return GridSamplerFunction.apply(features, grid, interpolation, padding, align)
elif "Ascend950" in DEVICE_NAME:
return torch.nn.functional.grid_sample(
features, grid, mode=interpolation, padding_mode=padding, align_corners=align
)
else:
raise NotImplementedError("The grid_sampler3d_v1 currently only supports Ascend910B, Ascend910C and Ascend950.")