"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import argparse
import os
import sys
import torch
from diffusers import StableDiffusion3Pipeline
from torch import nn
from tqdm import tqdm
cur_file_dir = os.path.dirname(os.path.abspath(__file__))
example_base_dir = os.path.abspath(os.path.join(cur_file_dir, "..", "..", ".."))
sys.path.append(example_base_dir)
from example.common.security.pytorch import safe_torch_load
from example.common.security.path import get_valid_read_path, get_write_directory
from msmodelslim.quant import quant_model, SessionConfig
from msmodelslim.quant import W8A8ProcessorConfig, W8A8QuantConfig, SaveProcessorConfig
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="SD3 inference script")
parser.add_argument("--sd3_model_path", type=str, required=True, help='Ckpt path of sd3 model')
parser.add_argument("--prompt_path", type=str, default="./calib_prompts.txt", help="input prompt text path")
parser.add_argument("--width", type=int, default=1024, help='Image size width')
parser.add_argument("--height", type=int, default=1024, help='Image size height')
parser.add_argument("--infer_steps", type=int, default=28, help="Inference steps")
parser.add_argument("--seed", type=int, default=42, help="A seed for all the prompts")
parser.add_argument("--device", type=str, choices=["npu"], default="npu", help="model running device")
parser.add_argument("--save_path", type=str, default="./results", help="path to save image output")
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--quant_type", choices=["w8a8"], default="w8a8", )
parser.add_argument("--quant_weight_save_folder", type=str)
parser.add_argument("--quant_dump_calib_folder", type=str)
parser.add_argument("--do_save_img", action="store_true", help="whether to save image output")
args = parser.parse_args(namespace=namespace)
args.sd3_model_path = get_valid_read_path(args.sd3_model_path, is_dir=True)
args.prompt_path = get_valid_read_path(args.prompt_path, is_dir=False)
args.save_path = get_write_directory(args.save_path)
args.quant_weight_save_folder = get_write_directory(args.quant_weight_save_folder)
args.quant_dump_calib_folder = get_write_directory(args.quant_dump_calib_folder)
return args
def load_prompt(path):
if not path.endswith('txt'):
raise ValueError("prompt path must end with txt")
with open(path, 'r') as file:
text_prompt = file.readlines()
prompts = [line.strip() for line in text_prompt]
return prompts
def inference(args):
pipe = StableDiffusion3Pipeline.from_pretrained(args.sd3_model_path)
pipe.to(args.device)
model = pipe.transformer
model.eval()
def inference_func(save_path, desc=''):
save_path = get_write_directory(save_path)
prompts = load_prompt(args.prompt_path)
for cnt, prompt in enumerate(tqdm(prompts, desc=desc)):
torch.manual_seed(args.seed)
torch.npu.manual_seed(args.seed)
torch.npu.manual_seed_all(args.seed)
images = pipe(
prompt=[prompt],
negative_prompt=[""],
width=args.width,
height=args.height,
num_inference_steps=args.infer_steps,
guidance_scale=7.0
).images
for i, img in enumerate(images):
img.save(os.path.join(save_path, f"{cnt}_{i}.png"))
if args.do_quant:
do_multimodal_quant(
args,
model,
infer_func=inference_func,
infer_args=[],
infer_kwargs=dict(
save_path=os.path.join(args.save_path, 'calib_fp'),
desc='Dump calib data by float model inference'
)
)
if args.do_save_img:
inference_func(save_path=os.path.join(args.save_path, 'calib_quant'),
desc='Run fake quant using calib data')
else:
raise ValueError("Please --do_quant to True")
def do_multimodal_quant(args, model, infer_func, infer_args, infer_kwargs):
from example.multimodal_sd.utils import get_disable_layer_names, get_rank, DumperManager, get_rank_suffix_file
dump_calib_folder = args.quant_dump_calib_folder
safe_tensor_folder = args.quant_weight_save_folder
rank = get_rank()
is_distributed = rank >= 0
dump_data_path = os.path.join(dump_calib_folder, get_rank_suffix_file(base_name="calib_data", ext="pth",
is_distributed=is_distributed, rank=rank))
if not isinstance(model, nn.Module):
raise ValueError("model must be a nn.Module")
if not os.path.exists(dump_data_path):
os.makedirs(os.path.dirname(dump_data_path), exist_ok=True)
dumper_manager = DumperManager(model, capture_mode='args')
infer_func(*infer_args, **infer_kwargs)
dumper_manager.save(dump_data_path)
calib_dataset = safe_torch_load(dump_data_path, map_location=f'npu:{rank if is_distributed else 0}')
def get_w8a8_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8', ext='json',
is_distributed=is_distributed, rank=rank)
_cfg = SessionConfig(
processor_cfg_map={
"w8a8": W8A8ProcessorConfig(
cfg=W8A8QuantConfig(
act_method='minmax'
),
disable_names=['context_embedder']
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=calib_dataset,
device='npu'
)
return _cfg
if args.quant_type == 'w8a8':
session_cfg = get_w8a8_cfg()
else:
raise ValueError("quant_type must be w8a8")
session_cfg.model_validate(session_cfg)
quant_model(model, session_cfg)
if __name__ == "__main__":
args = parse_args()
inference(args)