from typing import Optional, Tuple, List
import numpy as np
import torch
import torch_npu
from torch import Tensor
from ._utils import deal_with_tensor_batch
_interpolation_crop_and_resize_int2str = {0: "bilinear", 1: "nearest", 2: "bicubic"}
_interpolation_resize_int2str = {0: "linear", 1: "nearest", 2: "cubic"}
_padding_mode_int2str = {0: "constant", 1: "edge", 2: "reflect", 3: "symmetric"}
_gb_kernel_size = [1, 3, 5]
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = img.shape[1]
if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))
def _assert_mode(mode: int, supported_modes: List[int]) -> None:
if mode not in supported_modes:
raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation))
@deal_with_tensor_batch
def _normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
result = tensor
result = torch.ops.torchvision._normalize_aclnn(tensor, mean=mean, std=std)
if not inplace:
return result
tensor = result.clone()
return tensor
@deal_with_tensor_batch
def _vflip(img: Tensor) -> Tensor:
return torch.ops.torchvision._vertical_flip_aclnn(img)
@deal_with_tensor_batch
def _hflip(img: Tensor) -> Tensor:
return torch.ops.torchvision._horizontal_flip_aclnn(img)
@deal_with_tensor_batch
def _resized_crop(img: Tensor, crop_param: List[int], size: List[int], interpolation: int = 0) -> Tensor:
i, j, h, w = [p for p in crop_param]
return torch.ops.torchvision._crop_and_resize_aclnn(img, top=i, left=j, height=h, width=w,
size=size, interpolation_mode=interpolation)
@deal_with_tensor_batch
def _to_tensor(pic) -> Tensor:
return torch.ops.torchvision._img_to_tensor_aclnn(pic)
@deal_with_tensor_batch
def _resize(img: Tensor, size: List[int], interpolation: int = 0) -> Tensor:
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list) and len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
w, h = img.shape[-1], img.shape[-2]
if isinstance(size, int):
size_w, size_h = size, size
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[1], size[0]
if isinstance(size, int) or len(size) < 2:
if w < h:
size_h = size_w * h / w
else:
size_w = size_h * w / h
if w <= h and w == size_w:
return img
if h <= w and h == size_h:
return img
sizes = [int(size_h), int(size_w)]
return torch.ops.torchvision._resize_aclnn(img, size=sizes, interpolation_mode=interpolation)
@deal_with_tensor_batch
def _crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
return torch.ops.torchvision._crop_aclnn(img, top=top, left=left, height=height, width=width)
@deal_with_tensor_batch
def _pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: int = 0) -> Tensor:
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if isinstance(padding, int):
if torch.jit.is_scripting():
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
if min(p) < 0:
raise ValueError("pad value ({}) is not non-negative.".format(min(p)))
fill = [fill, fill, fill]
return torch.ops.torchvision._pad_aclnn(img, padding=p, padding_mode=padding_mode, fill=fill)
@deal_with_tensor_batch
def _adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
return torch.ops.torchvision._adjust_brightness_aclnn(img, factor=brightness_factor)
@deal_with_tensor_batch
def _adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
return torch.ops.torchvision._adjust_contrast_aclnn(img, factor=contrast_factor)
@deal_with_tensor_batch
def _adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
image_num_channels = img.shape[-3]
if image_num_channels == 1:
return img
return torch.ops.torchvision._adjust_hue_aclnn(img, factor=hue_factor)
@deal_with_tensor_batch
def _adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
return torch.ops.torchvision._adjust_saturation_aclnn(img, factor=saturation_factor)
@deal_with_tensor_batch
def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if kernel_size[0] not in _gb_kernel_size or kernel_size[1] not in _gb_kernel_size:
raise ValueError("sigma value must be in range {}.".format(_gb_kernel_size))
padding_mode = 2
return torch.ops.torchvision._gaussian_blur_aclnn(img, kernel_size=kernel_size, sigma=sigma,
padding_mode=padding_mode)
@deal_with_tensor_batch
def _rotate(img: Tensor, rotate_force_param: List,
center: Optional[List[int]] = None, fill: Optional[List[float]] = None
) -> Tensor:
angle, interpolation, expand = [p for p in rotate_force_param]
_assert_mode(interpolation, [0, 1])
if center is None:
center = []
else:
center = [int(x) for x in center]
if fill is not None:
if isinstance(fill, float):
fill = [fill, fill, fill]
if isinstance(fill, (tuple, list)):
fill = [fill[0], fill[0], fill[0]]
return torch.ops.torchvision._rotate_aclnn(img, angle=angle, interpolation_mode=interpolation, expand=expand,
center=center, padding_mode=0, fill=fill)
@deal_with_tensor_batch
def _affine(img: Tensor, matrix: List[float], interpolation: int = 1, fill: Optional[List[float]] = None
) -> Tensor:
_assert_mode(interpolation, [0, 1])
if len(matrix) != 6:
raise ValueError(f"Argument matrix should have 6 float values but have {len(matrix)}")
if fill is not None:
if isinstance(fill, float):
fill = [fill, fill, fill]
if isinstance(fill, (tuple, list)):
fill = [fill[0], fill[0], fill[0]]
return torch.ops.torchvision._warp_affine_aclnn(img, matrix=matrix, interpolation_mode=interpolation,
padding_mode=0, fill=fill)
@deal_with_tensor_batch
def _perspective(img: Tensor, matrix: List[float], interpolation: int = 0, fill: Optional[List[float]] = None
) -> Tensor:
_assert_mode(interpolation, [0, 1])
if len(matrix) != 8:
raise ValueError(f"Argument matrix should have 8 float values but have {len(matrix)}")
matrix.append(1)
if fill is not None:
if isinstance(fill, float):
fill = [fill, fill, fill]
if isinstance(fill, (tuple, list)):
fill = [fill[0], fill[0], fill[0]]
return torch.ops.torchvision._warp_perspective_aclnn(img, matrix=matrix, interpolation_mode=interpolation,
padding_mode=0, fill=fill)
@deal_with_tensor_batch
def _rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
_assert_channels(img, [3])
if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
return torch.ops.torchvision._rgb_to_grayscale_aclnn(img, output_channels_num=num_output_channels)
@deal_with_tensor_batch
def _posterize(img: Tensor, bits: int) -> Tensor:
if img.dtype != torch.uint8:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
_assert_channels(img, [1, 3])
return torch.ops.torchvision._posterize_aclnn(img, bits=bits)
@deal_with_tensor_batch
def _solarize(img: Tensor, threshold: float) -> Tensor:
_assert_channels(img, [1, 3])
threshold = [threshold]
return torch.ops.torchvision._solarize_aclnn(img, threshold=threshold)
@deal_with_tensor_batch
def _invert(img: Tensor) -> Tensor:
_assert_channels(img, [1, 3])
return torch.ops.torchvision._invert_aclnn(img)