import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union, Callable
from tqdm import tqdm
from megatron.core import mpu
import numpy as np
import torch
import torch.distributed as dist
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def opensora_linear_quadratic_schedule(num_inference_steps, approximate_steps=1000):
if approximate_steps % 2 != 0:
raise ValueError(f"approximate_steps must be even")
if num_inference_steps % 2 != 0:
raise ValueError(f"num_inference_steps must be even")
if num_inference_steps > approximate_steps:
raise ValueError(f"num_inference_steps must be less than or equal to approximate_steps")
_num_inference_steps = num_inference_steps // 2
_approximate_steps = approximate_steps // 2
linear_sigmas = [i / (2 * _approximate_steps) for i in range(_num_inference_steps)]
quadratic_a = (_approximate_steps - _num_inference_steps) / (_approximate_steps * _num_inference_steps ** 2)
quadratic_b = (5 * _num_inference_steps - 4 * _approximate_steps) / (2 * _approximate_steps * _num_inference_steps)
quadratic_c = (_approximate_steps - _num_inference_steps) / _approximate_steps
quadratic_sigmas = [
quadratic_a * i ** 2 + quadratic_b * i + quadratic_c
for i in range(_num_inference_steps, 2 * _num_inference_steps)
]
sigmas = linear_sigmas + quadratic_sigmas + [1.0]
sigmas = [1.0 - x for x in sigmas]
return sigmas
class OpenSoraPlanScheduler:
"""
In OpenSoraPlan v1.5, we use FlowMatching to train the model.
"""
order = 1
def __init__(
self,
num_inference_steps: bool = None,
guidance_scale: float = 4.5,
guidance_rescale: float = 0.7,
use_linear_quadratic_schedule: bool = False,
use_dynamic_shifting: bool = False,
weighting_scheme: str = "logit_normal",
logit_mean: float = 0.0,
logit_std: float = 1.0,
mode_scale: float = 1.29,
base_image_seq: int = 256,
max_image_seq: int = 4096,
shift: float = 1.0,
base_shift: float = 0.5,
max_shift: float = 1.15,
sigma_eps: float = None,
device: str = "npu",
**kwargs,
):
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_linear_quadratic_schedule = use_linear_quadratic_schedule
self.device = device
self.shift = shift
self.use_dynamic_shifting = use_dynamic_shifting
self.weighting_scheme = weighting_scheme
self.logit_mean = logit_mean
self.logit_std = logit_std
self.mode_scale = mode_scale
if self.use_dynamic_shifting:
self.base_image_seq = base_image_seq
self.max_image_seq = max_image_seq
self.base_shift = base_shift
self.max_shift = max_shift
self.shift_k = (self.max_shift - self.base_shift) / (self.max_image_seq - self.base_image_seq)
self.shift_b = self.base_shift - self.shift_k * self.base_image_seq
sigma_eps = sigma_eps
if sigma_eps is not None:
if not (sigma_eps >= 0 and sigma_eps <= 1e-2):
raise ValueError("sigma_eps should be in the range of [0, 1e-2]")
else:
sigma_eps = 0.0
self._sigma_eps = sigma_eps
self._sigma_min = 0.0
self._sigma_max = 1.0
self.sigmas = None
@property
def sigma_eps(self):
return self._sigma_eps
@property
def sigma_min(self):
return self._sigma_min
@property
def sigma_max(self):
return self._sigma_max
@staticmethod
def add_noise(
sample: torch.FloatTensor,
sigmas: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.1
sigma (`float` or `torch.FloatTensor`):
sigma value in flow matching.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
sample_dtype = sample.dtype
sigmas = sigmas.float()
noise = noise.float()
sample = sample.float()
noised_sample = sigmas * noise + (1.0 - sigmas) * sample
noised_sample = noised_sample.to(sample_dtype)
return noised_sample
def compute_density_for_sigma_sampling(
self,
batch_size: int,
):
"""Compute the density for sampling the sigmas when doing SD3 training.
"""
if self.weighting_scheme == "logit_normal":
sigmas = torch.normal(mean=self.logit_mean, std=self.logit_std, size=(batch_size,), device="cpu")
sigmas = torch.nn.functional.sigmoid(sigmas)
elif self.weighting_scheme == "mode":
sigmas = torch.rand(size=(batch_size,), device="cpu")
sigmas = 1 - sigmas - self.mode_scale * (torch.cos(math.pi * sigmas / 2) ** 2 - 1 + sigmas)
else:
sigmas = torch.rand(size=(batch_size,), device="cpu")
return sigmas
def compute_loss_weighting_for_sd3(self, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
"""
if self.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif self.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def sigma_shift_opensoraplan(
self,
sigmas: Union[float, torch.Tensor],
image_seq_len: Optional[int] = None,
gamma: Optional[float] = 1.0,
):
if not self.use_dynamic_shifting:
sigmas_ = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
else:
if image_seq_len is None:
raise ValueError("you have to pass `image_seq_len` when `use_dynamic_shifting` is set to be `True`")
shift = image_seq_len * self.shift_k + self.shift_b
shift = math.exp(shift)
if math.isclose(gamma, 1.0):
sigmas_ = shift * sigmas / (1 + (shift - 1) * sigmas)
else:
sigmas_ = shift / (shift + (1 / sigmas - 1) ** gamma)
if isinstance(sigmas_, torch.Tensor):
sigmas_ = torch.where(sigmas_ > self.sigma_eps, sigmas_, torch.ones_like(sigmas_) * self.sigma_eps)
elif isinstance(sigmas_, np.ndarray):
sigmas_ = np.where(sigmas_ > self.sigma_eps, sigmas_, np.ones_like(sigmas_) * self.sigma_eps)
else:
sigmas_ = max(sigmas_, self.sigma_eps)
return sigmas_
def set_sigmas(
self,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
image_seq_len: Optional[int] = None,
inversion: Optional[bool] = False,
**kwargs,
):
if self.use_linear_quadratic_schedule:
print("use OpenSoraPlanScheduler and linear quadratic schedule")
approximate_steps = min(max(self.num_inference_steps * 10, 250), 1000)
sigmas = opensora_linear_quadratic_schedule(self.num_inference_steps, approximate_steps=approximate_steps)
sigmas = np.array(sigmas)
else:
if sigmas is None:
sigmas = np.linspace(self._sigma_max, self._sigma_min, self.num_inference_steps + 1)
if self.shift > 1.0 or self.use_dynamic_shifting:
print("use OpenSoraPlanScheduler and shifting schedule")
sigmas = self.sigma_shift_opensoraplan(sigmas, image_seq_len=image_seq_len)
if inversion:
sigmas = np.copy(np.flip(sigmas))
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.sigmas = sigmas
return sigmas
def step(
self,
model_output: torch.FloatTensor,
step_index: int,
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
):
if not (
isinstance(step_index, int)
or isinstance(step_index, torch.IntTensor)
or isinstance(step_index, torch.LongTensor)
):
raise ValueError("step_index should be an integer or a tensor of integer")
if not (step_index >= 0 and step_index < len(self.sigmas)):
raise ValueError("step_index should be in the range of [0, len(sigmas)]")
sample = sample.to(torch.float32)
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output
prev_sample = prev_sample.to(model_output.dtype)
return prev_sample
def training_losses(
self,
model_output: torch.Tensor,
x_start: torch.Tensor,
noise: torch.Tensor = None,
mask: torch.Tensor = None,
**kwargs
):
if torch.all(mask.bool()):
mask = None
b, c, _, _, _ = model_output.shape
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, c, 1, 1, 1).float()
mask = mask.reshape(b, -1)
weighting = self.compute_loss_weighting_for_sd3(sigmas=self.sigmas)
target = noise - x_start
loss_mse = (weighting.float() * (model_output.float() - target.float()) ** 2).reshape(target.shape[0], -1)
if mask is not None:
loss = (loss_mse * mask).sum() / mask.sum()
else:
loss = loss_mse.mean()
return loss
def q_sample(
self,
x_start: torch.Tensor,
sigmas: torch.Tensor = None,
noise: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param sigmas: interpolation factor in flow matching.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
b, c, _, _, _ = x_start.shape
if noise is None:
noise = torch.randn_like(x_start)
self.broadcast_tensor(noise)
if noise.shape != x_start.shape:
raise ValueError("The shape of noise and x_start must be equal.")
if sigmas is None:
sigmas = self.compute_density_for_sigma_sampling(b).to(x_start.device)
image_seq_len = (noise.shape[-1] * noise.shape[-2]) // 4 if self.use_dynamic_shifting else None
sigmas = self.sigma_shift_opensoraplan(sigmas, image_seq_len=image_seq_len)
timesteps = sigmas * 1000
while sigmas.ndim < x_start.ndim:
sigmas = sigmas.unsqueeze(-1)
self.broadcast_tensor(sigmas)
self.broadcast_tensor(timesteps)
self.sigmas = sigmas
x_t = self.add_noise(x_start, sigmas, noise)
return x_t, noise, timesteps
def sample(
self,
model: Callable,
shape: Union[List, Tuple],
latents: torch.Tensor,
model_kwargs: dict = None,
added_cond_kwargs: dict = None,
extra_step_kwargs: dict = None,
**kwargs
):
if not isinstance(shape, (tuple, list)):
raise AssertionError("param shape is incorrect")
if latents is None:
latents = torch.randn(*shape, device=self.device)
if added_cond_kwargs:
model_kwargs.update(added_cond_kwargs)
image_seq_len = (shape[-1] * shape[-2]) // 4 if self.use_dynamic_shifting else None
sigmas = self.set_sigmas(device=self.device, sigmas=None, image_seq_len=image_seq_len)
timesteps = sigmas.clone() * 1000
timesteps = timesteps[:-1]
do_classifier_free_guidance = self.guidance_scale > 1.0
encoder_hidden_states = model_kwargs.pop("prompt")
encoder_attention_mask = model_kwargs.pop("prompt_mask")
with tqdm(total=self.num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0])
attention_mask = torch.ones_like(latent_model_input)[:, 0].to(device=self.device)
noise_pred = model(
latent_model_input,
timestep=timestep,
prompt=encoder_hidden_states,
video_mask=attention_mask,
prompt_mask=encoder_attention_mask
)
if torch.any(torch.isnan(noise_pred)):
raise ValueError("noise_pred contains nan values")
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and self.guidance_rescale > 0.0:
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
latents = self.step(noise_pred, i, latents, **extra_step_kwargs)
if i == len(timesteps) - 1 or (i + 1) % self.order == 0:
progress_bar.update()
return latents
@staticmethod
def broadcast_tensor(input_: torch.Tensor):
cp_src_rank = list(mpu.get_context_parallel_global_ranks())[0]
if mpu.get_context_parallel_world_size() > 1:
dist.broadcast(input_, cp_src_rank, group=mpu.get_context_parallel_group())
tp_src_rank = mpu.get_tensor_model_parallel_src_rank()
if mpu.get_tensor_model_parallel_world_size() > 1:
dist.broadcast(input_, tp_src_rank, group=mpu.get_tensor_model_parallel_group())