import unittest
from unittest.mock import patch, MagicMock
from argparse import Namespace
import torch
import mindspeed.megatron_adaptor
from tests.ut.utils import clear_module
class TestSoraGRPOTrainer(unittest.TestCase):
"""SoraGrpoTrainer is an abstract class and cannot be instantiated; use its subclass FluxGrpoTrainer for instantiation testing."""
def setUp(self):
self.mock_dataset_provider = MagicMock()
def test_init(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state') as initialize_sequence_parallel_state,
patch('mindspeed_mm.configs.config.mm_extra_args_provider'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(train_valid_test_dataset_provider=self.mock_dataset_provider)
self.assertEqual(trainer.local_rank, 0)
self.assertEqual(trainer.rank, 0)
self.assertEqual(trainer.world_size, 8)
self.assertEqual(trainer.train_valid_test_dataset_provider, self.mock_dataset_provider)
self.assertEqual(trainer.optimizer, None)
self.assertEqual(trainer.device, 0)
self.assertEqual(trainer.hyper_model, None)
mock_get_args.assert_called_once()
initialize_sequence_parallel_state.assert_called_once()
def test_get_args(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('sys.argv', ['script.py', '--load', 'test.txt'])
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
args = trainer.get_args()
self.assertEqual(args.dataloader_num_workers, 10)
self.assertEqual(args.train_batch_size, 1)
self.assertEqual(args.load, "test.txt")
def test_get_args_required_args_not_set(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args')
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with self.assertRaises(SystemExit) as cm:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
trainer.get_args()
self.assertEqual(cm.exception.code, 2)
def test_sd3_time_shift(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args')
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
sigma_schedule = torch.linspace(1, 0, 17)
res = trainer.sd3_time_shift(3, sigma_schedule)
expected = torch.tensor(
[1.0000, 0.9783, 0.9545, 0.9286, 0.9000, 0.8684, 0.8333, 0.7941, 0.7500, 0.7000, 0.6429, 0.5769,
0.5000, 0.4091, 0.3000, 0.1667, 0.0000])
self.assertTrue(torch.allclose(res, expected, atol=1e-4))
def test_gather_tensor(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
input = torch.zeros(3, 4)
output = trainer.gather_tensor(input)
self.assertEqual(output.size(), (24, 4))
def test_assert_eq(self):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
FluxGRPOTrainer.assert_eq(1, 1, "test")
def test_assert_eq_not_equal(self):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
with self.assertRaises(AssertionError) as err:
FluxGRPOTrainer.assert_eq(1, 2, "test")
self.assertEqual(str(err.exception), "test not equal")
def test_flux_step(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
config = {
"grpo": True,
"sde_solver": True,
"eta": 0.3,
"index": 0,
}
model_output = torch.ones(3, 4)
latents = torch.ones(3, 4)
sigmas = torch.tensor([0.03, 0.02, 0.01])
prev_sample = torch.ones(3, 4) / 2
torch.manual_seed(42)
prev_sample, pred_original_sample, log_prob = trainer.grpo_step(model_output, latents, sigmas,
prev_sample, config)
exp_prev_sample = torch.tensor([[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000]])
exp_pred_original_sample = torch.tensor([[0.9700, 0.9700, 0.9700, 0.9700],
[0.9700, 0.9700, 0.9700, 0.9700],
[0.9700, 0.9700, 0.9700, 0.9700]])
exp_log_prob = torch.tensor([-117.7857, -117.7857, -117.7857])
self.assertTrue(torch.allclose(prev_sample, exp_prev_sample, atol=1e-4))
self.assertTrue(torch.allclose(pred_original_sample, exp_pred_original_sample, atol=1e-4))
self.assertTrue(torch.allclose(log_prob, exp_log_prob, atol=1e-4))
def test_flux_step_pre_sample_none(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
config = {
"grpo": True,
"sde_solver": True,
"eta": 0.3,
"index": 0,
}
model_output = torch.ones(3, 4)
latents = torch.ones(3, 4)
sigmas = torch.tensor([0.03, 0.02, 0.01])
torch.manual_seed(42)
prev_sample, pred_original_sample, log_prob = trainer.grpo_step(model_output, latents, sigmas,
None, config)
exp_prev_sample = torch.tensor([[0.9706, 0.9643, 0.9675, 0.9674],
[0.9268, 0.9549, 1.0267, 0.9413],
[0.9743, 0.9685, 0.9765, 0.9847]])
exp_pred_original_sample = torch.tensor([[0.9700, 0.9700, 0.9700, 0.9700],
[0.9700, 0.9700, 0.9700, 0.9700],
[0.9700, 0.9700, 0.9700, 0.9700]])
exp_log_prob = torch.tensor([-0.0297, -0.8223, -0.1532])
self.assertTrue(torch.allclose(prev_sample, exp_prev_sample, atol=1e-4))
self.assertTrue(torch.allclose(pred_original_sample, exp_pred_original_sample, atol=1e-4))
self.assertTrue(torch.allclose(log_prob, exp_log_prob, atol=1e-4))
def test_save_checkpoint_rank_0(self):
clear_module("mindspeed_mm")
clear_module("json.dumps")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
patch('torch.distributed.fsdp.FullyShardedDataParallel'),
patch('safetensors.torch.save_file') as mock_save_file,
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
try:
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
transformer = MagicMock()
transformer.state_dict = MagicMock()
rank = 0
output_dir = "sora_grpo_trainer_test"
step = 40
epoch = 1
trainer.save_checkpoint(transformer, rank, output_dir, step, epoch)
transformer.state_dict.assert_called_once()
mock_save_file.assert_called_once()
finally:
import shutil
shutil.rmtree("sora_grpo_trainer_test")
def test_save_checkpoint_other_rank(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
patch('torch.distributed.fsdp.FullyShardedDataParallel'),
patch('safetensors.torch.save_file') as mock_save_file,
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
transformer = MagicMock()
transformer.state_dict = MagicMock()
rank = 1
output_dir = "sora_grpo_trainer_test"
step = 40
epoch = 1
trainer.save_checkpoint(transformer, rank, output_dir, step, epoch)
transformer.state_dict.assert_called_once()
mock_save_file.assert_not_called()
def test_main_print(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
with patch.object(FluxGRPOTrainer, 'get_args', return_value=Namespace(sp_size=1)) as mock_get_args:
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
trainer.main_print("test test")
def test_train(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.distributed.init_process_group'),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
patch('torch.distributed.barrier'),
patch('torch.distributed.fsdp.FullyShardedDataParallel'),
patch('mindspeed_mm.tasks.rl.soragrpo.utils.fsdp_util.apply_fsdp_checkpointing'),
patch('torch.optim.AdamW'),
patch('diffusers.get_scheduler'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
test_args = Namespace(
seed=1234,
save=None,
load_rank=8,
fsdp_sharding_strategy="hybrid_full",
use_cpu_offload=False,
master_weight_type="fp32",
gradient_checkpointing=True,
selective_checkpointing=1.0,
lr=1e-5,
weight_decay=0.01,
lr_scheduler="constant_with_warmup",
lr_warmup_steps=10,
lr_num_cycles=1,
lr_power=1.0,
sampler_seed=1234,
train_batch_size=1,
dataloader_num_workers=10,
gradient_accumulation_steps=4,
sp_size=1,
train_sp_batch_size=1,
train_iters=5,
save_interval=1000,
)
mock_hyper_model = MagicMock()
mock_hyper_model.diffuser = MagicMock()
with (patch.object(FluxGRPOTrainer, 'get_args', return_value=test_args),
patch.object(FluxGRPOTrainer, 'train_one_step', return_value=(0.001, 1)) as mock_train_one_step,
patch.object(FluxGRPOTrainer, 'model_provider', return_value=mock_hyper_model)):
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
trainer.train()
self.assertEqual(mock_train_one_step.call_count, 5)
def test_train_and_train_one_step(self):
clear_module("mindspeed_mm")
with (
patch.dict('os.environ', {
'LOCAL_RANK': '0',
'RANK': '0',
'WORLD_SIZE': '8',
'LOCAL_WORLD_SIZE': '8'
}, clear=True),
patch('torch.cuda.set_device'),
patch('torch.cuda.current_device', return_value=0),
patch(
'mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states.initialize_sequence_parallel_state'),
patch('mindspeed_mm.configs.config.merge_mm_args'),
patch('torch.distributed.init_process_group'),
patch('torch.distributed.is_initialized', return_value=True),
patch('torch.distributed.get_world_size', return_value=8),
patch('torch.distributed.all_gather'),
patch('torch.distributed.barrier'),
patch('torch.distributed.fsdp.FullyShardedDataParallel'),
patch('torch.distributed.all_reduce'),
patch('torch.distributed.get_rank', return_value=0),
patch('mindspeed_mm.tasks.rl.soragrpo.utils.fsdp_util.apply_fsdp_checkpointing'),
patch('torch.optim.AdamW'),
patch('diffusers.get_scheduler'),
):
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer
test_args = Namespace(
seed=1234,
save=None,
load_rank=8,
fsdp_sharding_strategy="hybrid_full",
use_cpu_offload=False,
master_weight_type="fp32",
gradient_checkpointing=True,
selective_checkpointing=1.0,
lr=1e-5,
weight_decay=0.01,
lr_scheduler="constant_with_warmup",
lr_warmup_steps=10,
lr_num_cycles=1,
lr_power=1.0,
sampler_seed=1234,
train_batch_size=1,
dataloader_num_workers=10,
gradient_accumulation_steps=4,
sp_size=1,
train_sp_batch_size=1,
train_iters=5,
save_interval=1000,
clip_range=1e-4,
max_grad_norm=2.0,
adv_clip_max=5.0
)
mock_hyper_model = MagicMock()
mock_hyper_model.diffuser = MagicMock()
samples_batched_list = [{
"timesteps": torch.tensor([[500, 642, 750, 928, 954, 576, 409, 868, 700, 794, 978, 900,
833, 1000, 300]]),
'latents': torch.rand([1, 15, 2025, 64]),
'next_latents': torch.rand([1, 15, 2025, 64]),
'log_probs': torch.tensor([[-0.4984, -0.5021, -0.4993, -0.5006, -0.4955, -0.5030, -0.4986, -0.4998,
-0.5018, -0.5026, -0.4975, -0.4988, -0.5011, -0.4990, -0.4987]]),
'rewards': torch.rand([1]),
'image_ids': torch.rand([1, 2025, 3]),
'text_ids': torch.tensor([[0., 0., 0.]]),
'encoder_hidden_states': torch.rand([1, 512, 4096]),
'pooled_prompt_embeds': torch.rand([1, 768]),
'advantages': torch.rand([1]),
}] * 12
train_timesteps = 9
sigma_schedule = torch.tensor([1.0000, 0.9783, 0.9545, 0.9286, 0.9000, 0.8684, 0.8333, 0.7941, 0.7500,
0.7000, 0.6429, 0.5769, 0.5000, 0.4091, 0.3000, 0.1667, 0.0000])
perms = torch.tensor([[12, 10, 8, 3, 2, 11, 13, 5, 9, 7, 1, 4, 6, 0, 14],
[6, 7, 10, 12, 11, 2, 13, 4, 9, 0, 5, 3, 8, 1, 14],
[5, 7, 2, 13, 3, 11, 10, 8, 4, 9, 0, 6, 14, 12, 1],
[1, 7, 12, 9, 3, 6, 0, 10, 2, 14, 13, 5, 8, 11, 4],
[8, 13, 5, 0, 6, 14, 2, 10, 7, 12, 1, 3, 11, 9, 4],
[12, 2, 5, 7, 11, 13, 9, 3, 14, 0, 10, 1, 6, 4, 8],
[7, 1, 10, 12, 13, 11, 9, 3, 4, 14, 5, 2, 8, 6, 0],
[6, 11, 14, 9, 10, 13, 5, 8, 7, 12, 2, 0, 3, 1, 4],
[2, 3, 4, 12, 13, 11, 10, 1, 9, 6, 7, 5, 0, 14, 8],
[3, 11, 4, 9, 1, 0, 12, 7, 2, 6, 10, 8, 14, 5, 13],
[13, 10, 14, 1, 6, 9, 0, 3, 5, 2, 4, 7, 11, 8, 12],
[7, 0, 2, 10, 9, 13, 5, 6, 4, 8, 14, 11, 3, 1, 12]])
with (patch.object(FluxGRPOTrainer, 'get_args', return_value=test_args),
patch.object(FluxGRPOTrainer, 'model_provider', return_value=mock_hyper_model),
patch.object(FluxGRPOTrainer, 'sample_reference',
return_value=(samples_batched_list, train_timesteps, sigma_schedule,
perms)) as mock_sample_reference,
patch.object(FluxGRPOTrainer, 'grpo_one_step',
return_value=torch.tensor([-0.4985], requires_grad=True)) as mock_grpo_one_step,
):
trainer = FluxGRPOTrainer(self.mock_dataset_provider)
trainer.train()
self.assertEqual(mock_sample_reference.call_count, 5)
self.assertEqual(mock_grpo_one_step.call_count, 12 * 9 * 5)
if __name__ == '__main__':
unittest.main(verbosity=2)