import os
from copy import deepcopy
from datetime import timedelta
from pprint import pformat
import time
import gc
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device, set_seed
from tqdm import tqdm
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.lr_scheduler import LinearWarmupLR
from opensora.utils.misc import (
Timer,
all_reduce_mean,
create_logger,
create_tensorboard_writer,
format_numel_str,
get_model_numel,
requires_grad,
to_torch_dtype,
)
from opensora.utils.train_utils import MaskGenerator, create_colossalai_plugin, update_ema
from opensora.utils.device_utils import is_npu_available
if is_npu_available():
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
def main():
cfg = parse_configs(training=True)
gc.set_threshold(700, 10, 1000)
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
coordinator = DistCoordinator()
device = get_current_device()
exp_name, exp_dir = define_experiment_workspace(cfg)
coordinator.block_all()
if coordinator.is_master():
os.makedirs(exp_dir, exist_ok=True)
save_training_config(cfg.to_dict(), exp_dir)
coordinator.block_all()
logger = create_logger(exp_dir)
logger.info("Experiment directory created at %s", exp_dir)
logger.info("Training configuration:\n %s", pformat(cfg.to_dict()))
if coordinator.is_master():
tb_writer = create_tensorboard_writer(exp_dir)
if cfg.get("wandb", False):
wandb.init(project="Open-Sora", name=exp_name, config=cfg.to_dict(), dir="./outputs/wandb")
plugin = create_colossalai_plugin(
plugin=cfg.get("plugin", "zero2"),
dtype=cfg_dtype,
grad_clip=cfg.get("grad_clip", 0),
sp_size=cfg.get("sp_size", 1),
)
booster = Booster(plugin=plugin)
torch.set_num_threads(1)
logger.info("Building dataset...")
dataset = build_module(cfg.dataset, DATASETS)
logger.info("Dataset contains %s samples.", len(dataset))
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
num_workers=cfg.get("num_workers", 4),
seed=cfg.get("seed", 1024),
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
num_bucket_build_workers=cfg.get("num_bucket_build_workers", 1),
**dataloader_args,
)
num_steps_per_epoch = len(dataloader)
logger.info("Building models...")
text_encoder = build_module(cfg.get("text_encoder", None), MODELS, device=device, dtype=dtype)
if text_encoder is not None:
text_encoder_output_dim = text_encoder.output_dim
text_encoder_model_max_length = text_encoder.model_max_length
else:
text_encoder_output_dim = cfg.get("text_encoder_output_dim", 4096)
text_encoder_model_max_length = cfg.get("text_encoder_model_max_length", 300)
vae = build_module(cfg.get("vae", None), MODELS)
if vae is not None:
vae = vae.to(device, dtype).eval()
if vae is not None:
input_size = (dataset.num_frames, *dataset.image_size)
latent_size = vae.get_latent_size(input_size)
vae_out_channels = vae.out_channels
else:
latent_size = (None, None, None)
vae_out_channels = cfg.get("vae_out_channels", 4)
model = (
build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae_out_channels,
caption_channels=text_encoder_output_dim,
model_max_length=text_encoder_model_max_length,
enable_sequence_parallelism=cfg.get("sp_size", 1) > 1,
)
.to(device, dtype)
.train()
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
"[Diffusion] Trainable model params: %s, Total model params: %s",
format_numel_str(model_numel_trainable),
format_numel_str(model_numel),
)
ema = deepcopy(model).to(torch.float32).to(device)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
update_ema(ema, model, decay=0, sharded=False)
scheduler = build_module(cfg.scheduler, SCHEDULERS)
if is_npu_available():
from mindspeed.optimizer.adamw import AdamW
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
eps=cfg.get("adam_eps", 1e-8)
)
else:
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.get("lr", 1e-4),
weight_decay=cfg.get("weight_decay", 0),
adamw_mode=True,
eps=cfg.get("adam_eps", 1e-8)
)
warmup_steps = cfg.get("warmup_steps", None)
if warmup_steps is None:
lr_scheduler = None
else:
lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=cfg.get("warmup_steps"))
if cfg.get("grad_checkpoint", False):
set_grad_checkpoint(model)
if cfg.get("mask_ratios", None) is not None:
mask_generator = MaskGenerator(cfg.mask_ratios)
logger.info("Preparing for distributed training...")
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
logger.info("Boosting model for distributed training")
cfg_epochs = cfg.get("epochs", 1000)
start_epoch = start_step = log_step = acc_step = 0
running_loss = 0.0
logger.info("Training for %s epochs with %s steps per epoch", cfg_epochs, num_steps_per_epoch)
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
ret = load(
booster,
cfg.load,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=None if cfg.get("start_from_scratch", False) else sampler,
)
if not cfg.get("start_from_scratch", False):
start_epoch, start_step = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
model_sharding(ema)
dist.barrier()
early_stop_flag = False
for epoch in range(start_epoch, cfg_epochs):
sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
if early_stop_flag:
break
with tqdm(
enumerate(dataloader_iter, start=start_step),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
initial=start_step,
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
start_step_time = time.time()
timer_list = []
with Timer("move data") as move_data_t:
x = batch.pop("video").to(device, dtype)
y = batch.pop("text")
timer_list.append(move_data_t)
with Timer("encode") as encode_t:
with torch.no_grad():
if cfg.get("load_video_features", False):
x = x.to(device, dtype)
else:
x = vae.encode(x)
if cfg.get("load_text_features", False):
model_args = {"y": y.to(device, dtype)}
mask = batch.pop("mask")
if isinstance(mask, torch.Tensor):
mask = mask.to(device, dtype)
model_args["mask"] = mask
else:
model_args = text_encoder.encode(y)
timer_list.append(encode_t)
with Timer("mask") as mask_t:
mask = None
if cfg.get("mask_ratios", None) is not None:
mask = mask_generator.get_masks(x)
model_args["x_mask"] = mask
timer_list.append(mask_t)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)
with Timer("diffusion") as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
timer_list.append(loss_t)
with Timer("backward") as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
timer_list.append(backward_t)
with Timer("optimizer") as optimizer_t:
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
timer_list.append(optimizer_t)
with Timer("update_ema") as ema_t:
update_ema(ema, model.module, optimizer=optimizer, decay=cfg.get("ema_decay", 0.9999))
timer_list.append(ema_t)
with Timer("reduce_loss") as reduce_loss_t:
all_reduce_mean(loss)
running_loss += loss.item()
global_step = epoch * num_steps_per_epoch + step
log_step += 1
acc_step += 1
timer_list.append(reduce_loss_t)
train_step_time = time.time()
if coordinator.is_master() and (global_step + 1) % cfg.get("log_every", 1) == 0:
logger.info(
f"E2E train time {train_step_time - start_step_time} | loss {loss.item()}")
avg_loss = running_loss / log_step
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
tb_writer.add_scalar("loss", loss.item(), global_step)
if cfg.get("wandb", False):
wandb.log(
{
"iter": global_step,
"acc_step": acc_step,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
"lr": optimizer.param_groups[0]["lr"],
"debug/move_data_time": move_data_t.elapsed_time,
"debug/encode_time": encode_t.elapsed_time,
"debug/mask_time": mask_t.elapsed_time,
"debug/diffusion_time": loss_t.elapsed_time,
"debug/backward_time": backward_t.elapsed_time,
"debug/update_ema_time": ema_t.elapsed_time,
"debug/reduce_loss_time": reduce_loss_t.elapsed_time,
},
step=global_step,
)
running_loss = 0.0
log_step = 0
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
model_gathering(ema, ema_shape_dict)
save_dir = save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
)
if dist.get_rank() == 0:
model_sharding(ema)
logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
global_step + 1,
save_dir,
)
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list:
log_str += f"{timer.name}: {timer.elapsed_time:.3f}s | "
if cfg.get("print_details", False):
print(log_str)
if cfg.get("max_train_steps", 0) > 0 and global_step == cfg.get("max_train_steps", 0):
early_stop_flag = True
break
sampler.reset()
start_step = 0
if __name__ == "__main__":
main()