import os
import json
import torch
from diffusers import FluxPipeline
from tqdm import tqdm
import torch.distributed as dist
import mindspeed.megatron_adaptor
from mindspeed_mm.tasks.rl.soragrpo.preprocess.data_preprocess import DataPreprocess
class FluxDataPreprocess(DataPreprocess):
def __init__(self):
super().__init__()
def preprocess(self):
args = self.args
local_rank = self.rank
world_size = self.world_size
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "prompt_embed"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "text_ids"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "pooled_prompt_embeds"), exist_ok=True)
pipe = FluxPipeline.from_pretrained(args.load, torch_dtype=torch.bfloat16).to(self.device)
json_data = []
for _, data in tqdm(enumerate(self.dataloader), disable=local_rank != 0):
with torch.inference_mode():
for idx, video_name in enumerate(data["filename"]):
prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
prompt=data["caption"], prompt_2=data["caption"]
)
prompt_embed_path = os.path.join(args.output_dir, "prompt_embed", video_name + ".pt")
pooled_prompt_embeds_path = os.path.join(args.output_dir, "pooled_prompt_embeds",
video_name + ".pt")
text_ids_path = os.path.join(args.output_dir, "text_ids", video_name + ".pt")
torch.save(prompt_embeds[idx], prompt_embed_path)
torch.save(pooled_prompt_embeds[idx], pooled_prompt_embeds_path)
torch.save(text_ids[idx], text_ids_path)
item = {}
item["prompt_embed_path"] = video_name + ".pt"
item["text_ids"] = video_name + ".pt"
item["pooled_prompt_embeds_path"] = video_name + ".pt"
item["caption"] = data["caption"][idx]
json_data.append(item)
dist.barrier()
local_data = json_data
gathered_data = [None] * world_size
dist.all_gather_object(gathered_data, local_data)
if local_rank == 0:
all_json_data = [
item
for sublist in gathered_data
for item in sublist
]
with open(os.path.join(args.output_dir, "videos2caption.json"), "w") as f:
json.dump(all_json_data, f, indent=4)
if __name__ == '__main__':
data_preprocess = FluxDataPreprocess()
data_preprocess.preprocess()