import torch
import torch_npu


def box_dtype_check(box):
    if box not in [torch.float, torch.half]:
        return box.float()
    return box


def npu_single_level_responsible_flags(featmap_size,
                                       gt_bboxes,
                                       stride,
                                       num_base_anchors):
    """Using NPU OP to generate the responsible flags of anchor in a single feature map.

    .. note::
        Because of the limitation of NPU op,
        output_size(featmap_size[0] * featmap_size[1] * num_base_anchors) must be smaller than 60000.

    Args:
        featmap_size (tuple[int]): The size of feature maps.
        gt_bboxes (Tensor): Ground truth boxes, shape (n, 4). Support dtype: float, half.
        stride (tuple(int)): stride of current level
        num_base_anchors (int): The number of base anchors.

    Returns:
        torch.Tensor: The valid flags of each anchor in a single level \
            feature map. Output size is [featmap_size[0] * featmap_size[1] * num_base_anchors].
    """

    gt_bboxes = box_dtype_check(gt_bboxes)

    flags = torch_npu.npu_anchor_response_flags(
        gt_bboxes,
        featmap_size,
        stride,
        num_base_anchors)
    return flags