import argparse
import os
import sys
import torch
from diffusers import StableDiffusion3Pipeline
from torch import nn
from tqdm import tqdm
from msmodelslim.quant import quant_model, SessionConfig
from msmodelslim.quant import W8A8ProcessorConfig, W8A8QuantConfig, SaveProcessorConfig
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="SD3 inference script")
parser.add_argument("--sd3_model_path", type=str, required=True, help='Ckpt path of sd3 model')
parser.add_argument("--prompt_path", type=str, default="./prompts.txt", help="input prompt text path")
parser.add_argument("--width", type=int, default=1024, help='Image size width')
parser.add_argument("--height", type=int, default=1024, help='Image size height')
parser.add_argument("--infer_steps", type=int, default=28, help="Inference steps")
parser.add_argument("--seed", type=int, default=42, help="A seed for all the prompts")
parser.add_argument("--device", type=str, choices=["npu"], default="npu", help="model running device")
parser.add_argument("--save_path", type=str, default="./results", help="path to save image output")
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--quant_type", choices=["w8a8"], default="w8a8", )
parser.add_argument("--quant_weight_save_folder", type=str)
parser.add_argument("--quant_dump_calib_folder", type=str)
args = parser.parse_args(namespace=namespace)
return args
def load_prompt(path):
if not path.endswith('txt'):
raise ValueError("prompt path must end with txt")
with open(path, 'r') as file:
text_prompt = file.readlines()
prompts = [line.strip() for line in text_prompt]
return prompts
def inference(args):
pipe = StableDiffusion3Pipeline.from_pretrained(args.sd3_model_path)
pipe.to(args.device)
model = pipe.transformer
model.eval()
def inference_func(save_path, desc=''):
os.makedirs(save_path, exist_ok=True)
prompts = load_prompt(args.prompt_path)
for cnt, prompt in enumerate(tqdm(prompts, desc=desc)):
torch.manual_seed(args.seed)
torch.npu.manual_seed(args.seed)
torch.npu.manual_seed_all(args.seed)
images = pipe(
prompt=[prompt],
negative_prompt=[""],
width=args.width,
height=args.height,
num_inference_steps=args.infer_steps,
guidance_scale=7.0
).images
for i, img in enumerate(images):
img.save(os.path.join(save_path, f"{cnt}_{i}.png"))
if args.do_quant:
do_multimodal_quant(
args,
model,
infer_func=inference_func,
infer_args=[],
infer_kwargs=dict(
save_path=os.path.join(args.save_path, 'calib_fp'),
desc='Dump calib data by float model inference'
)
)
inference_func(save_path=os.path.join(args.save_path, 'calib_quant'),
desc='Run fake quant using calib data')
else:
raise ValueError("Please --do_quant to True")
def do_multimodal_quant(args, model, infer_func, infer_args, infer_kwargs):
from example.multimodal_sd.utils import get_disable_layer_names, get_rank, DumperManager
dump_calib_folder = args.quant_dump_calib_folder
safe_tensor_folder = args.quant_weight_save_folder
dump_data_path = os.path.join(dump_calib_folder, f'calib_data_{get_rank()}.pth')
safe_tensor_path = os.path.join(safe_tensor_folder, f'rank_{get_rank()}.safetensors')
if not isinstance(model, nn.Module):
raise ValueError("model must be a nn.Module")
if not os.path.exists(dump_data_path):
os.makedirs(os.path.dirname(dump_data_path), exist_ok=True)
dumper_manager = DumperManager(model, capture_mode='args')
infer_func(*infer_args, **infer_kwargs)
dumper_manager.save(dump_data_path)
calib_dataset = torch.load(dump_data_path, map_location=f'npu:{get_rank()}')
def get_w8a8_cfg():
_cfg = SessionConfig(
processor_cfg_map={
"w8a8": W8A8ProcessorConfig(
cfg=W8A8QuantConfig(
act_method='minmax'
),
disable_names=['context_embedder']
),
"save": SaveProcessorConfig(
output_path=os.path.dirname(safe_tensor_path),
safetensors_name=os.path.basename(safe_tensor_path),
json_name=None,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=calib_dataset,
device='npu'
)
return _cfg
if args.quant_type == 'w8a8':
session_cfg = get_w8a8_cfg()
else:
raise ValueError("quant_type must be w8a8")
session_cfg.model_validate(session_cfg)
quant_model(model, session_cfg)
if __name__ == "__main__":
cur_file_dir = os.path.dirname(os.path.abspath(__file__))
example_base_dir = os.path.abspath(os.path.join(cur_file_dir, "..", "..", ".."))
sys.path.append(example_base_dir)
args = parse_args()
inference(args)