import torch
import torch_npu
from torch.autograd import Function

import mx_driving._C


class PixelGroupFunction(Function):
    @staticmethod
    # pylint: disable=huawei-too-many-arguments
    def forward(
        ctx,
        score,
        mask,
        embedding,
        kernel_label,
        kernel_contour,
        kernel_region_num,
        distance_threshold,
    ):
        output = mx_driving._C.pixel_group(
            score,
            mask,
            embedding,
            kernel_label,
            kernel_contour,
            kernel_region_num,
            distance_threshold,
        )

        return output


pixel_group = PixelGroupFunction.apply