import torch
import torch_npu
import torchvision
from torch import Tensor
def roi_pool(input, boxes, output_size, spatial_scale: float = 1.0) -> Tensor:
"""There are some differences between the native implementation of TorchVision and the implementation
provided by the NPU operator when calc roi_pool. This can lead to inaccurate calculations.
This function calculates the boxes coordinate value first,
and then passes it into the operator with a spatial scale of 1 to ensure
that the accuracy is consistent with the CPU.
Ref to torchvision/csrc/ops/cuda/roi_pool_kernel.cu, CUDA calc box roi_width and roi_height by:
roi_width = round(boxes[:,3] * spatial_scale) - round(boxes[:,1] * spatial_scale) + 1
NPU calc roi_pool according to implementation of MMCV.
Ref to mmcv/ops/csrc/common/cuda/roi_pool_cuda_kernel.cuh, NPU calc box roi_width and roi_height by:
roi_width = (boxes[:,3] + 1) * spatial_scale - boxes[:,1] * spatial_scale
to meet the diff, we do round operation before ahead and construct spatial_scale=1
Args:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from.
The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the input coordinates to
the box coordinates. Default: 1.0
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
if input.device.type == "npu":
boxes[:, 1:] = torch.round(boxes[:, 1:] * spatial_scale)
spatial_scale = 1.0
return torchvision.ops.tv_roi_pool(input, boxes, output_size, spatial_scale)
def patch_roi_pool():
setattr(torchvision.ops, "tv_roi_pool", torchvision.ops.roi_pool)
torchvision.ops.roi_pool = roi_pool