"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import sys
import argparse
from pathlib import Path
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from torch import nn
from tqdm import tqdm
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import (
sanity_check_args, add_device_args, add_network_args, add_extra_models_args, add_denoise_schedule_args,
add_inference_args, add_parallel_args, add_ditcache_args, add_attentioncache_args
)
from hyvideo.inference import HunyuanVideoSampler
from mindiesd import CacheConfig, CacheAgent
cur_file_dir = os.path.dirname(os.path.abspath(__file__))
example_base_dir = os.path.abspath(os.path.join(cur_file_dir, "..", "..", ".."))
sys.path.append(example_base_dir)
from example.common.security.pytorch import safe_torch_load
from example.common.security.path import get_write_directory, get_valid_read_path
from msmodelslim.quant import quant_model, SessionConfig, FA3ProcessorConfig, W8A8DynamicQuantConfig, \
W8A8DynamicProcessorConfig, M3ProcessorConfig, M4ProcessorConfig, M6ProcessorConfig, M6Config
from msmodelslim.quant import W8A8TimeStepProcessorConfig, W8A8TimeStepQuantConfig, \
SaveProcessorConfig
from msmodelslim import logger
torch_npu.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
parser = add_device_args(parser)
parser = add_network_args(parser)
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
parser = add_ditcache_args(parser)
parser = add_attentioncache_args(parser)
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--quant_type", choices=["w8a8_timestep", "w8a8_dynamic_fa3", "w8a8_dynamic"],
default="w8a8_timestep", )
parser.add_argument("--anti_method", choices=["m3", "m4", "m6"], default=None)
parser.add_argument("--quant_weight_save_folder", type=str)
parser.add_argument("--quant_dump_calib_folder", type=str)
parser.add_argument("--do_save_video", action="store_true", help="whether to save video output")
args = parser.parse_args(namespace=namespace)
args = sanity_check_args(args)
args.quant_weight_save_folder = get_write_directory(args.quant_weight_save_folder)
args.quant_dump_calib_folder = get_write_directory(args.quant_dump_calib_folder)
return args
def main():
args = parse_args()
models_root_path = Path(args.model_base)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}'
save_path = get_write_directory(save_path)
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
transformer = hunyuan_video_sampler.pipeline.transformer
args = hunyuan_video_sampler.args
if args.prompt.endswith('txt'):
args.prompt = get_valid_read_path(args.prompt)
with open(args.prompt, 'r') as file:
text_prompt = file.readlines()
prompts = [line.strip() for line in text_prompt]
else:
prompts = [args.prompt]
if args.use_cache:
config_single = CacheConfig(
method="dit_block_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps,
step_start=args.cache_start_steps,
step_interval=args.cache_interval,
step_end=args.infer_steps - 1,
block_start=args.single_block_start,
block_end=args.single_block_end
)
cache_single = CacheAgent(config_single)
hunyuan_video_sampler.pipeline.transformer.cache_single = cache_single
if args.use_cache_double:
config_double = CacheConfig(
method="dit_block_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps,
step_start=args.cache_start_steps,
step_interval=args.cache_interval,
step_end=args.infer_steps - 1,
block_start=args.double_block_start,
block_end=args.double_block_end
)
cache_dual = CacheAgent(config_double)
hunyuan_video_sampler.pipeline.transformer.cache_dual = cache_dual
if args.use_attentioncache:
config_double = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
config_single = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
else:
config_double = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps
)
config_single = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps
)
cache_double = CacheAgent(config_double)
cache_single = CacheAgent(config_single)
for block in transformer.double_blocks:
block.cache = cache_double
for block in transformer.single_blocks:
block.cache = cache_single
def sample(save_path, desc='', **kwargs):
for idx in tqdm(range(len(prompts)), desc=desc):
if idx == 0:
outputs = hunyuan_video_sampler.predict(
prompt=prompts[0],
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=2,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale
)
outputs = hunyuan_video_sampler.predict(
prompt=prompts[idx],
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale
)
samples = outputs['samples']
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
video_path = f"{save_path}/sample_{idx}.mp4"
save_videos_grid(sample, video_path, fps=24)
logger.info('Sample save to: %r', video_path)
if args.do_quant:
do_multimodal_quant(
args,
transformer,
infer_func=sample,
infer_args=[],
infer_kwargs=dict(
save_path=os.path.join(args.save_path, 'calib_fp'),
desc='Dump calib data by float model inference'
)
)
if args.do_save_video:
sample(save_path=os.path.join(args.save_path, 'calib_quant'),
desc='Run fake quant using calib data')
else:
raise ValueError("Please --do_quant to True")
return
def do_multimodal_quant(args, model, infer_func, infer_args, infer_kwargs):
from example.multimodal_sd.utils import get_disable_layer_names, get_rank, DumperManager, get_rank_suffix_file
dump_calib_folder = args.quant_dump_calib_folder
safe_tensor_folder = args.quant_weight_save_folder
rank = get_rank()
is_distributed = rank >= 0
dump_data_path = os.path.join(dump_calib_folder, get_rank_suffix_file(base_name="calib_data", ext="pth",
is_distributed=is_distributed, rank=rank))
if not isinstance(model, nn.Module):
raise ValueError("model must be a nn.Module")
if not os.path.exists(dump_data_path):
os.makedirs(os.path.dirname(dump_data_path), exist_ok=True)
capture_mode = 'timestep' if 'timestep' in args.quant_type else 'args'
dumper_manager = DumperManager(model, capture_mode=capture_mode)
infer_func(*infer_args, **infer_kwargs)
dumper_manager.save(dump_data_path)
calib_dataset = safe_torch_load(dump_data_path, map_location=f'npu:{rank if is_distributed else 0}')
def get_timestep_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_timestep', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_timestep', ext='json',
is_distributed=is_distributed, rank=rank)
_cfg = SessionConfig(
processor_cfg_map={
"w8a8_timestep": W8A8TimeStepProcessorConfig(
cfg=W8A8TimeStepQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*', '*fc2*'),
),
timestep_sep=25,
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=calib_dataset,
device='npu'
)
return _cfg
def get_fa3_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_dynamic', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_dynamic', ext='json',
is_distributed=is_distributed, rank=rank)
_cfg = SessionConfig(
processor_cfg_map={
"fa3": FA3ProcessorConfig(),
"w8a8_dynamic": W8A8DynamicProcessorConfig(
cfg=W8A8DynamicQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*'),
),
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=["safe_tensor"],
part_file_size=None,
)
},
calib_data=calib_dataset[:20],
device="npu",
)
return _cfg
def get_w8a8_dynamic_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_dynamic', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_dynamic', ext='json',
is_distributed=is_distributed, rank=rank)
processor_cfg_map = {
"w8a8_dynamic": W8A8DynamicProcessorConfig(
cfg=W8A8DynamicQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*', '*fc2*'),
),
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=['safe_tensor'],
part_file_size=None
)
}
if args.anti_method == 'm3':
processor_cfg_map['m3'] = M3ProcessorConfig()
elif args.anti_method == 'm4':
processor_cfg_map['m4'] = M4ProcessorConfig()
elif args.anti_method == 'm6':
processor_cfg_map['m6'] = M6ProcessorConfig(
cfg=M6Config(
alpha=0.8,
beta=0.2
)
)
_cfg = SessionConfig(
processor_cfg_map=processor_cfg_map,
calib_data=calib_dataset,
device='npu'
)
return _cfg
if 'timestep' in args.quant_type:
session_cfg = get_timestep_cfg()
elif 'fa3' in args.quant_type:
session_cfg = get_fa3_cfg()
elif args.quant_type == 'w8a8_dynamic':
model.config.model_type = 'hunyuan_video'
session_cfg = get_w8a8_dynamic_cfg()
else:
raise ValueError("quant_type must be timestep or fa3")
session_cfg.model_validate(session_cfg)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
quant_model(model, session_cfg)
if __name__ == "__main__":
main()