import os
import torch
import torch.distributed as dist
from diffusers.image_processor import VaeImageProcessor
from tqdm.auto import tqdm
from mindspeed_mm.tasks.rl.soragrpo.sora_grpo_trainer import SoraGRPOTrainer
from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_model import FluxGRPOModel
class FluxGRPOTrainer(SoraGRPOTrainer):
def model_provider(self, args):
return FluxGRPOModel(args, device=self.device)
def grpo_one_step(self, sample, perm, sigma_schedule, index):
args = self.args
latents = sample["latents"][:, index]
pre_latents = sample["next_latents"][:, index]
encoder_hidden_states = sample["encoder_hidden_states"]
pooled_prompt_embeds = sample["pooled_prompt_embeds"]
text_ids = sample["text_ids"]
image_ids = sample["image_ids"]
transformer = self.hyper_model.diffuser
timesteps = sample["timesteps"][:, index]
transformer.train()
with torch.autocast("cuda", torch.bfloat16):
pred = transformer(
hidden_states=latents,
encoder_hidden_states=encoder_hidden_states,
timestep=timesteps / 1000,
guidance=torch.tensor(
[3.5],
device=latents.device,
dtype=torch.bfloat16
),
txt_ids=text_ids.repeat(encoder_hidden_states.shape[1], 1),
pooled_projections=pooled_prompt_embeds,
img_ids=image_ids.squeeze(0),
joint_attention_kwargs=None,
return_dict=False,
)[0]
config = {}
config["grpo"] = True
config["sde_solver"] = True
config["eta"] = args.eta
config["index"] = perm
z, pred_original, log_prob = self.grpo_step(pred, latents.to(torch.float32), sigma_schedule,
pre_latents.to(torch.float32), config)
return log_prob
def sample_reference(self, dataloader):
(
encoder_hidden_states,
pooled_prompt_embeds,
text_ids,
caption,
) = next(dataloader)
args = self.args
if args.use_group:
def repeat_tensor(tensor):
if tensor is None:
return None
return torch.repeat_interleave(tensor, args.num_generations, dim=0)
encoder_hidden_states = repeat_tensor(encoder_hidden_states)
pooled_prompt_embeds = repeat_tensor(pooled_prompt_embeds)
text_ids = repeat_tensor(text_ids)
if isinstance(caption, str):
caption = [caption] * args.num_generations
elif isinstance(caption, list):
caption = [
item
for item in caption
for _ in range(args.num_generations)
]
else:
raise ValueError(f"Unsupported caption type: {type(caption)}")
reward, all_latents, all_log_probs, sigma_schedule, all_image_ids = self.sample_reference_model(
args,
caption,
encoder_hidden_states,
pooled_prompt_embeds,
text_ids
)
batch_size = all_latents.shape[0]
timestep_value = [int(sigma * 1000) for sigma in sigma_schedule][:args.sampling_steps]
timestep_values = [timestep_value[:] for _ in range(batch_size)]
device = all_latents.device
timesteps = torch.tensor(timestep_values, device=all_latents.device, dtype=torch.long)
samples = {
"timesteps": timesteps.detach().clone()[:, :-1],
"latents": all_latents[
:, :-1
][:, :-1],
"next_latents": all_latents[
:, 1:
][:, :-1],
"log_probs": all_log_probs[:, :-1],
"rewards": reward.to(torch.float32),
"image_ids": all_image_ids,
"text_ids": text_ids,
"encoder_hidden_states": encoder_hidden_states,
"pooled_prompt_embeds": pooled_prompt_embeds,
}
gathered_reward = self.gather_tensor(samples["rewards"])
if dist.get_rank() == 0:
print("gathered_hps_reward", gathered_reward)
print("gathered_hps_reward_mean=", gathered_reward.mean().item())
with open(args.hps_reward_save, 'a') as f:
f.write(f"{gathered_reward.mean().item()}\n")
if args.use_group:
n = len(samples["rewards"]) // (args.num_generations)
advantages = torch.zeros_like(samples["rewards"])
for i in range(n):
start_idx = i * args.num_generations
end_idx = (i + 1) * args.num_generations
group_rewards = samples["rewards"][start_idx:end_idx]
group_mean = group_rewards.mean()
group_std = group_rewards.std() + 1e-8
advantages[start_idx:end_idx] = (group_rewards - group_mean) / group_std
samples["advantages"] = advantages
else:
advantages = (samples["rewards"] - gathered_reward.mean()) / (gathered_reward.std() + 1e-8)
samples["advantages"] = advantages
perms = torch.stack(
[
torch.randperm(len(samples["timesteps"][0]))
for _ in range(batch_size)
]
).to(device)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][
torch.arange(batch_size).to(device)[:, None],
perms,
]
samples_batched = {
k: v.unsqueeze(1)
for k, v in samples.items()
}
samples_batched_list = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())]
train_timesteps = int(len(samples["timesteps"][0]) * args.timestep_fraction)
return samples_batched_list, train_timesteps, sigma_schedule, perms
def sample_reference_model(self, args, caption, encoder_hidden_states, pooled_prompt_embeds, text_ids):
transformer = self.hyper_model.diffuser
vae = self.hyper_model.ae
reward_model = self.hyper_model.reward['model']
tokenizer = self.hyper_model.reward['processor']
preprocess_val = self.hyper_model.reward['preprocess_val']
device = self.device
w, h, t = args.w, args.h, args.t
sample_steps = args.sampling_steps
sigma_schedule = torch.linspace(1, 0, args.sampling_steps + 1)
sigma_schedule = self.sd3_time_shift(args.shift, sigma_schedule)
FluxGRPOTrainer.assert_eq(
len(sigma_schedule),
sample_steps + 1,
"sigma_schedule must have length sample_steps + 1",
)
B = encoder_hidden_states.shape[0]
batch_size = args.sample_batch_size
batch_indices = torch.chunk(torch.arange(B), B // batch_size)
SPATIAL_DOWNSAMPLE = 8
IN_CHANNELS = 16
latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
if args.init_same_noise:
input_latents = torch.randn(
(1, IN_CHANNELS, latent_h, latent_w),
dtype=torch.bfloat16,
).repeat(batch_size, 1, 1, 1).to(device)
all_latents = []
all_log_probs = []
all_rewards = []
all_image_ids = []
for batch_idx in batch_indices:
if not args.init_same_noise:
input_latents = torch.randn(
(len(batch_idx), IN_CHANNELS, latent_h, latent_w),
device=device,
dtype=torch.bfloat16,
)
progress_bar = tqdm(range(0, sample_steps), desc="Sampling Progress")
image_ids = self.prepare_latent_image_ids(len(batch_idx), latent_h // 2, latent_w // 2, device,
torch.bfloat16)
with torch.no_grad():
pack_input_latents = self.pack_latents(input_latents, len(batch_idx), IN_CHANNELS, latent_h, latent_w)
sample_input = {
"pack_input_latents": pack_input_latents,
"sigma_schedule": sigma_schedule,
"encoder_hidden_states": encoder_hidden_states[batch_idx],
"pooled_prompt_embeds": pooled_prompt_embeds[batch_idx],
"text_ids": text_ids[batch_idx],
"image_ids": image_ids
}
z, latents, batch_latents, batch_log_probs = self.run_sample_step(args, progress_bar, transformer,
sample_input)
for _ in range(batch_size):
all_image_ids.append(image_ids)
all_latents.append(batch_latents)
all_log_probs.append(batch_log_probs)
vae.enable_tiling()
rank = int(os.environ["RANK"])
with torch.inference_mode():
with torch.autocast("cuda", dtype=torch.bfloat16):
latents = self.unpack_latents(latents, h, w, 8)
latents = (latents / 0.3611) + 0.1159
image = vae.decode(latents, return_dict=False)[0]
image_processor = VaeImageProcessor(16)
batch_decoded_images = image_processor.postprocess(image)
for idx, image in zip(batch_idx, batch_decoded_images):
image.save(f"./images/flux_{rank}_{idx}.png")
batch_caption = [caption[i] for i in batch_idx]
if args.use_hpsv2:
with torch.inference_mode():
for i, decoded_image in enumerate(batch_decoded_images):
image_path = decoded_image
image = preprocess_val(image_path).unsqueeze(0).to(device=device, non_blocking=True)
text = tokenizer([batch_caption[i]]).to(device=device, non_blocking=True)
with torch.amp.autocast('cuda'):
outputs = reward_model(image, text)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits_per_image = image_features @ text_features.T
hps_score = torch.diagonal(logits_per_image)
all_rewards.append(hps_score)
all_latents = torch.cat(all_latents, dim=0)
all_log_probs = torch.cat(all_log_probs, dim=0)
all_rewards = torch.cat(all_rewards, dim=0)
all_image_ids = torch.stack(all_image_ids, dim=0)
return all_rewards, all_latents, all_log_probs, sigma_schedule, all_image_ids
def run_sample_step(self, args, progress_bar, transformer, sample_input):
z = sample_input["pack_input_latents"]
sigma_schedule = sample_input["sigma_schedule"]
encoder_hidden_states = sample_input["encoder_hidden_states"]
pooled_prompt_embeds = sample_input["pooled_prompt_embeds"]
text_ids = sample_input["text_ids"]
image_ids = sample_input["image_ids"]
all_latents = [z]
all_log_probs = []
for i in progress_bar:
sigma = sigma_schedule[i]
timestep_value = int(sigma * 1000)
timesteps = torch.full([encoder_hidden_states.shape[0]], timestep_value, device=z.device,
dtype=torch.long)
transformer.eval()
with torch.autocast("cuda", torch.bfloat16):
expanded_text_ids = text_ids.unsqueeze(1)
expanded_text_ids = expanded_text_ids.repeat(1, encoder_hidden_states.shape[1], 1)
pred = transformer(
hidden_states=z,
encoder_hidden_states=encoder_hidden_states,
timestep=timesteps / 1000,
guidance=torch.tensor(
[3.5],
device=z.device,
dtype=torch.bfloat16
),
txt_ids=expanded_text_ids[0],
pooled_projections=pooled_prompt_embeds,
img_ids=image_ids,
joint_attention_kwargs=None,
return_dict=False,
)[0]
config = {}
config["grpo"] = True
config["sde_solver"] = True
config["eta"] = args.eta
config["index"] = i
z, pred_original, log_prob = self.grpo_step(pred, z.to(torch.float32), sigma_schedule, None, config)
z.to(torch.bfloat16)
all_latents.append(z)
all_log_probs.append(log_prob)
latents = pred_original
all_latents = torch.stack(all_latents, dim=1)
all_log_probs = torch.stack(all_log_probs, dim=1)
return z, latents, all_latents, all_log_probs
def pack_latents(self, latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def unpack_latents(self, latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def prepare_latent_image_ids(self, batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)