"""
A minimal training script for DiT using PyTorch DDP.
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import argparse
import logging
import os
from models import DiT_models
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from utils.adamw import AdamW
from utils.device_utils import is_npu_available
if is_npu_available():
from torch_npu.contrib import transfer_to_npu
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def cleanup():
"""
End DDP training.
"""
dist.destroy_process_group()
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0:
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
else:
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def main(args):
"""
Trains a new DiT model.
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
import gc
gc.set_threshold(700, 10, 1000)
dist.init_process_group("nccl")
assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
if rank == 0:
os.makedirs(args.results_dir, exist_ok=True)
experiment_index = len(glob(f"{args.results_dir}/*"))
model_string_name = args.model.replace("/", "-")
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}"
checkpoint_dir = f"{experiment_dir}/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
else:
logger = create_logger(None)
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.image_size // 8
model = DiT_models[args.model](
input_size=latent_size,
num_classes=args.num_classes
)
if args.precision == "bf16":
print("Enable bfloat16...")
model.to(torch.bfloat16)
ema = deepcopy(model).to(device)
requires_grad(ema, False)
model = DDP(model.to(device), device_ids=[rank])
diffusion = create_diffusion(timestep_respacing="")
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
opt = AdamW(model.parameters(), lr=1e-4, weight_decay=0, betas=(0.9, 0.998))
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = ImageFolder(args.data_path, transform=transform)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=rank,
shuffle=True,
seed=args.global_seed
)
loader = DataLoader(
dataset,
batch_size=int(args.global_batch_size // dist.get_world_size()),
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
update_ema(ema, model.module, decay=0)
model.train()
ema.eval()
train_steps = 0
log_steps = 0
running_loss = 0
start_time = time()
logger.info(f"Training for {args.epochs} epochs...")
for epoch in range(args.epochs):
sampler.set_epoch(epoch)
logger.info(f"Beginning epoch {epoch}...")
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
model_kwargs = dict(y=y)
loss_dict = diffusion.training_losses(model, x, t, args.precision, model_kwargs)
loss = loss_dict["loss"].mean()
opt.zero_grad()
loss.backward()
opt.step()
update_ema(ema, model.module)
running_loss += loss.item()
log_steps += 1
train_steps += 1
if train_steps % args.log_every == 0:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = log_steps / (end_time - start_time)
avg_loss = torch.tensor(running_loss / log_steps, device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / dist.get_world_size()
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
running_loss = 0
log_steps = 0
start_time = time()
if train_steps % args.ckpt_every == 0 and train_steps > 0:
if rank == 0:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": opt.state_dict(),
"args": args
}
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
dist.barrier()
model.eval()
logger.info("Done!")
cleanup()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--global-batch-size", type=int, default=256)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--log-every", type=int, default=1)
parser.add_argument("--ckpt-every", type=int, default=50_000)
parser.add_argument("--precision", type=str, choices=["fp32", "bf16"], default="fp32")
args = parser.parse_args()
main(args)