from typing import List, Union
import torch
import mx_driving._C
class BEVPoolV3(torch.autograd.Function):
"""
BEVPoolV3 adapts BEVPoolV1 and BEVPoolV2 for best performance on NPU.
"""
@staticmethod
def forward(
ctx,
depth: Union[torch.Tensor, None],
feat: torch.Tensor,
ranks_depth: Union[torch.Tensor, None],
ranks_feat: Union[torch.Tensor, None],
ranks_bev: torch.Tensor,
bev_feat_shape: List[int],
) -> torch.Tensor:
(B, D, H, W, C) = bev_feat_shape
if depth is None:
if ranks_bev.dim() != 2:
raise ValueError("ranks_bev must be 2D when running without depth")
ranks_bev = ranks_bev[:, 3] * D * H * W + ranks_bev[:, 2] * H * W + ranks_bev[:, 0] * W + ranks_bev[:, 1]
out = mx_driving._C.npu_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, B, D, H, W)
out = out.permute(0, 4, 1, 2, 3).contiguous()
ctx.save_for_backward(depth, feat, ranks_feat, ranks_depth, ranks_bev)
return out
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
depth, feat, ranks_feat, ranks_depth, ranks_bev = ctx.saved_tensors
grad_out = grad_out.permute(0, 2, 3, 4, 1).contiguous()
grad_depth, grad_feat = mx_driving._C.npu_bev_pool_v3_backward(
grad_out,
depth,
feat,
ranks_depth,
ranks_feat,
ranks_bev,
)
return grad_depth, grad_feat, None, None, None, None
def bev_pool_v3(
depth: Union[torch.Tensor, None],
feat: torch.Tensor,
ranks_depth: Union[torch.Tensor, None],
ranks_feat: Union[torch.Tensor, None],
ranks_bev: torch.Tensor,
bev_feat_shape: List[int],
) -> torch.Tensor:
x = BEVPoolV3.apply(
depth,
feat,
ranks_depth,
ranks_feat,
ranks_bev,
bev_feat_shape,
)
return x