from typing import List, Union
import torch
import torchvision
from torchvision.utils import _log_api_usage_once
from torchvision.io.image import ImageReadMode
def patch_io_image():
setattr(torchvision.io.image, "encode_jpeg_ori", torchvision.io.image.encode_jpeg)
torchvision.io.image.encode_jpeg = encode_jpeg
torchvision.io.encode_jpeg = encode_jpeg
setattr(torchvision.io.image, "decode_jpeg_ori", torchvision.io.image.decode_jpeg)
torchvision.io.image.decode_jpeg = _decode_jpeg
torchvision.io.decode_jpeg = _decode_jpeg
def encode_jpeg(img: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding JPEG file.
Args:
img (Tensor[channels, image_height, image_width])): int8 image tensor of
``c`` channels, where ``c`` must be 1 or 3.
For npu backend, img format should be NCHW(N=1, C=1or3).
quality (int): Quality of the resulting JPEG file, it must be a number between
1 and 100. Default: 75
Returns:
output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
JPEG file.
"""
if img.device.type == 'npu':
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_jpeg)
if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number between 1 and 100")
img = img.unsqueeze(0)
return torch.ops.torchvision._encode_jpeg_aclnn(img, quality)
return torchvision.io.image.encode_jpeg_ori(img, quality)
def _decode_jpeg_npu(
input: torch.Tensor,
mode: ImageReadMode = ImageReadMode.UNCHANGED,
) -> torch.Tensor:
shape = torch.ops.torchvision._get_jpeg_image_shape(input)
uint8_tensor = input.npu(non_blocking=True)
channels = shape[2]
if mode == ImageReadMode.GRAY:
channels = 1
elif mode == ImageReadMode.RGB:
channels = 3
return torch.ops.torchvision._decode_jpeg_aclnn(
uint8_tensor, image_shape=shape, channels=channels).squeeze(0)
def _decode_jpeg(
input: Union[torch.Tensor, List[torch.Tensor]],
mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: Union[str, torch.device] = "cpu",
apply_exif_orientation: bool = False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU、CUDA or NPU.
The values of the output tensor are uint8 between 0 and 255.
"""
if isinstance(device, str):
device = torch.device(device)
if isinstance(mode, str):
mode = ImageReadMode[mode.upper()]
device = torch.device(device)
if device.type != 'npu':
return torchvision.io.image.decode_jpeg_ori(input, mode, device)
if isinstance(input, list):
if len(input) == 0:
raise ValueError("Input list must contain at least one element")
if not all(isinstance(t, torch.Tensor) for t in input):
raise ValueError("All elements of the input list must be tensors.")
if not all(t.device.type == "cpu" for t in input):
raise ValueError("Input list must contain tensors on CPU.")
return [_decode_jpeg_npu(img, mode) for img in input]
else:
if input.device.type != "cpu":
raise ValueError("Input tensor must be a CPU tensor")
return _decode_jpeg_npu(input, mode)