import os
import random
import unittest
import numpy as np
import torch
import torch_npu
import torchvision
from torchvision.transforms import v2
import torchvision_npu
class TestUniformTemporalSubsample(unittest.TestCase):
def _reference_uniform_temporal_subsample_video(self, video, *, num_samples):
t = video.shape[-4]
assert num_samples > 0 and t > 0
indices = torch.linspace(0, t - 1, num_samples, device=video.device, dtype=torch.float64)
indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(video, -4, indices)
def generate_random_video_tensor(self, num_frames, num_channels, height, width):
video_float = np.random.rand(num_frames, num_channels, height, width)
video_int = np.random.randint(0, 255, size=(num_frames, num_channels, height, width), dtype=np.uint8)
video_tensor_float = torch.from_numpy(video_float)
video_tensor_int = torch.from_numpy(video_int)
return video_tensor_float, video_tensor_int
def test_video_correctness(self):
os.environ["TORCHVISION_OMP_NUM_THREADS"] = "8"
test_nums = 10
for i in range(test_nums):
num_channels = 3
num_frames = np.random.randint(2, 50)
height = np.random.randint(480, 1024)
width = np.random.randint(480, 1024)
num_samples = np.random.randint(1, num_frames)
video_tensor_float, video_tensor_int = self.generate_random_video_tensor(num_frames, num_channels, height, width)
actual_float = v2.functional.uniform_temporal_subsample_video(video_tensor_float, num_samples=num_samples)
expected_float = self._reference_uniform_temporal_subsample_video(video_tensor_float, num_samples=num_samples)
self.assertTrue(torch.allclose(actual_float, expected_float, atol=0))
actual_int = v2.functional.uniform_temporal_subsample_video(video_tensor_int, num_samples=num_samples)
expected_int = self._reference_uniform_temporal_subsample_video(video_tensor_int, num_samples=num_samples)
self.assertTrue(torch.allclose(actual_int, expected_int, atol=0))
if __name__ == "__main__":
unittest.main()