import math
import warnings
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from fractions import Fraction
import os
import cv2
import numpy as np
import torch
import torch_npu
from torchvision.utils import _log_api_usage_once
import torchvision
import av
from torchvision import io as IO
import torchvision_npu
from torchvision_npu._utils import PathManager
from torchvision_npu import io as IO_npu
try:
FFmpegError = av.FFmpegError
except AttributeError:
FFmpegError = av.AVError
def patch_io_video():
setattr(torchvision.io, "read_video_ori", torchvision.io.read_video)
torchvision.io.read_video = read_video
def is_neon_supported():
cpu_info = "/proc/cpuinfo"
neon_support = False
if os.path.exists(cpu_info):
with open(cpu_info, "r") as f:
for line in f:
if "neon" in line.lower() or "asimd" in line.lower():
neon_support = True
break
return neon_support
def get_kp_thread_num():
num_thread = -1
env_var = os.environ.get("TORCHVISION_OMP_NUM_THREADS")
if env_var is not None and env_var.isdigit() and is_neon_supported():
num_thread = int(env_var)
return num_thread
def read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The start presentation time of the video
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
num_thread = get_kp_thread_num()
if torchvision.get_video_backend() == 'npu':
return IO_npu._video._read_video(filename, start_pts, end_pts, pts_unit, output_format)
elif torchvision.get_video_backend() == 'pyav' and num_thread > 0:
return IO_npu._video.kp_read_video(filename, start_pts, end_pts, pts_unit, output_format, thread_count=num_thread)
return IO.read_video_ori(filename, start_pts, end_pts, pts_unit, output_format)
class VideoFrameDvpp:
dts: int
pts: int
positions: int
frame: np.ndarray
def __init__(
self, dts: int = 0, pts: int = 0
) -> None:
self.pts = pts
self.dts = dts
...
class DecodeParams:
def __init__(self, container, start_offset, end_offset, stream):
self.container = container
self.start_offset = start_offset
self.end_offset = end_offset
self.stream = stream
def get_frame_by_cv(filename: str, container: "av.container.Container", stream: "av.stream.Stream") -> Tuple[Dict[
int, VideoFrameDvpp], int]:
cap = cv2.VideoCapture(filename)
if not cap.isOpened():
raise RuntimeError(f"filename is not opened in cv2.VideoCapture.")
cap.set(cv2.CAP_PROP_FORMAT, -1)
frames: Dict[int, VideoFrameDvpp] = {}
pts_list = []
pts_per_frame = round(1 / cap.get(cv2.CAP_PROP_POS_AVI_RATIO) / cap.get(cv2.CAP_PROP_FPS), 0)
for packet in container.demux(stream):
if packet.pts is not None:
frame = VideoFrameDvpp(packet.dts, packet.pts)
ret, frame.frame = cap.read()
pts_list.append(packet.pts)
frames[frame.pts] = frame
cap.release()
pts_list.sort()
position_list = {value: index for index, value in enumerate(pts_list)}
for frame in frames.values():
frame.positions = position_list[frame.pts]
return frames, pts_per_frame
def _decode_video_dvpp(
decode_params: DecodeParams,
frames: Dict[int, VideoFrameDvpp],
pts_per_frame: int
) -> torch.Tensor:
container = decode_params.container
start_offset = decode_params.start_offset
end_offset = decode_params.end_offset
stream = decode_params.stream
codecs_type = stream.name
hi_pt_h264 = 96
hi_pt_h265 = 265
if codecs_type == "h264":
codec_id = hi_pt_h264
elif codecs_type == "hevc":
codec_id = hi_pt_h265
else:
raise ValueError(f"The video codecs_type should be either 'h264' or 'hevc', got {codecs_type}.")
start_offset = int(start_offset / pts_per_frame) * pts_per_frame
frame_width = stream.width
frame_height = stream.height
end_offset_real = end_offset
if end_offset_real % pts_per_frame == 0:
end_offset_real += 1
end_offset_real = min(end_offset_real, len(frames) * pts_per_frame)
start_frame = math.ceil(start_offset / pts_per_frame)
total_frame = math.ceil((end_offset_real - start_offset) / pts_per_frame)
if end_offset_real < start_offset or total_frame == 0:
return torch.empty(0, dtype=torch.uint8, device='npu')
ret_tensor = torch.empty([total_frame, 3, frame_height, frame_width], dtype=torch.uint8, device='npu')
chn = torch.ops.torchvision._decode_video_create_chn(codec_id)
if chn == -1:
warnings.warn(f"_decode_video_create_chn failed {chn}")
torch.ops.torchvision._dvpp_sys_exit()
return None
ret = torch.ops.torchvision._decode_video_start_get_frame(chn, total_frame)
if not ret == 0:
warnings.warn(f"_decode_video_start_get_frame failed {ret}")
torch.ops.torchvision._decode_video_destroy_chnl(chn)
torch.ops.torchvision._dvpp_sys_exit()
return None
for packet in container.demux(stream):
if packet.pts is not None:
frame = frames[packet.pts].frame
input_tensor = torch.tensor(frame).npu(non_blocking=True)
if start_offset <= int(packet.pts) <= end_offset:
display = True
output_tensor = ret_tensor[frames[packet.pts].positions - start_frame]
else:
display = False
output_tensor = torch.empty(0, dtype=torch.uint8, device='npu')
ret = torch.ops.torchvision._decode_video_send_stream(chn, input_tensor, 69, display,
output_tensor)
if not ret == 0:
warnings.warn(f"_decode_video_send_stream failed {ret}")
ret_tensor_dvpp = torch.ops.torchvision._decode_video_stop_get_frame(chn, total_frame)
if ret_tensor_dvpp.numel() != 0:
ret_tensor = ret_tensor_dvpp
ret = torch.ops.torchvision._decode_video_destroy_chnl(chn)
if not ret == 0:
warnings.warn(f"_decode_video_destroy_chnl failed {ret}")
return ret_tensor
def _read_from_stream_dvpp(
filename: str,
container: "av.container.Container",
start_offset: float,
end_offset: float,
pts_unit: str,
stream: "av.stream.Stream",
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> torch.Tensor:
if not stream.type == "video":
raise RuntimeError("_read_from_stream_dvpp only handle video type")
if pts_unit == "sec" and stream.time_base != 0:
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf") and stream.time_base != 0:
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
max_buffer_size = 5
should_buffer = True
extradata = stream.codec_context.extradata
if extradata and b"DivX" in extradata:
pos = extradata.find(b"DivX")
data = extradata[pos:]
obj = re.search(rb"DivX(\d+)Build(\d+)(\w)", data)
if obj is None:
obj = re.search(rb"DivX(\d+)b(\d+)(\w)", data)
if obj is not None:
should_buffer = obj.group(3) == b"p"
seek_offset = start_offset
seek_offset = max(seek_offset - 1, 0)
if should_buffer:
seek_offset = max(seek_offset - max_buffer_size, 0)
frames, pts_per_frame = get_frame_by_cv(filename, container, stream)
try:
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except FFmpegError:
return []
decode_params = DecodeParams(container, start_offset, end_offset, stream)
frames = _decode_video_dvpp(decode_params, frames, pts_per_frame)
if frames is None:
warnings.warn(f"_decode_video_dvpp failed: {filename}")
return frames
def _read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(_read_video)
if pts_unit not in ("pts", "sec"):
raise ValueError(f"pts_unit should be either 'pts' or 'sec', got {pts_unit}.")
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
filepath = os.path.realpath(filename)
PathManager.check_directory_path_readable(filepath)
torchvision.io.video._check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
audio_frames = []
audio_timebase = torchvision.io._video_opt.default_timebase
try:
with av.open(filepath, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
if container.streams.video[0].name not in ("hevc", "h264"):
warnings.warn(f"This video in {filepath} is coding by {container.streams.video[0].name}, "
"not supported on DVPP backend and will fall back to run on the pyav."
"This may have performance implications.")
torchvision.set_video_backend('pyav')
vframes, aframes, info = IO.video.read_video(
filepath,
start_pts,
end_pts,
pts_unit,
output_format,
)
torchvision.set_video_backend('npu')
return vframes.npu(), aframes.npu(), info
else:
vframes = _read_from_stream_dvpp(
filepath,
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
if video_fps is not None:
info["video_fps"] = float(video_fps)
else:
vframes = torch.empty(0, dtype=torch.uint8, device='npu')
if container.streams.audio:
audio_frames = torchvision.io.video._read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except FFmpegError:
vframes = torch.empty(0, dtype=torch.uint8, device='npu')
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec" and audio_timebase != 0:
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf") and audio_timebase != 0:
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = torchvision.io.video._align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframes = aframes.npu(non_blocking=True)
if output_format == "THWC" and vframes is not None and len(vframes) != 0:
vframes = vframes.permute(0, 2, 3, 1)
return vframes, aframes, info
def _kp_yuv2rgb_is_vaild(video_pix_fmt, frame, video_colorspace):
return video_pix_fmt in ["yuv420p", "yuvj420p"] and \
not frame.width % 2 and not frame.height % 2 and \
video_colorspace in [1, 2]
def kp_read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
thread_count: int = 8
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
torchvision.io.video._check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
video_frames = []
audio_frames = []
audio_timebase = torchvision.io._video_opt.default_timebase
try:
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].thread_count = thread_count
video_frames = torchvision.io.video._read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
video_pix_fmt = container.streams.video[0].pix_fmt
video_color_range = container.streams.video[0].color_range
video_colorspace = container.streams.video[0].colorspace
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = torchvision.io.video._read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
except FFmpegError as e:
warnings.warn(f"[Warning] Error while reading video {filename}: {e}")
vframes_list = []
src_stride = [0, 0, 0, 0]
for index, frame in enumerate(video_frames):
y_plane = torch.from_numpy(np.frombuffer(frame.planes[0], dtype=np.uint8))
u_plane = torch.from_numpy(np.frombuffer(frame.planes[1], dtype=np.uint8))
v_plane = torch.from_numpy(np.frombuffer(frame.planes[2], dtype=np.uint8))
for i in range(3):
src_stride[i] = frame.planes[i].line_size
if _kp_yuv2rgb_is_vaild(video_pix_fmt, frame, video_colorspace):
vframes_list.append(torch.ops.torchvision.yuv2rgb(thread_count,
frame.width,
frame.height,
video_color_range,
video_colorspace,
y_plane,
src_stride,
u_plane,
v_plane).numpy())
else:
vframes_list.append(video_frames[index].to_rgb().to_ndarray())
aframes_list = [frame.to_ndarray() for frame in audio_frames]
if vframes_list:
vframes = torch.as_tensor(np.stack(vframes_list))
else:
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = torchvision.io.video._align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info