import collections
import os
import random
import unittest
from unittest.mock import patch
import cv2
import torch
import torch_npu
import torchvision
import torchvision.io as io
import torchvision_npu
torch_npu.npu.current_stream().set_data_preprocess_stream(True)
try:
import av
io.video._check_av_available()
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "DataVideo", "readvideo")
CheckerConfig = [
"duration_sec",
"duration_pts",
"video_fps",
"audio_sample_rate",
]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
all_check_config = GroundTruth(
duration_sec=0,
duration_pts=0,
video_fps=0,
audio_sample_rate=0,
)
test_videos = {
"R6llTwEh07w.mp4": GroundTruth(
duration_sec=10.0,
duration_pts=154624,
video_fps=30.0,
audio_sample_rate=44100,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration_sec=11.0,
duration_pts=326326,
video_fps=29.97,
audio_sample_rate=48000,
),
}
no_support_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration_sec=2.0,
duration_pts=0,
video_fps=30.0,
audio_sample_rate=None,
),
"hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi": None
}
class TestReadVideo(unittest.TestCase):
def _check_info(self, info, config):
self.assertNotEqual(info, None)
self.assertAlmostEqual(info['video_fps'], config.video_fps, places=1)
self.assertAlmostEqual(info['audio_fps'], config.audio_sample_rate, places=1)
def _frames_compare(self, frames1, frames2):
self.assertEqual(len(frames1), len(frames2))
if len(frames1) == 0 or len(frames2) == 0:
return
self.assertEqual(frames1[0].shape, frames2[0].shape)
max_diff_threshold = 3
abs_diff = torch.where(frames1 > frames2, frames1 - frames2, frames2 - frames1)
if abs_diff.numel() == 0:
return
max_diff = abs_diff.max()
self.assertLessEqual(max_diff, max_diff_threshold)
def test_read_video_npu_pts0(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
print("test_read_video_npu_pts0 video path: ", full_path)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path)
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path)
self._check_info(info_npu, config)
video_npu = video_npu.cpu()
audio_npu = audio_npu.cpu()
self._frames_compare(video_npu, video_ori)
self._frames_compare(audio_npu, audio_ori)
def test_read_video_npu_random_pts_uint_sec(self):
num_iter = 2
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
random_seek = [random.uniform(0, config.duration_sec) for _ in range(num_iter)]
for _i in range(num_iter):
end_pts = random.uniform(random_seek[_i], config.duration_sec)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, random_seek[_i], end_pts,
pts_unit="sec")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, random_seek[_i], end_pts,
pts_unit="sec")
video_npu = video_npu.cpu()
audio_npu = audio_npu.cpu()
self._frames_compare(video_npu, video_ori)
self._frames_compare(audio_npu, audio_ori)
def test_read_video_npu_random_pts_uint_pts(self):
num_iter = 2
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
random_seek = [random.randint(0, config.duration_pts) for _ in range(num_iter)]
for _i in range(num_iter):
end_pts = random.randint(random_seek[_i], config.duration_pts)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, random_seek[_i], end_pts,
pts_unit="pts")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, random_seek[_i], end_pts,
pts_unit="pts")
video_npu = video_npu.cpu()
audio_npu = audio_npu.cpu()
self._frames_compare(video_npu, video_ori)
self._frames_compare(audio_npu, audio_ori)
def test_read_video_npu_equal_pts(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
random_seek = random.randint(0, config.duration_pts)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, random_seek, random_seek)
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, random_seek, random_seek)
video_npu = video_npu.cpu()
audio_npu = audio_npu.cpu()
self._frames_compare(video_npu, video_ori)
self._frames_compare(audio_npu, audio_ori)
def test_read_video_npu_pts_out_of_range(self):
test_video = "R6llTwEh07w.mp4"
full_path = os.path.join(VIDEO_DIR, test_video)
duration_pts = test_videos[test_video].duration_pts
torchvision.set_video_backend('npu')
video_npu, _, _ = torchvision.io.read_video(full_path, -15, -10, pts_unit="pts")
self.assertEqual(len(video_npu), 0)
video_npu, _, _ = torchvision.io.read_video(full_path, duration_pts + 10000, duration_pts + 15000,
pts_unit="pts")
self.assertEqual(len(video_npu), 0)
self.assertRaises(ValueError, torchvision.io.read_video, full_path, 15, 10, pts_unit="pts")
def test_read_video_npu_chw(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, output_format="TCHW")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, output_format="TCHW")
video_npu = video_npu.cpu()
self._frames_compare(video_npu, video_ori)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, -15, -10, pts_unit="pts",
output_format="TCHW")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, -15, -10, pts_unit="pts",
output_format="TCHW")
video_npu = video_npu.cpu()
self._frames_compare(video_npu, video_ori)
def test_read_video_npu_hwc(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, output_format="THWC")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, output_format="THWC")
video_npu = video_npu.cpu()
self._frames_compare(video_npu, video_ori)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path, -15, -10, pts_unit="pts",
output_format="THWC")
torchvision.set_video_backend('pyav')
video_ori, audio_ori, info_ori = torchvision.io.read_video(full_path, -15, -10, pts_unit="pts",
output_format="THWC")
video_npu = video_npu.cpu()
self._frames_compare(video_npu, video_ori)
def test_read_video_npu_not_support_type(self):
for test_video, config in no_support_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
torchvision.set_video_backend('npu')
video_npu, audio_npu, info_npu = torchvision.io.read_video(full_path)
torchvision.set_video_backend('pyav')
video_cpu, audio_cpu, info_cpu = torchvision.io.read_video(full_path)
self.assertTrue(torch.equal(video_npu.cpu(), video_cpu))
self.assertTrue(torch.equal(audio_npu.cpu(), audio_cpu))
self.assertEqual(info_cpu, info_npu)
def test_read_video_mem(self):
torchvision.set_video_backend('npu')
full_path = os.path.join(VIDEO_DIR, list(test_videos.keys())[0])
cap = cv2.VideoCapture(full_path)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
frame_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
frame_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
cap.release()
vframes, _, _ = torchvision.io.read_video(full_path)
max_mem = torch_npu.npu.max_memory_allocated()
cal_mem = 3 * frame_height * frame_width * frame_count
reserve_mem_ratio = 0.1
self.assertLess(max_mem, (1 + reserve_mem_ratio) * cal_mem)
if __name__ == '__main__':
unittest.main()