"""
Copyright (c) OpenMMLab. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
Modification by: Huawei Developers
Modification date: 2024-06-04
Modification Description:
Modification 1. Add support for Ascend NPU
"""
import torch
import torch_npu
from torch.autograd import Function
import mx_driving._C
class MaxPool2d(Function):
@staticmethod
def forward(ctx, x, kernel_size, stride, padding):
y = mx_driving._C.npu_max_pool2d(x, kernel_size, stride, padding)
return y
def npu_max_pool2d(x, kernel_size, stride, padding):
DEVICE_NAME = torch_npu.npu.get_device_name(x.device.index)
if "Ascend910" in DEVICE_NAME:
return MaxPool2d.apply(x, kernel_size, stride, padding)
elif "Ascend950" in DEVICE_NAME:
return torch.nn.functional.max_pool2d(x, kernel_size, stride, padding)
else:
raise NotImplementedError("The npu_max_pool2d currently only supports Ascend910B, Ascend910C and Ascend950.")