import argparse
import logging
import time
import os
import csv
import json
import torch
from cogview3plus import CogView3PlusPipeline, set_random_seed, CogView3PlusTransformer2DModel
from cogview3plus.utils.file_utils import standardize_path
from mindiesd import CacheAgent, CacheConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PromptLoader:
def __init__(
self,
prompt_file: str,
prompt_file_type: str,
batch_size: int,
num_images_per_prompt: int = 1,
max_num_prompts: int = 0
):
self.prompts = []
self.catagories = ['Not_specified']
self.batch_size = batch_size
self.num_images_per_prompt = num_images_per_prompt
if prompt_file_type == 'plain':
self.load_prompts_plain(prompt_file, max_num_prompts)
elif prompt_file_type == 'parti':
self.load_prompts_parti(prompt_file, max_num_prompts)
elif prompt_file_type == 'hpsv2':
self.load_prompts_hpsv2(prompt_file, max_num_prompts)
else:
print("This operation is not supported!")
self.current_id = 0
self.inner_id = 0
def __len__(self):
return len(self.prompts) * self.num_images_per_prompt
def __iter__(self):
return self
def __next__(self):
if self.current_id == len(self.prompts):
raise StopIteration
ret = {
'prompts': [],
'catagories': [],
'save_names': [],
'n_prompts': self.batch_size,
}
for _ in range(self.batch_size):
if self.current_id == len(self.prompts):
ret['prompts'].append('')
ret['save_names'].append('')
ret['catagories'].append('')
ret['n_prompts'] -= 1
else:
prompt, catagory_id = self.prompts[self.current_id]
ret['prompts'].append(prompt)
ret['catagories'].append(self.catagories[catagory_id])
ret['save_names'].append(f'{self.current_id}_{self.inner_id}')
self.inner_id += 1
if self.inner_id == self.num_images_per_prompt:
self.inner_id = 0
self.current_id += 1
return ret
def load_prompts_plain(self, file_path: str, max_num_prompts: int):
with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f:
for i, line in enumerate(f):
if max_num_prompts and i == max_num_prompts:
break
prompt = line.strip()
self.prompts.append((prompt, 0))
def load_prompts_parti(self, file_path: str, max_num_prompts: int):
with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f:
next(f)
tsv_file = csv.reader(f, delimiter="\t")
for i, line in enumerate(tsv_file):
if max_num_prompts and i == max_num_prompts:
break
prompt = line[0]
catagory = line[1]
if catagory not in self.catagories:
self.catagories.append(catagory)
catagory_id = self.catagories.index(catagory)
self.prompts.append((prompt, catagory_id))
def load_prompts_hpsv2(self, file_path: str, max_num_prompts: int):
with open(file_path, 'r') as file:
all_prompts = json.load(file)
count = 0
for style, prompts in all_prompts.items():
for prompt in prompts:
count += 1
if max_num_prompts and count >= max_num_prompts:
break
if style not in self.catagories:
self.catagories.append(style)
catagory_id = self.catagories.index(style)
self.prompts.append((prompt, catagory_id))
def parse_arguments():
parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.")
parser.add_argument(
"--prompt_file",
type=str,
default="./prompts/example_prompts.txt",
help="A text file of prompts for generating images.",
)
parser.add_argument(
"--prompt_file_type",
choices=["plain", "parti", "hpsv2"],
default="plain",
help="Type of prompt file.",
)
parser.add_argument(
"--save_dir",
type=str,
default="./results",
help="Path to save result images.",
)
parser.add_argument(
"--info_file_save_path",
type=str,
default="./image_info.json",
help="Path to save image information file.",
)
parser.add_argument(
"--model_path", type=str, default="/data/CogView3B", help="Path to the pre-trained model."
)
parser.add_argument(
"--guidance_scale", type=float, default=7.0, help="The guidance scale for classifier-free guidance."
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt."
)
parser.add_argument(
"--max_num_prompts",
default=0,
type=int,
help="Limit the number of prompts (0: no limit).",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size."
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.")
parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.")
parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.")
parser.add_argument("--dtype", type=str, default="bf16", help="bf16 or fp16")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--device_id", type=int, default=0, help="NPU device id")
parser.add_argument('--cache_algorithm', type=str, default="None", help="The type of optimization algorithm")
return parser.parse_args()
def infer(args):
torch.npu.set_device(args.device_id)
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if args.seed is not None:
set_random_seed(args.seed)
args.model_path = standardize_path(args.model_path)
pipe = CogView3PlusPipeline.from_pretrained(args.model_path, torch_dtype=dtype)
transformer = CogView3PlusTransformer2DModel.from_pretrained(os.path.join(args.model_path, 'transformer'), torch_dtype=dtype)
pipe.transformer = transformer
pipe = pipe.to("npu")
if args.cache_algorithm == "attention":
steps_count = args.num_inference_steps
blocks_count = pipe.transformer.config.num_layers
config = CacheConfig(
method="attention_cache",
blocks_count=blocks_count,
steps_count=steps_count,
step_start=15,
step_end=37,
step_interval=2
)
agent = CacheAgent(config)
pipe.transformer.use_cache = True
for block in pipe.transformer.transformer_blocks:
block.cache = agent
use_time = 0
prompt_loader = PromptLoader(args.prompt_file,
args.prompt_file_type,
args.batch_size,
args.num_images_per_prompt,
args.max_num_prompts)
infer_num = 0
image_info = []
current_prompt = None
for i, input_info in enumerate(prompt_loader):
prompts = input_info['prompts']
catagories = input_info['catagories']
save_names = input_info['save_names']
n_prompts = input_info['n_prompts']
print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}")
infer_num += args.batch_size
start_time = time.time()
images = pipe(
prompt=prompts,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
image_size=(args.height, args.width),
)
if i > 1:
use_time += time.time() - start_time
for j in range(n_prompts):
image_save_path = os.path.join(args.save_dir, f"{save_names[j]}.png")
image = images[0][j]
image.save(image_save_path)
if current_prompt != prompts[j]:
current_prompt = prompts[j]
image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]})
image_info[-1]['images'].append(image_save_path)
infer_num = infer_num - 2 * args.batch_size
print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n"
f"average time: {use_time / infer_num:.3f}s\n")
if os.path.exists(args.info_file_save_path):
os.remove(args.info_file_save_path)
with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f:
json.dump(image_info, f)
if __name__ == "__main__":
inference_args = parse_arguments()
infer(inference_args)