import torch
import torch_npu
import mindspeed.megatron_adaptor

from mindspeed_mm.models.diffusion.wan_flow_match_scheduler import WanFlowMatchScheduler
from tests.ut.utils import judge_expression


class TestWanFlowMatchScheduler:
    @staticmethod
    def test_training_init():
        train_num_timesteps = 500
        scheduler = WanFlowMatchScheduler(num_train_timesteps=train_num_timesteps)
        judge_expression(not scheduler.do_classifier_free_guidance)
        judge_expression(isinstance(scheduler.sigmas, torch.Tensor))
        judge_expression(
            scheduler.num_train_timesteps == train_num_timesteps and
            scheduler.timesteps.shape[0] == train_num_timesteps and
            scheduler.linear_timesteps_weights.shape == (train_num_timesteps,)
        )
    
    @staticmethod
    def test_inference_init():
        test_num_timesteps = 50
        scheduler = WanFlowMatchScheduler(num_inference_timesteps=test_num_timesteps)
        judge_expression(scheduler.timesteps.shape[0] == test_num_timesteps)
        judge_expression(not hasattr(scheduler, "linear_timesteps_weights"))
    
    @staticmethod
    def test_sigma_conditions():
        scheduler = WanFlowMatchScheduler(
            num_inference_timesteps=100,
            inverse_timesteps=True,
            extra_one_step=True,
            reverse_sigmas=True,
            shift=3.0
        )
        extra_sigmas = torch.linspace(1.0, 0.003 / 1.002, 101)[:-1]
        flip_sigmas = torch.flip(extra_sigmas, dims=[0])
        scaled_sigmas = 3.0 * flip_sigmas / (1 + 2.0 * flip_sigmas)
        judge_expression(torch.allclose(scheduler.sigmas, 1 - scaled_sigmas))

    @staticmethod
    def test_training_weight_loss():
        scheduler = WanFlowMatchScheduler(num_train_timesteps=200)
        timestep = torch.Tensor([50])
        timestep_idx = torch.argmin((scheduler.timesteps - timestep).abs())
        expect_weights = scheduler.linear_timesteps_weights[timestep_idx]
        weights = scheduler._training_weight(timestep)
        judge_expression(torch.isclose(weights, expect_weights))

        model_output = torch.ones([1, 16, 16, 30, 50])
        loss = scheduler.training_losses(
            model_output=model_output,
            x_start=model_output,
            noise=torch.zeros_like(model_output),
            t=timestep
        )
        expect_loss = torch.nn.functional.mse_loss(model_output, -model_output) * expect_weights
        judge_expression(torch.allclose(loss, expect_loss))
    
    @staticmethod
    def test_sample():
        class Dummy_Model:
            def __call__(self, x, t, emb, **kwargs):
                return torch.zeros_like(x)
        
        scheduler_classifier = WanFlowMatchScheduler(
            num_inference_timesteps=5,
            guidance_scale=0.0
        )
        scheduler_classifier_free = WanFlowMatchScheduler(
            num_train_timesteps=10,
            guidance_scale=5.0
        )
        latents = torch.randn([1, 16, 16, 30, 50])
        output_classifer = scheduler_classifier.sample(Dummy_Model(), latents, {})
        output_classifer_free = scheduler_classifier_free.sample(Dummy_Model(), latents, {})
        judge_expression(output_classifer.shape == latents.shape)
        judge_expression(output_classifer_free.shape == latents.shape)

    @staticmethod
    def test_boundary_sigma():
        scehduler = WanFlowMatchScheduler(num_inference_timesteps=3, reverse_sigmas=False)
        last_timestep = scehduler.timesteps[-1]
        sample_before = torch.randn([1, 16, 16, 30, 50])
        model_output = torch.randn([1, 16, 16, 30, 50])
        sample_after = scehduler._step(model_output, last_timestep, sample_before)
        
        current_sigma = scehduler.sigmas[-1]
        next_sigma = 0
        expect_sample = sample_before + model_output * (next_sigma - current_sigma)
        judge_expression(torch.allclose(sample_after, expect_sample))
    
    @staticmethod
    def test_q_sample_randomness():
        torch.manual_seed(42)
        scheduler = WanFlowMatchScheduler()
        latents = torch.ones([1, 16, 16, 30, 50])
        noised_1, _, __ = scheduler.q_sample(latents)
        torch.manual_seed(42)
        noised_2, _, __ = scheduler.q_sample(latents)  
        judge_expression(torch.allclose(noised_1, noised_2))