import argparse
import os
import time
from PIL import Image
import torch
from diffusers.utils import logging
from mindiesd import CacheConfig, CacheAgent
COND_CACHE = bool(int(os.environ.get('COND_CACHE', 0)))
UNCOND_CACHE = bool(int(os.environ.get('UNCOND_CACHE', 0)))
logger = logging.get_logger(__name__)
def custom_op(
name,
fn=None,
/,
*,
mutates_args,
device_types=None,
schema=None,
tags=None,
):
def decorator(func):
return func
if fn is not None:
return decorator(fn)
return decorator
def register_fake(
op,
fn=None,
/,
*,
lib=None,
_stacklevel: int = 1,
allow_override: bool = False,
):
def decorator(func):
return func
if fn is not None:
return decorator(fn)
return decorator
torch.library.custom_op = custom_op
torch.library.register_fake = register_fake
from qwenimage_edit.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage_edit.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
def read_prompts(file_path):
"""读取提示词文件(每行一个提示词)"""
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f"提示词文件不存在: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
prompts = [line.strip() for line in f.readlines() if line.strip()]
if not prompts:
raise ValueError(f"提示词文件内容为空: {file_path}")
return prompts
def _parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="使用 Qwen-Image-Edit-2509 模型生成编辑图像")
parser.add_argument("--model_path", type=str, default="/home/weight/Qwen-Image-Edit-2509/",
help="模型本地路径")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"],
help="模型数据类型")
parser.add_argument("--device", type=str, default="npu", help="运行设备(npu/cuda/cpu)")
parser.add_argument("--device_id", type=int, default=0, help="设备ID(如昇腾芯片索引)")
parser.add_argument("--img_paths", type=str, required=True,
help="输入图像路径(多图用逗号分隔,如 'img1.png,img2.png')")
parser.add_argument("--prompt_file", type=str, default="edit_prompts.txt",
help="提示词文件路径(每行一个提示词)")
parser.add_argument("--negative_prompt_file", type=str, default=None,
help="负面提示词文件路径(每行一个)")
parser.add_argument("--num_inference_steps", type=int, default=40,
help="推理步数")
parser.add_argument("--true_cfg_scale", type=float, default=4.0,
help="真实CFG缩放系数")
parser.add_argument("--guidance_scale", type=float, default=1.0,
help="引导缩放系数(Qwen特有)")
parser.add_argument("--seed", type=int, default=0,
help="随机种子(确保 reproducibility)")
parser.add_argument("--num_images_per_prompt", type=int, default=1,
help="每个提示词生成的图像数量")
parser.add_argument("--output_dir", type=str, default="output_images",
help="生成图像保存目录")
parser.add_argument(
"--quant_desc_path",
type=str,
default=None,
help="Path to quantization description file (e.g., quant_model_description_*.json). "
"Enables quantization if provided (applies to Text Encoder and Transformer)."
)
args = parser.parse_args()
if args.quant_desc_path:
if not os.path.exists(args.quant_desc_path):
raise FileNotFoundError(f"Quantization description file not found: {args.quant_desc_path}")
if not args.quant_desc_path.endswith(".json") or "quant_model_description" not in args.quant_desc_path:
raise ValueError(f"Invalid quantization file: {args.quant_desc_path}. "
"Expected format: 'quant_model_description_*.json'")
return args
def main():
args = _parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = f"{args.device}:{args.device_id}"
torch.npu.set_device(args.device_id)
logger.warning(f"使用设备: {device}")
torch_dtype = torch.bfloat16 if args.torch_dtype == "bfloat16" else torch.float32
logger.warning(f"从 {args.model_path} 加载模型...")
transformer = QwenImageTransformer2DModel.from_pretrained(
os.path.join(args.model_path, 'transformer'),
torch_dtype=torch_dtype,
device_map=None,
low_cpu_mem_usage=True
)
if args.quant_desc_path:
from mindiesd import quantize
logger.warning("Quantizing Transformer (单独量化核心组件)...")
quantize(
model=transformer,
quant_des_path=args.quant_desc_path,
use_nz=True,
)
torch.npu.empty_cache()
pipeline = QwenImageEditPlusPipeline.from_pretrained(
args.model_path,
transformer=transformer,
torch_dtype=torch_dtype,
device_map=None,
low_cpu_mem_usage=True
)
pipeline.vae.use_slicing = True
pipeline.vae.use_tiling = True
pipeline.to(device)
pipeline.set_progress_bar_config(disable=None)
img_path_list = [p.strip() for p in args.img_paths.split(",")]
images = []
for img_path in img_path_list:
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像文件不存在: {img_path}")
img = Image.open(img_path).convert("RGB")
images.append(img)
logger.warning(f"加载完成 {len(images)} 张输入图像")
if COND_CACHE or UNCOND_CACHE:
cache_config = CacheConfig(
method="dit_block_cache",
blocks_count=60,
steps_count=args.num_inference_steps,
step_start=10,
step_interval=3,
step_end=35,
block_start=10,
block_end=50
)
pipeline.transformer.cache_cond = CacheAgent(cache_config) if COND_CACHE else None
pipeline.transformer.cache_uncond = CacheAgent(cache_config) if UNCOND_CACHE else None
logger.warning("启用缓存配置")
prompts = read_prompts(args.prompt_file)
neg_prompts = read_prompts(args.negative_prompt_file) if args.negative_prompt_file else [" "] * len(prompts)
logger.warning(f"加载完成 {len(prompts)} 个提示词")
total_time = 0.0
for prompt_idx, (prompt, neg_prompt) in enumerate(zip(prompts, neg_prompts)):
inputs = {
"image": images,
"prompt": prompt,
"negative_prompt": neg_prompt,
"generator": torch.Generator(device=device).manual_seed(args.seed),
"true_cfg_scale": args.true_cfg_scale,
"guidance_scale": args.guidance_scale,
"num_inference_steps": args.num_inference_steps,
"num_images_per_prompt": args.num_images_per_prompt,
}
torch.npu.synchronize()
start_time = time.time()
with torch.inference_mode():
output = pipeline(**inputs)
torch.npu.synchronize()
end_time = time.time()
infer_time = end_time - start_time
logger.warning(f"提示词 {prompt_idx + 1}/{len(prompts)} 推理完成,耗时: {infer_time:.2f}秒")
for img_idx, img in enumerate(output.images):
save_path = os.path.join(
args.output_dir,
f"edit_result_{prompt_idx}_{img_idx}.png"
)
img.save(save_path)
logger.warning(f"图像保存至: {save_path}")
if prompt_idx >= 3:
total_time += infer_time
if len(prompts) > 3:
avg_time = total_time / (len(prompts) - 3)
logger.warning(f"排除前3次预热后,平均推理时间: {avg_time:.2f}秒")
if __name__ == "__main__":
main()