import unittest
from unittest.mock import MagicMock, patch, mock_open
import os
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image
from torch_npu.testing.testcase import TestCase, run_tests
from mx_driving.dataset.utils import balanced_resize, BalancedRandomResize
class TestBalancedResize(TestCase):
def setUp(self):
super().setUp()
self.test_image = Image.new('RGB', (100, 200), color='red')
self.test_target = {
"boxes": torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]),
"area": torch.tensor([200, 400]),
"size": torch.tensor([200, 100])
}
def test_balanced_resize_no_target(self):
resized_image, target = balanced_resize(self.test_image, None, 50)
self.assertEqual(min(resized_image.size), 50)
self.assertEqual(max(resized_image.size), 100)
self.assertIsNone(target)
def test_balanced_resize_with_target(self):
resized_image, target = balanced_resize(self.test_image, self.test_target, 50)
self.assertEqual(min(resized_image.size), 50)
self.assertEqual(max(resized_image.size), 100)
self.assertIsNotNone(target)
self.assertIn("boxes", target)
self.assertIn("area", target)
self.assertIn("size", target)
def test_balanced_resize_with_max_size(self):
resized_image, target = balanced_resize(self.test_image, self.test_target, 50, max_size=60)
self.assertEqual(max(resized_image.size), 60)
def test_balanced_resize_with_masks_raises_error(self):
target_with_masks = self.test_target.copy()
target_with_masks["masks"] = torch.ones(1, 200, 100)
with self.assertRaises(RuntimeError):
balanced_resize(self.test_image, target_with_masks, 50)
class TestBalancedRandomResize(TestCase):
def setUp(self):
super().setUp()
self.sizes = [50, 100, 150]
self.transform = BalancedRandomResize(self.sizes)
def test_init_with_invalid_sizes(self):
with self.assertRaises(TypeError):
BalancedRandomResize("invalid")
def test_call_without_target(self):
test_image = Image.new('RGB', (100, 200), color='red')
for _ in range(10):
resized_image, target = self.transform(test_image, None)
self.assertIn(resized_image.size[0], self.sizes)
self.assertIsNone(target)
def test_call_with_target(self):
test_image = Image.new('RGB', (100, 200), color='red')
test_target = {
"boxes": torch.tensor([[10, 20, 30, 40]]),
"area": torch.tensor([200]),
"size": torch.tensor([200, 100])
}
resized_image, target = self.transform(test_image, test_target)
self.assertIn(min(resized_image.size), self.sizes)
self.assertIsNotNone(target)
if __name__ == '__main__':
run_tests()