import torch
import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error

__all__ = ['npu_fast_condition_index_put']


def npu_fast_condition_index_put(x, condition, value):
    """Using NPU affinity writing method to replace the native writing method in bool type index_put function.

    Examples::
    >>> x = torch.randn(128, 8192)
    >>> condition = x < 0.5
    >>> value = 0.
    >>> x1 = copy.deepcopy(x)[condition] = value
    >>> x1_opt = npu_fast_condition_index_put(x, condition, value)

    .. note::
        Because the index operator has been optimized all the time, the native implementation 
        performance of some scenarios is better.

    Args:
        x (torch.Tensor): Normal tensor.
        condition (torch.BoolTensor): Judgment condition, bool dtype.
        value (int, float): Stride of bboxes. Only IntTensor is supported.

    Returns:
        torch.Tensor: Box transformation deltas
    """

    if condition.dtype not in [torch.bool]:
        raise TypeError("Expected condition.dtype in [torch.bool]" + ops_error(ErrCode.TYPE))

    if value == 0:
        mask = torch.zeros_like(x)
    elif value == 1:
        mask = torch.ones_like(x)
    else:
        mask = torch.zeros_like(x) + value

    x = torch.where(condition, mask, x)
    return x