from torch.nn import Module

from ..ops.roipoint_pool3d import roipoint_pool3d


class RoIPointPool3d(Module):
    def __init__(self, num_sampled_points: int = 512):
        super().__init__()
        self.num_sampled_points = num_sampled_points

    def forward(self, points, point_features, boxes3d):
        return roipoint_pool3d(self.num_sampled_points, points, point_features, boxes3d)