import os
import sys
import argparse
from pathlib import Path
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from torch import nn
from tqdm import tqdm
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import (
sanity_check_args, add_device_args, add_network_args, add_extra_models_args, add_denoise_schedule_args,
add_inference_args, add_parallel_args, add_ditcache_args, add_attentioncache_args
)
from hyvideo.inference import HunyuanVideoSampler
from mindiesd import CacheConfig, CacheAgent
cur_file_dir = os.path.dirname(os.path.abspath(__file__))
example_base_dir = os.path.abspath(os.path.join(cur_file_dir, "..", "..", ".."))
sys.path.append(example_base_dir)
from example.common.security.pytorch import safe_torch_load
from example.common.security.path import get_write_directory, get_valid_read_path
from msmodelslim.quant import quant_model, SessionConfig, FA3ProcessorConfig, W8A8DynamicQuantConfig, \
W8A8DynamicProcessorConfig, M3ProcessorConfig, M4ProcessorConfig, M6ProcessorConfig, M6Config
from msmodelslim.quant import W8A8TimeStepProcessorConfig, W8A8TimeStepQuantConfig, \
SaveProcessorConfig
from msmodelslim import logger
torch_npu.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
parser = add_device_args(parser)
parser = add_network_args(parser)
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
parser = add_ditcache_args(parser)
parser = add_attentioncache_args(parser)
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--quant_type", choices=["w8a8_timestep", "w8a8_dynamic_fa3", "w8a8_dynamic"],
default="w8a8_timestep", )
parser.add_argument("--anti_method", choices=["m3", "m4", "m6"], default=None)
parser.add_argument("--quant_weight_save_folder", type=str)
parser.add_argument("--quant_dump_calib_folder", type=str)
parser.add_argument("--do_save_video", action="store_true", help="whether to save video output")
args = parser.parse_args(namespace=namespace)
args = sanity_check_args(args)
args.quant_weight_save_folder = get_write_directory(args.quant_weight_save_folder)
args.quant_dump_calib_folder = get_write_directory(args.quant_dump_calib_folder)
return args
def main():
args = parse_args()
models_root_path = Path(args.model_base)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}'
save_path = get_write_directory(save_path)
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
transformer = hunyuan_video_sampler.pipeline.transformer
args = hunyuan_video_sampler.args
if args.prompt.endswith('txt'):
args.prompt = get_valid_read_path(args.prompt)
with open(args.prompt, 'r') as file:
text_prompt = file.readlines()
prompts = [line.strip() for line in text_prompt]
else:
prompts = [args.prompt]
if args.use_cache:
config_single = CacheConfig(
method="dit_block_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps,
step_start=args.cache_start_steps,
step_interval=args.cache_interval,
step_end=args.infer_steps - 1,
block_start=args.single_block_start,
block_end=args.single_block_end
)
cache_single = CacheAgent(config_single)
hunyuan_video_sampler.pipeline.transformer.cache_single = cache_single
if args.use_cache_double:
config_double = CacheConfig(
method="dit_block_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps,
step_start=args.cache_start_steps,
step_interval=args.cache_interval,
step_end=args.infer_steps - 1,
block_start=args.double_block_start,
block_end=args.double_block_end
)
cache_dual = CacheAgent(config_double)
hunyuan_video_sampler.pipeline.transformer.cache_dual = cache_dual
if args.use_attentioncache:
config_double = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
config_single = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
else:
config_double = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.double_blocks),
steps_count=args.infer_steps
)
config_single = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.single_blocks),
steps_count=args.infer_steps
)
cache_double = CacheAgent(config_double)
cache_single = CacheAgent(config_single)
for block in transformer.double_blocks:
block.cache = cache_double
for block in transformer.single_blocks:
block.cache = cache_single
def sample(save_path, desc='', **kwargs):
for idx in tqdm(range(len(prompts)), desc=desc):
if idx == 0:
outputs = hunyuan_video_sampler.predict(
prompt=prompts[0],
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=2,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale
)
outputs = hunyuan_video_sampler.predict(
prompt=prompts[idx],
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale
)
samples = outputs['samples']
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
video_path = f"{save_path}/sample_{idx}.mp4"
save_videos_grid(sample, video_path, fps=24)
logger.info('Sample save to: %r', video_path)
if args.do_quant:
do_multimodal_quant(
args,
transformer,
infer_func=sample,
infer_args=[],
infer_kwargs=dict(
save_path=os.path.join(args.save_path, 'calib_fp'),
desc='Dump calib data by float model inference'
)
)
if args.do_save_video:
sample(save_path=os.path.join(args.save_path, 'calib_quant'),
desc='Run fake quant using calib data')
else:
raise ValueError("Please --do_quant to True")
return
def do_multimodal_quant(args, model, infer_func, infer_args, infer_kwargs):
from example.multimodal_sd.utils import get_disable_layer_names, get_rank, DumperManager, get_rank_suffix_file
dump_calib_folder = args.quant_dump_calib_folder
safe_tensor_folder = args.quant_weight_save_folder
rank = get_rank()
is_distributed = rank >= 0
dump_data_path = os.path.join(dump_calib_folder, get_rank_suffix_file(base_name="calib_data", ext="pth",
is_distributed=is_distributed, rank=rank))
if not isinstance(model, nn.Module):
raise ValueError("model must be a nn.Module")
if not os.path.exists(dump_data_path):
os.makedirs(os.path.dirname(dump_data_path), exist_ok=True)
capture_mode = 'timestep' if 'timestep' in args.quant_type else 'args'
dumper_manager = DumperManager(model, capture_mode=capture_mode)
infer_func(*infer_args, **infer_kwargs)
dumper_manager.save(dump_data_path)
calib_dataset = safe_torch_load(dump_data_path, map_location=f'npu:{rank if is_distributed else 0}')
def get_timestep_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_timestep', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_timestep', ext='json',
is_distributed=is_distributed, rank=rank)
_cfg = SessionConfig(
processor_cfg_map={
"w8a8_timestep": W8A8TimeStepProcessorConfig(
cfg=W8A8TimeStepQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*', '*fc2*'),
),
timestep_sep=25,
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=calib_dataset,
device='npu'
)
return _cfg
def get_fa3_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_dynamic', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_dynamic', ext='json',
is_distributed=is_distributed, rank=rank)
_cfg = SessionConfig(
processor_cfg_map={
"fa3": FA3ProcessorConfig(),
"w8a8_dynamic": W8A8DynamicProcessorConfig(
cfg=W8A8DynamicQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*'),
),
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=["safe_tensor"],
part_file_size=None,
)
},
calib_data=calib_dataset[:20],
device="npu",
)
return _cfg
def get_w8a8_dynamic_cfg():
safetensors_name = get_rank_suffix_file(base_name='quant_model_weight_w8a8_dynamic', ext='safetensors',
is_distributed=is_distributed, rank=rank)
json_name = get_rank_suffix_file(base_name='quant_model_description_w8a8_dynamic', ext='json',
is_distributed=is_distributed, rank=rank)
processor_cfg_map = {
"w8a8_dynamic": W8A8DynamicProcessorConfig(
cfg=W8A8DynamicQuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=('*double_blocks*', '*single_blocks*'),
layer_exclude=('*img_mod*', '*modulation*', '*fc2*'),
),
),
"save": SaveProcessorConfig(
output_path=safe_tensor_folder,
safetensors_name=safetensors_name,
json_name=json_name,
save_type=['safe_tensor'],
part_file_size=None
)
}
if args.anti_method == 'm3':
processor_cfg_map['m3'] = M3ProcessorConfig()
elif args.anti_method == 'm4':
processor_cfg_map['m4'] = M4ProcessorConfig()
elif args.anti_method == 'm6':
processor_cfg_map['m6'] = M6ProcessorConfig(
cfg=M6Config(
alpha=0.8,
beta=0.2
)
)
_cfg = SessionConfig(
processor_cfg_map=processor_cfg_map,
calib_data=calib_dataset,
device='npu'
)
return _cfg
if 'timestep' in args.quant_type:
session_cfg = get_timestep_cfg()
elif 'fa3' in args.quant_type:
session_cfg = get_fa3_cfg()
elif args.quant_type == 'w8a8_dynamic':
model.config.model_type = 'hunyuan_video'
session_cfg = get_w8a8_dynamic_cfg()
else:
raise ValueError("quant_type must be timestep or fa3")
session_cfg.model_validate(session_cfg)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
quant_model(model, session_cfg)
if __name__ == "__main__":
main()