from copy import deepcopy
import time
import colossalai
import torch
import torch.distributed as dist
import wandb
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from tqdm import tqdm
import adaptor
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import (
get_data_parallel_group,
set_data_parallel_group,
set_sequence_parallel_group,
initialize_sequence_parallel_group_for_send_recv_overlap,
get_sequence_parallel_group_for_send_recv_overlap,
)
from opensora.acceleration.plugin import ZeroSeqParallelPlugin
from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save
from opensora.utils.config_utils import (
create_experiment_workspace,
create_tensorboard_writer,
parse_configs,
save_training_config,
)
from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype
from opensora.utils.train_utils import update_ema, AdamW
from opensora.utils.device_utils import is_npu_available
if is_npu_available():
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
def main():
cfg = parse_configs(training=True)
print(cfg)
exp_name, exp_dir = create_experiment_workspace(cfg)
save_training_config(cfg._cfg_dict, exp_dir)
import gc
gc.set_threshold(700, 10, 1000)
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}"
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
device = get_current_device()
dtype = to_torch_dtype(cfg.dtype)
if not coordinator.is_master():
logger = create_logger(None)
else:
logger = create_logger(exp_dir)
logger.info(f"Experiment directory created at {exp_dir}")
writer = create_tensorboard_writer(exp_dir)
if cfg.wandb:
wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict)
if cfg.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
set_data_parallel_group(dist.group.WORLD)
elif cfg.plugin == "zero2-seq":
plugin = ZeroSeqParallelPlugin(
sp_size=cfg.sp_size,
stage=1,
precision=cfg.dtype,
initial_scale=2**16,
max_norm=cfg.grad_clip,
)
initialize_sequence_parallel_group_for_send_recv_overlap(cfg.use_cp_send_recv_overlap, cfg.sp_size)
set_sequence_parallel_group(plugin.sp_group)
set_data_parallel_group(plugin.dp_group)
else:
raise ValueError(f"Unknown plugin {cfg.plugin}")
booster = Booster(plugin=plugin)
dataset = DatasetFromCSV(
cfg.data_path,
transform=(
get_transforms_video(cfg.image_size[0])
if not cfg.use_image_transform
else get_transforms_image(cfg.image_size[0])
),
num_frames=cfg.num_frames,
frame_interval=cfg.frame_interval,
root=cfg.root,
)
dataloader = prepare_dataloader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
)
logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})")
total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size
logger.info(f"Total batch size: {total_batch_size}")
input_size = (cfg.num_frames, *cfg.image_size)
vae = build_module(cfg.vae, MODELS)
latent_size = vae.get_latent_size(input_size)
text_encoder = build_module(cfg.text_encoder, MODELS, device=device)
model = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
dtype=dtype,
)
model_numel, model_numel_trainable = get_model_numel(model)
logger.info(
f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}"
)
model_state_dict = model.state_dict()
ema = build_module(
cfg.model,
MODELS,
input_size=latent_size,
in_channels=vae.out_channels,
caption_channels=text_encoder.output_dim,
model_max_length=text_encoder.model_max_length,
dtype=dtype,
)
ema = ema.to(torch.float32).to(device)
ema.load_state_dict(model_state_dict)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
vae = vae.to(device, dtype)
model = model.to(device, dtype)
scheduler = build_module(cfg.scheduler, SCHEDULERS)
if is_npu_available():
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0
)
else:
optimizer = HybridAdam(
filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True
)
lr_scheduler = None
if cfg.grad_checkpoint:
set_grad_checkpoint(model)
model.train()
update_ema(ema, model, decay=0, sharded=False)
ema.eval()
torch.set_default_dtype(dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader
)
torch.set_default_dtype(torch.float)
num_steps_per_epoch = len(dataloader)
logger.info("Boost model for distributed training")
start_epoch = start_step = log_step = sampler_start_idx = 0
running_loss = 0.0
if cfg.load is not None:
logger.info("Loading checkpoint")
start_epoch, start_step, sampler_start_idx = load(booster, model, ema, optimizer, lr_scheduler, cfg.load)
logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}")
logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch")
dataloader.sampler.set_start_index(sampler_start_idx)
model_sharding(ema)
if "global_batch_size" in cfg:
if cfg.global_batch_size % total_batch_size == 0:
grad_acum_step = cfg.global_batch_size / total_batch_size
else:
raise ValueError("total_batch_size needs to be divisible by global_batch_size.")
else:
cfg.global_batch_size = total_batch_size
grad_acum_step = 1
early_stopping_flag = False
grad_acum_time = 0
for epoch in range(start_epoch, cfg.epochs):
dataloader.sampler.set_epoch(epoch)
dataloader_iter = iter(dataloader)
logger.info(f"Beginning epoch {epoch}...")
if early_stopping_flag:
break
with tqdm(
range(start_step, num_steps_per_epoch),
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step,
) as pbar:
for step in pbar:
global_step = epoch * num_steps_per_epoch + step
start_step_time = time.time()
batch = next(dataloader_iter)
data_step_time = time.time()
x = batch["video"].to(device, dtype)
y = batch["text"]
with torch.no_grad():
x = vae.encode(x)
model_args = text_encoder.encode(y)
t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
loss_dict = scheduler.training_losses(model, x, t, model_args)
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
if global_step % grad_acum_step == 0:
optimizer.step()
optimizer.zero_grad()
update_ema(ema, model.module, optimizer=optimizer)
all_reduce_mean(loss)
running_loss += loss.item()
log_step += 1
train_step_time = time.time()
if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0:
print(f"data time {data_step_time - start_step_time} | E2E train time {train_step_time - start_step_time} | "
f"FPS {total_batch_size / (train_step_time - start_step_time)}", flush=True)
if grad_acum_step != 1:
grad_acum_time += (train_step_time - start_step_time)
if global_step % grad_acum_step == 0:
print(f"gradient accumulation step time {grad_acum_time}")
grad_acum_time = 0
avg_loss = running_loss / log_step
pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
running_loss = 0
log_step = 0
writer.add_scalar("loss", loss.item(), global_step)
if cfg.wandb:
wandb.log(
{
"iter": global_step,
"num_samples": global_step * total_batch_size,
"epoch": epoch,
"loss": loss.item(),
"avg_loss": avg_loss,
},
step=global_step,
)
if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0:
save(
booster,
model,
ema,
optimizer,
lr_scheduler,
epoch,
step + 1,
global_step + 1,
cfg.batch_size,
coordinator,
exp_dir,
ema_shape_dict,
)
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
)
if cfg.max_train_steps > 0 and global_step == cfg.max_train_steps:
early_stopping_flag = True
break
dataloader.sampler.set_start_index(0)
start_step = 0
if __name__ == "__main__":
main()