import json
import os
import time
import argparse
from collections import deque
import math
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
from safetensors.torch import save_file
from torch.utils.data import DistributedSampler, DataLoader
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
from diffusers import get_scheduler
from tqdm.auto import tqdm
from accelerate.utils import set_seed
from mindspeed_mm.configs.config import mm_extra_args_provider, merge_mm_args
from mindspeed_mm.tasks.rl.soragrpo.dataset.latent_flux_rl_datasets import latent_collate_function
from mindspeed_mm.tasks.rl.soragrpo.utils.communications_flux import sp_parallel_dataloader_wrapper
from mindspeed_mm.tasks.rl.soragrpo.utils.fsdp_util import get_dit_fsdp_kwargs, apply_fsdp_checkpointing
from mindspeed_mm.tasks.rl.soragrpo.utils.parallel_states import initialize_sequence_parallel_state, \
get_sequence_parallel_state, destroy_sequence_parallel_group
class SoraGRPOTrainer(ABC):
def __init__(self, train_valid_test_dataset_provider):
self.local_rank = int(os.environ["LOCAL_RANK"])
dist.init_process_group("hccl")
torch.cuda.set_device(self.local_rank)
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
if self.local_world_size != int(torch.cuda.device_count()):
raise AssertionError(f"ASCEND_RT_VISIBLE_DEVICES which is {int(torch.cuda.device_count())} must specify the exact number of devices used per node which is {self.local_world_size}. "
"Please verify its value and whether the current devices are available.")
self.train_valid_test_dataset_provider = train_valid_test_dataset_provider
self.optimizer = None
self.lr_scheduler = None
self.args = self.get_args()
merge_mm_args(self.args)
self.device = torch.cuda.current_device()
self.hyper_model = None
self.gc_iteration = 5
initialize_sequence_parallel_state(self.args.sp_size)
def train(self):
import gc
gc.disable()
args = self.args
rank = self.rank
world_size = self.world_size
local_rank = self.local_rank
device = self.device
local_world_size = self.local_world_size
if world_size <= 8:
os.environ['HCCL_DETERMINISTIC'] = 'true'
elif world_size >= 16:
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
if args.seed is not None:
set_seed(args.seed + rank)
if rank <= 0 and args.save is not None:
os.makedirs(args.save, exist_ok=True)
transformer = None
load_rank_batchsize = args.load_rank
for start_rank in range(0, local_world_size, load_rank_batchsize):
end_rank = min(start_rank + load_rank_batchsize, world_size)
load_ranks = list(range(start_rank, end_rank))
if local_rank in load_ranks:
if local_rank % load_rank_batchsize == 0:
print(f"rank {load_ranks} load start")
self.hyper_model = self.model_provider(args)
transformer = self.hyper_model.diffuser
fsdp_kwargs, split_modules = get_dit_fsdp_kwargs(
self.hyper_model,
args.fsdp_sharding_strategy,
False,
args.use_cpu_offload,
args.master_weight_type,
)
self.hyper_model.diffuser = FSDP(transformer, **fsdp_kwargs, )
if args.gradient_checkpointing:
apply_fsdp_checkpointing(
transformer, split_modules, args.selective_checkpointing
)
if local_rank % load_rank_batchsize == 0:
print(f"rank {load_ranks} load success")
dist.barrier()
self.main_print(
f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_strategy}"
)
transformer.train()
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
self.main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=1000000,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
train_dataset = self.train_valid_test_dataset_provider(args)
sampler = DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=args.sampler_seed
)
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
total_batch_size = (
args.train_batch_size
* world_size
* args.gradient_accumulation_steps
/ args.sp_size
* args.train_sp_batch_size
)
self.main_print("***** Running training *****")
self.main_print(f" Num examples = {len(train_dataset)}")
self.main_print(f" Dataloader size = {len(train_dataloader)}")
self.main_print(f" Resume training from step {init_steps}")
self.main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
self.main_print(
f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
)
self.main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
self.main_print(f" Total optimization steps per epoch = {args.train_iters}")
self.main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
self.main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
progress_bar = tqdm(
range(0, 100000),
initial=init_steps,
desc="Steps",
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
for epoch in range(1):
if isinstance(sampler, DistributedSampler):
sampler.set_epoch(epoch)
for step in range(init_steps + 1, args.train_iters + 1):
start_time = time.time()
if args.save is not None and step % args.save_interval == 0:
self.save_checkpoint(transformer, rank, args.save, step, epoch)
dist.barrier()
loss, grad_norm = self.train_one_step(loader)
step_time = time.time() - start_time
step_times.append(step_time)
progress_bar.set_postfix(
{
"loss": f"{loss}",
"step_time": f"{step_time:.2f}s",
"grad_norm": grad_norm,
}
)
progress_bar.update(1)
if step % self.gc_iteration == 0:
gc.collect()
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
def train_one_step(self, dataloader):
device = self.device
hyper_model = self.hyper_model
args = self.args
max_grad_norm = args.max_grad_norm
optimizer = self.optimizer
lr_scheduler = self.lr_scheduler
total_loss = 0.0
optimizer.zero_grad()
samples_batched_list, train_timesteps, sigma_schedule, perms = self.sample_reference(dataloader)
for i, sample in list(enumerate(samples_batched_list)):
for j in range(train_timesteps):
clip_range = args.clip_range
adv_clip_max = args.adv_clip_max
new_log_probs = self.grpo_one_step(
sample,
perms[i][j],
sigma_schedule,
j
)
ratio = torch.exp(new_log_probs - sample["log_probs"][:, j])
advantages = torch.clamp(
sample["advantages"],
-adv_clip_max,
adv_clip_max,
)
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
ratio,
1.0 - clip_range,
1.0 + clip_range,
)
loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) / (
args.gradient_accumulation_steps * train_timesteps)
loss.backward()
avg_loss = loss.detach().clone()
dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
total_loss += avg_loss.item()
if dist.get_rank() % self.world_size == 0:
print("hps reward", sample["rewards"].item())
print("ratio", ratio)
print("advantage", sample["advantages"].item())
print("final loss", loss.item())
if (i + 1) % args.gradient_accumulation_steps == 0:
grad_norm = hyper_model.diffuser.clip_grad_norm_(max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
dist.barrier()
return total_loss, grad_norm.item()
@abstractmethod
def sample_reference(self, dataloader):
raise NotImplementedError("Subclasses must implement this method")
def get_args(self):
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--num_latent_t",
type=int,
default=1,
help="number of latent frames",
)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
parser.add_argument("--ema_decay", type=float, default=0.995)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.0)
parser.add_argument(
"--precondition_outputs",
action="store_true",
help="Whether to precondition the outputs of the model.",
)
parser.add_argument(
"--seed", type=int, default=None, help="A seed for reproducible training."
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--train-iters",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument(
"--max_grad_norm", default=2.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument("--fsdp_sharding_strategy", default="hybrid_full")
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant_with_warmup",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument(
"--weight-decay", type=float, default=0.01, help="Weight decay to apply."
)
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
parser.add_argument(
"--h",
type=int,
default=None,
help="video height",
)
parser.add_argument(
"--w",
type=int,
default=None,
help="video width",
)
parser.add_argument(
"--t",
type=int,
default=None,
help="video length",
)
parser.add_argument(
"--sampling_steps",
type=int,
default=None,
help="sampling steps",
)
parser.add_argument(
"--eta",
type=float,
default=None,
help="noise eta",
)
parser.add_argument(
"--sampler_seed",
type=int,
default=None,
help="seed of sampler",
)
parser.add_argument(
"--loss_coef",
type=float,
default=1.0,
help="the global loss should be divided by",
)
parser.add_argument(
"--use_group",
action="store_true",
default=False,
help="whether compute advantages for each prompt",
)
parser.add_argument(
"--num_generations",
type=int,
default=16,
help="num_generations per prompt",
)
parser.add_argument(
"--use_hpsv2",
action="store_true",
default=False,
help="whether use hpsv2 as reward model",
)
parser.add_argument(
"--ignore_last",
action="store_true",
default=False,
help="whether ignore last step of mdp",
)
parser.add_argument(
"--init_same_noise",
action="store_true",
default=False,
help="whether use the same noise within each prompt",
)
parser.add_argument(
"--shift",
type=float,
default=1.0,
help="shift for timestep scheduler",
)
parser.add_argument(
"--timestep_fraction",
type=float,
default=1.0,
help="timestep downsample ratio",
)
parser.add_argument(
"--clip_range",
type=float,
default=1e-4,
help="clip range for grpo",
)
parser.add_argument(
"--adv_clip_max",
type=float,
default=5.0,
)
parser.add_argument(
"--log-interval",
type=int,
help="print log train step interval",
)
parser.add_argument(
"--save-interval",
type=int,
help="save checkpoint train step interval",
)
parser.add_argument(
"--eval-interval",
type=int,
help="evaluation train step interval",
)
parser.add_argument(
"--eval-iters",
type=int,
help="evaluation iterations",
)
parser.add_argument(
"--save",
type=str,
default=None,
help="save checkpoint path",
)
parser.add_argument(
"--ckpt-format",
type=str,
default="torch",
help="save checkpoint format",
)
parser.add_argument(
"--distributed-backend",
type=str,
default="nccl",
help="distributed backend",
)
parser.add_argument(
"--load",
type=str,
required=True,
help="pretrained checkpoint path",
)
parser.add_argument(
"--hps_reward_save",
type=str,
help="hps reward save path",
)
parser.add_argument(
"--sample_batch_size",
type=int,
help="sample reference batch size",
)
parser.add_argument(
"--load_rank",
type=int,
default=8,
help="load rank batch size",
)
parser = mm_extra_args_provider(parser)
return parser.parse_args()
@abstractmethod
def model_provider(self, args):
raise NotImplementedError("Subclasses must implement this method")
@abstractmethod
def grpo_one_step(self, sample, perm, sigma_schedule, index):
raise NotImplementedError("Subclasses must implement this method")
def sd3_time_shift(self, shift, t):
return (shift * t) / (1 + (shift - 1) * t)
def gather_tensor(self, tensor):
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size()
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, tensor)
return torch.cat(gathered_tensors, dim=0)
@staticmethod
def assert_eq(x, y, msg=None):
if not x == y:
raise AssertionError(f"{msg} not equal")
def grpo_step(
self,
model_output: torch.Tensor,
latents: torch.Tensor,
sigmas: torch.Tensor,
prev_sample: torch.Tensor,
config: dict
):
grpo = config["grpo"]
sde_solver = config["sde_solver"]
eta = config["eta"]
index = config["index"]
sigma = sigmas[index]
dsigma = sigmas[index + 1] - sigma
prev_sample_mean = latents + dsigma * model_output
pred_original_sample = latents - sigma * model_output
delta_t = sigma - sigmas[index + 1]
std_dev_t = eta * math.sqrt(delta_t)
if sde_solver:
score_estimate = -(latents - pred_original_sample * (1 - sigma)) / sigma ** 2
log_term = -0.5 * eta ** 2 * score_estimate
prev_sample_mean = prev_sample_mean + log_term * dsigma
if grpo and prev_sample is None:
prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t
if grpo:
log_prob = (
-((prev_sample.detach().to(torch.float32) - prev_sample_mean.to(torch.float32)) ** 2) / (
2 * (std_dev_t ** 2))
)
- math.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, pred_original_sample, log_prob
else:
return prev_sample_mean, pred_original_sample
def save_checkpoint(self, transformer, rank, output_dir, step, epoch):
self.main_print(f"--> saving checkpoint at step {step}")
with FSDP.state_dict_type(
transformer,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = transformer.state_dict()
if rank <= 0:
save_dir = os.path.join(output_dir, f"checkpoint-{step}-{epoch}")
os.makedirs(save_dir, exist_ok=True)
weight_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
config_dict = dict(transformer.config)
if "dtype" in config_dict:
del config_dict["dtype"]
config_path = os.path.join(save_dir, "config.json")
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
self.main_print(f"--> checkpoint saved at step {step}")
def main_print(self, content):
if int(os.environ["LOCAL_RANK"]) <= 0:
print(content)