"""
Unit tests for mask_utils.py module.
Tests all mask generators and mask processor functionality.
"""
import torch
import random
from unittest.mock import patch, MagicMock
from abc import ABC, abstractmethod
from tests.ut.utils import judge_expression, TestConfig
from mindspeed_mm.utils.mask_utils import (
MaskType,
TYPE_TO_STR,
STR_TO_TYPE,
BaseMaskGenerator,
T2IVMaskGenerator,
I2VMaskGenerator,
TransitionMaskGenerator,
ContinuationMaskGenerator,
ClearMaskGenerator,
RandomTemporalMaskGenerator,
MaskProcessor,
MaskCompressor
)
class TestMaskType:
"""Test MaskType enum and related dictionaries."""
def test_mask_type_enum(self):
"""Test that all expected mask types are present in the enum."""
expected_types = [
"t2iv", "i2v", "transition",
"continuation", "clear", "random_temporal"
]
for mask_type in MaskType:
judge_expression(mask_type.name in expected_types)
judge_expression(len(MaskType) == len(expected_types))
def test_type_to_str(self):
"""Test TYPE_TO_STR dictionary mapping."""
for mask_type in MaskType:
judge_expression(TYPE_TO_STR[mask_type] == mask_type.name)
def test_str_to_type(self):
"""Test STR_TO_TYPE dictionary mapping."""
for mask_type in MaskType:
judge_expression(STR_TO_TYPE[mask_type.name] == mask_type)
class ConcreteMaskGenerator(BaseMaskGenerator):
"""Concrete implementation of BaseMaskGenerator for testing."""
def process(self, mask):
"""Implement abstract process method."""
return mask
class TestBaseMaskGenerator:
"""Test BaseMaskGenerator abstract base class."""
def test_create_system_mask(self):
"""Test create_system_mask method."""
generator = ConcreteMaskGenerator()
num_frames, height, width = 4, 32, 32
device = "cpu"
dtype = torch.float32
mask = generator.create_system_mask(num_frames, height, width, device, dtype)
judge_expression(mask.shape == (num_frames, 1, height, width))
judge_expression(mask.device.type == device)
judge_expression(mask.dtype == dtype)
judge_expression(torch.all(mask == 1.0))
try:
generator.create_system_mask(None, height, width, device, dtype)
judge_expression(False)
except ValueError:
pass
try:
generator.create_system_mask(num_frames, None, width, device, dtype)
judge_expression(False)
except ValueError:
pass
try:
generator.create_system_mask(num_frames, height, None, device, dtype)
judge_expression(False)
except ValueError:
pass
def test_abstract_method_enforcement(self):
"""Test that BaseMaskGenerator cannot be instantiated directly."""
try:
BaseMaskGenerator()
judge_expression(False)
except TypeError as e:
judge_expression("abstract class" in str(e).lower())
judge_expression("process" in str(e).lower())
class TestT2IVMaskGenerator:
"""Test T2IVMaskGenerator class."""
def test_process(self):
"""Test process method."""
generator = T2IVMaskGenerator()
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask == 1.0))
judge_expression(processed_mask.shape == mask.shape)
def test_call(self):
"""Test __call__ method."""
generator = T2IVMaskGenerator()
num_frames, height, width = 4, 32, 32
mask = generator(num_frames, height, width, device="cpu")
judge_expression(mask.shape == (num_frames, 1, height, width))
judge_expression(torch.all(mask == 1.0))
class TestI2VMaskGenerator:
"""Test I2VMaskGenerator class."""
def test_process(self):
"""Test process method."""
generator = I2VMaskGenerator()
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask[0] == 0.0))
judge_expression(torch.all(processed_mask[1:] == 1.0))
judge_expression(processed_mask.shape == mask.shape)
class TestTransitionMaskGenerator:
"""Test TransitionMaskGenerator class."""
def test_process(self):
"""Test process method."""
generator = TransitionMaskGenerator()
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask[0] == 0.0))
judge_expression(torch.all(processed_mask[-1] == 0.0))
judge_expression(torch.all(processed_mask[1:-1] == 1.0))
judge_expression(processed_mask.shape == mask.shape)
class TestContinuationMaskGenerator:
"""Test ContinuationMaskGenerator class."""
def test_init(self):
"""Test __init__ method with default and custom parameters."""
generator = ContinuationMaskGenerator()
judge_expression(generator.min_clear_ratio == 0.0)
judge_expression(generator.max_clear_ratio == 1.0)
min_ratio, max_ratio = 0.2, 0.8
generator = ContinuationMaskGenerator(min_clear_ratio=min_ratio, max_clear_ratio=max_ratio)
judge_expression(generator.min_clear_ratio == min_ratio)
judge_expression(generator.max_clear_ratio == max_ratio)
@patch('random.randint')
def test_process(self, mock_randint):
"""Test process method with mocked random.randint."""
mock_randint.return_value = 2
generator = ContinuationMaskGenerator(min_clear_ratio=0.0, max_clear_ratio=1.0)
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask[0:2] == 0.0))
judge_expression(torch.all(processed_mask[2:] == 1.0))
judge_expression(processed_mask.shape == mask.shape)
mock_randint.assert_called_once_with(0, 4)
class TestClearMaskGenerator:
"""Test ClearMaskGenerator class."""
def test_process(self):
"""Test process method."""
generator = ClearMaskGenerator()
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask == 0.0))
judge_expression(processed_mask.shape == mask.shape)
class TestRandomTemporalMaskGenerator:
"""Test RandomTemporalMaskGenerator class."""
def test_init(self):
"""Test __init__ method with default and custom parameters."""
generator = RandomTemporalMaskGenerator()
judge_expression(generator.min_clear_ratio == 0.0)
judge_expression(generator.max_clear_ratio == 1.0)
min_ratio, max_ratio = 0.3, 0.7
generator = RandomTemporalMaskGenerator(min_clear_ratio=min_ratio, max_clear_ratio=max_ratio)
judge_expression(generator.min_clear_ratio == min_ratio)
judge_expression(generator.max_clear_ratio == max_ratio)
@patch('random.randint')
@patch('random.sample')
def test_process(self, mock_sample, mock_randint):
"""Test process method with mocked random functions."""
mock_randint.return_value = 2
mock_sample.return_value = [1, 3]
generator = RandomTemporalMaskGenerator(min_clear_ratio=0.0, max_clear_ratio=1.0)
num_frames, height, width = 4, 32, 32
mask = torch.ones(num_frames, 1, height, width)
processed_mask = generator.process(mask)
judge_expression(torch.all(processed_mask[0] == 1.0))
judge_expression(torch.all(processed_mask[1] == 0.0))
judge_expression(torch.all(processed_mask[2] == 1.0))
judge_expression(torch.all(processed_mask[3] == 0.0))
judge_expression(processed_mask.shape == mask.shape)
mock_randint.assert_called_once_with(0, 4)
mock_sample.assert_called_once_with(range(4), 2)
class TestMaskProcessor:
"""Test MaskProcessor class."""
def test_init(self):
"""Test __init__ method with default and custom parameters."""
processor = MaskProcessor()
judge_expression(processor.max_height == 640)
judge_expression(processor.max_width == 640)
judge_expression(processor.min_clear_ratio == 0.0)
judge_expression(processor.max_clear_ratio == 1.0)
max_h, max_w = 512, 512
min_ratio, max_ratio = 0.1, 0.9
processor = MaskProcessor(
max_height=max_h,
max_width=max_w,
min_clear_ratio=min_ratio,
max_clear_ratio=max_ratio
)
judge_expression(processor.max_height == max_h)
judge_expression(processor.max_width == max_w)
judge_expression(processor.min_clear_ratio == min_ratio)
judge_expression(processor.max_clear_ratio == max_ratio)
for mask_type in MaskType:
judge_expression(mask_type in processor.mask_generators)
def test_get_mask(self):
"""Test get_mask method with all mask types."""
processor = MaskProcessor()
num_frames, channels, height, width = 4, 3, 32, 32
pixel_values = torch.randn(num_frames, channels, height, width)
for mask_type in MaskType:
mask = processor.get_mask(mask_type, pixel_values, device="cpu")
if mask.dim() == 3:
mask = mask.unsqueeze(1)
judge_expression(mask.shape == (num_frames, 1, height, width))
judge_expression(mask.device == pixel_values.device)
judge_expression(mask.dtype == pixel_values.dtype)
@patch('random.choices')
def test_call_with_mask_type_ratio_dict(self, mock_choices):
"""Test __call__ method with mask_type_ratio_dict parameter."""
mock_choices.return_value = [MaskType.transition]
processor = MaskProcessor()
num_frames, channels, height, width = 4, 3, 32, 32
pixel_values = torch.randn(num_frames, channels, height, width)
mask_type_ratio_dict = {
MaskType.t2iv: 0.2,
MaskType.i2v: 0.3,
MaskType.transition: 0.5
}
result = processor(pixel_values, mask_type_ratio_dict=mask_type_ratio_dict)
judge_expression("mask" in result)
judge_expression("masked_pixel_values" in result)
judge_expression(result["masked_pixel_values"].shape == pixel_values.shape)
mock_choices.assert_called_once_with(
list(mask_type_ratio_dict.keys()),
list(mask_type_ratio_dict.values())
)
def test_call_with_invalid_parameters(self):
"""Test __call__ method with invalid parameters."""
processor = MaskProcessor()
num_frames, channels, height, width = 4, 3, 32, 32
pixel_values = torch.randn(num_frames, channels, height, width)
try:
processor(pixel_values)
judge_expression(False)
except ValueError:
pass
class TestMaskCompressor:
"""Test MaskCompressor class."""
def test_init(self):
"""Test __init__ method with default and custom parameters."""
compressor = MaskCompressor()
judge_expression(compressor.ae_stride_h == 8)
judge_expression(compressor.ae_stride_w == 8)
judge_expression(compressor.ae_stride_t == 4)
stride_h, stride_w, stride_t = 4, 4, 2
compressor = MaskCompressor(
ae_stride_h=stride_h,
ae_stride_w=stride_w,
ae_stride_t=stride_t
)
judge_expression(compressor.ae_stride_h == stride_h)
judge_expression(compressor.ae_stride_w == stride_w)
judge_expression(compressor.ae_stride_t == stride_t)