from typing import Optional
import torch
from tqdm.auto import tqdm
class WanFlowMatchScheduler():
def __init__(
self,
num_inference_timesteps=None,
num_train_timesteps=1000,
shift=3.0,
sigma_max=1.0,
sigma_min=0.003 / 1.002,
inverse_timesteps=False,
extra_one_step=False,
reverse_sigmas=False,
guidance_scale=0.0,
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
**kwargs
):
self.num_inference_timesteps = num_inference_timesteps
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.guidance_scale = guidance_scale
self.do_classifier_free_guidance = guidance_scale > 1.0
if self.num_inference_timesteps is not None:
self._set_timesteps(self.num_inference_timesteps, training=False)
else:
self._set_timesteps(self.num_train_timesteps, training=True)
if not (0 <= min_timestep_boundary < max_timestep_boundary <= 1):
raise ValueError("min_timestep_boundary and max_timestep_boundary must satisfy 0 <= min < max <= 1")
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def training_losses(
self,
model_output: torch.Tensor,
x_start: Optional[torch.Tensor] = None,
x_t: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
t: Optional[torch.Tensor] = None,
**kwargs
):
reduction = kwargs.get("reduction", "mean")
target = noise - x_start
loss = torch.nn.functional.mse_loss(model_output.float(), target.float(), reduction=reduction)
loss *= self._training_weight(t)
return loss
def sample(
self,
model,
latents,
model_kwargs
):
num_inference_steps = self.num_inference_timesteps or self.num_train_timesteps
num_warmup_steps = len(self.timesteps) - num_inference_steps
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(self.timesteps):
latent_model_input = latents
timestep = t.expand(latent_model_input.shape[0]).to(latents.device)
noise_pred = model(
latent_model_input,
timestep,
model_kwargs.get('prompt_embeds'),
**model_kwargs
)
if self.do_classifier_free_guidance:
noise_uncond = model(
latent_model_input,
timestep,
model_kwargs.get('negative_prompt_embeds'),
**model_kwargs
)
noise_pred = noise_uncond + self.guidance_scale * (noise_pred - noise_uncond)
latents = self._step(noise_pred, t, latents)
if i == len(self.timesteps) - 1 or ((i + 1) > num_warmup_steps):
progress_bar.update()
return latents
def q_sample(self, latents, noise=None, t=None, **kwargs):
if noise is None:
noise = torch.randn_like(latents)
model_kwargs = kwargs.get("model_kwargs", {})
first_frame_latents = model_kwargs.get("first_frame_latents", None)
if first_frame_latents is not None:
noise[:, :, 0: 1] = first_frame_latents
if t is None:
max_timestep_boundary = int(self.max_timestep_boundary * self.num_train_timesteps)
min_timestep_boundary = int(self.min_timestep_boundary * self.num_train_timesteps)
timestep_idx = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.timesteps[timestep_idx].to(latents.device)
else:
timestep = t
timestep_idx = (t.to("cpu") == self.timesteps).nonzero()[0]
sigma = self.sigmas[timestep_idx].to(latents.device)
noised_latents = (1 - sigma) * latents + sigma * noise
return noised_latents, noise, timestep
def _set_timesteps(self, num_steps=100, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
y = torch.exp(-2 * ((self.timesteps - num_steps / 2) / num_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def _step(self, model_output, timestep, sample):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_idx = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_idx]
if timestep_idx + 1 >= len(self.timesteps):
sigma_next = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_next = self.sigmas[timestep_idx + 1]
prev_sample = sample + model_output * (sigma_next - sigma)
return prev_sample
def _training_weight(self, timestep):
if len(timestep) > 1:
timestep = timestep[0]
timestep_idx = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_idx]
return weights