import os
import argparse
import csv
import json
import time
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from mindiesd import CacheAgent, CacheConfig
from FLUX1dev import FluxPipeline
from FLUX1dev import get_local_rank, get_world_size, initialize_torch_distributed
from FLUX1dev.utils import check_prompts_valid, check_param_valid, check_dir_safety, check_file_safety
from transformers import T5EncoderModel
torch_npu.npu.set_compile_mode(jit_compile=False)
class PromptLoader:
def __init__(
self,
prompt_file: str,
prompt_file_type: str,
batch_size: int = 1,
num_images_per_prompt: int = 1,
max_num_prompts: int = 0
):
self.check_input_isvalid(batch_size, num_images_per_prompt, max_num_prompts)
self.prompts = []
self.catagories = ['Not_specified']
self.batch_size = batch_size
self.num_images_per_prompt = num_images_per_prompt
if prompt_file_type == 'plain':
self.load_prompts_plain(prompt_file, max_num_prompts)
elif prompt_file_type == 'parti':
self.load_prompts_parti(prompt_file, max_num_prompts)
elif prompt_file_type == 'hpsv2':
self.load_prompts_hpsv2(max_num_prompts)
else:
print("This operation is not supported!")
self.current_id = 0
self.inner_id = 0
def __len__(self):
return len(self.prompts) * self.num_images_per_prompt
def __iter__(self):
return self
def __next__(self):
if self.current_id == len(self.prompts):
raise StopIteration
ret = {
'prompts': [],
'catagories': [],
'save_names': [],
'n_prompts': self.batch_size,
}
for _ in range(self.batch_size):
if self.current_id == len(self.prompts):
ret['prompts'].append('')
ret['save_names'].append('')
ret['catagories'].append('')
ret['n_prompts'] -= 1
else:
prompt, catagory_id = self.prompts[self.current_id]
ret['prompts'].append(prompt)
ret['catagories'].append(self.catagories[catagory_id])
ret['save_names'].append(f'{self.current_id}_{self.inner_id}')
self.inner_id += 1
if self.inner_id == self.num_images_per_prompt:
self.inner_id = 0
self.current_id += 1
return ret
def load_prompts_plain(self, file_path: str, max_num_prompts: int):
with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f:
for i, line in enumerate(f):
if max_num_prompts and i == max_num_prompts:
break
prompt = line.strip()
self.prompts.append((prompt, 0))
def load_prompts_parti(self, file_path: str, max_num_prompts: int):
with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f:
next(f)
tsv_file = csv.reader(f, delimiter="\t")
for i, line in enumerate(tsv_file):
if max_num_prompts and i == max_num_prompts:
break
prompt = line[0]
catagory = line[1]
if catagory not in self.catagories:
self.catagories.append(catagory)
catagory_id = self.catagories.index(catagory)
self.prompts.append((prompt, catagory_id))
def load_prompts_hpsv2(self, max_num_prompts: int):
with open('hpsv2_benchmark_prompts.json', 'r') as file:
all_prompts = json.load(file)
count = 0
for style, prompts in all_prompts.items():
for prompt in prompts:
count += 1
if max_num_prompts and count >= max_num_prompts:
break
if style not in self.catagories:
self.catagories.append(style)
catagory_id = self.catagories.index(style)
self.prompts.append((prompt, catagory_id))
def check_input_isvalid(self, batch_size, num_images_per_prompt, max_num_prompts):
if batch_size <= 0:
raise ValueError(f"Param batch_size invalid, expected positive value, but get {batch_size}")
if num_images_per_prompt <= 0:
raise ValueError(f"Param num_images_per_prompt invalid, expected positive value, but get {num_images_per_prompt}")
if max_num_prompts < 0:
raise ValueError(f"Param max_num_prompts invalid, expected greater than or equal to 0, \
but get {max_num_prompts}")
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, default="./flux", help="Path to the flux model directory")
parser.add_argument("--save_path", type=str, default="./res", help="ouput image path")
parser.add_argument("--device_id", type=int, default=0, help="NPU device id")
parser.add_argument("--device", choices=["npu", "cpu"], default="npu", help="NPU")
parser.add_argument("--prompt_path", type=str, default="./prompts.txt", help="input prompt text path")
parser.add_argument("--prompt_type", choices=["plain", "parti", "hpsv2"], default="plain", help="specify infer prompt type")
parser.add_argument("--num_images_per_prompt", type=int, default=1, help="specify image number every prompt generate")
parser.add_argument("--max_num_prompt", type=int, default=0, help="limit the prompt number[0 indicates no limit]")
parser.add_argument("--info_file_save_path", type=str, default="./image_info.json", help="path to save image info")
parser.add_argument("--width", type=int, default=1024, help='Image size width')
parser.add_argument("--height", type=int, default=1024, help='Image size height')
parser.add_argument("--infer_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--seed", type=int, default=42, help="A seed for all the prompts")
parser.add_argument("--use_cache", action="store_true", help="turn on dit cache or not")
parser.add_argument("--batch_size", type=int, default=1, help="prompt batch size")
parser.add_argument("--device_type", choices=["A2-32g-single", "A2-32g-dual", "A2-64g"], default="A2-64g", help="specify device type")
return parser.parse_args()
def infer(args):
if args.device_type == "A2-32g-dual":
from FLUX1dev import replace_tp_from_pretrain, replace_tp_extract_init_dict
FluxPipeline.from_pretrained = classmethod(replace_tp_from_pretrain)
FluxPipeline.extract_init_dict = classmethod(replace_tp_extract_init_dict)
check_dir_safety(args.path)
T5_model_path = os.path.join(args.path, "text_encoder_2")
T5_model = T5EncoderModel.from_pretrained(T5_model_path).to(torch.bfloat16)
if args.device_type == "A2-32g-dual":
local_rank = get_local_rank()
world_size = get_world_size()
initialize_torch_distributed(local_rank, world_size)
import deepspeed
T5_model = deepspeed.init_inference(
T5_model,
tensor_parallel={"tp_size": get_world_size()},
)
T5_model.module.to("cpu")
pipe = FluxPipeline.from_pretrained(args.path, torch_dtype=torch.bfloat16, local_files_only=True)
if args.device_type == "A2-32g-single":
torch.npu.set_device(args.device_id)
pipe.enable_model_cpu_offload()
elif args.device_type == "A2-64g":
torch.npu.set_device(args.device_id)
pipe.to(f"npu:{args.device_id}")
else:
pipe.to(f"npu:{local_rank}")
pipe.text_encoder_2.to("cpu")
pipe.text_encoder_2 = T5_model.module.to(f"npu:{local_rank}")
if args.use_cache:
d_stream_config = CacheConfig(
method="dit_block_cache",
blocks_count=19,
steps_count=args.infer_steps,
step_start=18,
step_interval=2,
block_start=5,
block_end=13,
)
d_stream_agent = CacheAgent(d_stream_config)
pipe.transformer.d_stream_agent = d_stream_agent
s_stream_config = CacheConfig(
method="dit_block_cache",
blocks_count=38,
steps_count=args.infer_steps,
step_start=18,
step_interval=2,
block_start=1,
block_end=23,
)
s_stream_agent = CacheAgent(s_stream_config)
pipe.transformer.s_stream_agent = s_stream_agent
else:
d_stream_config = CacheConfig(
method="dit_block_cache",
blocks_count=19,
steps_count=args.infer_steps,
step_start=args.infer_steps,
step_interval=2,
block_start=18,
block_end=18,
)
d_stream_agent = CacheAgent(d_stream_config)
pipe.transformer.d_stream_agent = d_stream_agent
s_stream_config = CacheConfig(
method="dit_block_cache",
blocks_count=38,
steps_count=args.infer_steps,
step_start=args.infer_steps,
step_interval=2,
block_start=37,
block_end=37,
)
s_stream_agent = CacheAgent(s_stream_config)
pipe.transformer.s_stream_agent = s_stream_agent
torch.manual_seed(args.seed)
torch.npu.manual_seed(args.seed)
torch.npu.manual_seed_all(args.seed)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, mode=0o640)
check_dir_safety(args.save_path)
infer_num = 0
time_consume = 0
current_prompt = None
image_info = []
check_file_safety(args.prompt_path)
prompt_loader = PromptLoader(args.prompt_path,
args.prompt_type,
args.batch_size,
args.num_images_per_prompt,
args.max_num_prompt)
check_param_valid(args.height, args.width, args.infer_steps)
for _, input_info in enumerate(prompt_loader):
prompts = input_info['prompts']
save_names = input_info['save_names']
catagories = input_info['catagories']
save_names = input_info['save_names']
n_prompts = input_info['n_prompts']
check_prompts_valid(prompts)
print(f"[{infer_num+n_prompts}/{len(prompt_loader)}]: {prompts}")
infer_num += args.batch_size
if infer_num > 3:
start_time = time.time()
image = pipe(
prompts,
height=args.width,
width=args.height,
guidance_scale=3.5,
num_inference_steps=args.infer_steps,
max_sequence_length=512,
use_cache=args.use_cache,
)
if infer_num > 3:
end_time = time.time() - start_time
time_consume += end_time
for j in range(n_prompts):
image_save_path = os.path.join(args.save_path, f"{save_names[j]}.png")
image[0][j].save(image_save_path)
if current_prompt != prompts[j]:
current_prompt = prompts[j]
image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]})
image_info[-1]['images'].append(image_save_path)
if os.path.exists(args.info_file_save_path):
os.remove(args.info_file_save_path)
with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f:
json.dump(image_info, f)
image_time_count = len(prompt_loader) - 3
print(f"flux pipeline time is:{time_consume/image_time_count}")
return
if __name__ == "__main__":
inference_args = parse_arguments()
infer(inference_args)