"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import argparse
import logging
import os
import time
import sys
import imageio
import torch
from torch import nn
import torch_npu
from diffusers.schedulers import EulerAncestralDiscreteScheduler
from transformers import T5Tokenizer, MT5EncoderModel
from opensora.models.causalvideovae import ae_stride_config, CausalVAEModelWrapper
from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V
from opensora.sample.pipeline_opensora_sp import OpenSoraPipeline
from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank
from opensora.models.causalvideovae.model.causal_vae.parallel_layers import (
register_vae_decode, parallel_full_model_warp)
from utils.file_utils import standardize_path
from msmodelslim.quant import quant_model, SessionConfig
from msmodelslim.quant import W8A8ProcessorConfig, W8A8QuantConfig, SaveProcessorConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_t2v_checkpoint(model_path):
logger.info('load_t2v_checkpoint, %r', model_path)
transformer_model = OpenSoraT2V.from_pretrained(model_path, cache_dir=args.cache_dir,
low_cpu_mem_usage=False, device_map=None,
torch_dtype=weight_dtype, local_files_only=True).to("npu")
transformer_model.eval()
pipeline = OpenSoraPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to("npu")
if args.algorithm == "dit_cache":
from opensora.models.diffusion.opensora.cache_mgr import CacheManager, DitCacheConfig
config = DitCacheConfig(step_start=20, step_interval=2, block_start=7, num_blocks=21)
cache = CacheManager(config)
pipeline.transformer.cache = cache
return pipeline
def run_model_and_save_images(pipeline, args, save_path):
positive_prompt = """
(masterpiece), (best quality), (ultra-detailed), (unwatermarked),
{}.
emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo,
sharp focus, high budget, cinemascope, moody, epic, gorgeous
"""
negative_prompt = """
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality,
low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.
"""
if not isinstance(args.text_prompt, list):
args.text_prompt = [positive_prompt.format(args.text_prompt)]
if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'):
text_prompt = open(args.text_prompt[0], 'r').readlines()
args.text_prompt = [positive_prompt.format(i.strip()) for i in text_prompt]
if args.batch_size > 1:
prompt_list = []
group = len(args.text_prompt) // args.batch_size
tail = len(args.text_prompt) % args.batch_size
for index in range(group):
prompt_list.append(args.text_prompt[index * args.batch_size: (index + 1) * args.batch_size])
if tail > 0:
prompt_list.append(args.text_prompt[: -tail])
else:
prompt_list = args.text_prompt
kwargs = {}
if args.algorithm == "sampling_optimize":
kwargs["sampling_optimize"] = True
if not os.path.exists(save_path):
os.makedirs(save_path, mode=0o750)
if not args.test_time:
for index, prompt in enumerate(prompt_list):
videos = pipeline(prompt,
negative_prompt=negative_prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
num_images_per_prompt=1,
mask_feature=True,
max_sequence_length=args.max_sequence_length,
seed=args.seed,
**kwargs
).images
logger.info(videos.shape)
if get_sequence_parallel_rank() <= 0:
for i in range(len(prompt) if args.batch_size > 1 else 1):
imageio.mimwrite(
os.path.join(
save_path,
f'EulerAncestralDiscrete_{index * args.batch_size + i}'
+ f'_final__gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4'
), videos[i],
fps=args.fps, quality=6, codec='libx264',
output_params=['-threads', '20'])
else:
for _ in range(2):
start_time = time.time()
videos = pipeline(prompt_list[0],
negative_prompt=negative_prompt,
num_frames=args.num_frames,
height=args.height,
width=args.width,
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
num_images_per_prompt=1,
mask_feature=True,
max_sequence_length=args.max_sequence_length,
**kwargs
).images
torch.npu.synchronize()
use_time = time.time() - start_time
logger.info("========= use time %s", str(use_time))
logger.info(videos.shape)
def do_multimodal_quant(args, model, infer_func, infer_args, infer_kwargs):
from example.multimodal_sd.utils import get_disable_layer_names, get_rank, DumperManager
dump_calib_folder = args.quant_dump_calib_folder
safe_tensor_folder = args.quant_weight_save_folder
dump_data_path = os.path.join(dump_calib_folder, f'calib_data_{get_rank()}.pth')
safe_tensor_path = os.path.join(safe_tensor_folder, f'rank_{get_rank()}.safetensors')
if not isinstance(model, nn.Module):
raise ValueError("model must be a nn.Module")
if not os.path.exists(dump_data_path):
os.makedirs(os.path.dirname(dump_data_path), exist_ok=True)
dumper_manager = DumperManager(model, capture_mode='args')
infer_func(*infer_args, **infer_kwargs)
dumper_manager.save(dump_data_path)
calib_dataset = torch.load(dump_data_path, map_location=f'npu:{get_rank()}')
def get_w8a8_cfg():
_cfg = SessionConfig(
processor_cfg_map={
"w8a8": W8A8ProcessorConfig(
cfg=W8A8QuantConfig(
act_method='minmax'
),
disable_names=get_disable_layer_names(
model,
layer_include=None,
layer_exclude=('*net.2*', '*adaln_single*')
)
),
"save": SaveProcessorConfig(
output_path=os.path.dirname(safe_tensor_path),
safetensors_name=os.path.basename(safe_tensor_path),
json_name=None,
save_type=['safe_tensor'],
part_file_size=None
)
},
calib_data=calib_dataset,
device='npu'
)
return _cfg
if args.quant_type == 'w8a8':
session_cfg = get_w8a8_cfg()
else:
raise ValueError("quant_type must be w8a8")
session_cfg.model_validate(session_cfg)
quant_model(model, session_cfg)
if __name__ == "__main__":
cur_file_dir = os.path.dirname(os.path.abspath(__file__))
example_base_dir = os.path.abspath(os.path.join(cur_file_dir, "..", "..", ".."))
sys.path.append(example_base_dir)
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0')
parser.add_argument("--version", type=str, default=None, choices=[None, '65x512x512', '65x256x256', '17x256x256'])
parser.add_argument("--num_frames", type=int, default=93)
parser.add_argument("--height", type=int, default=720)
parser.add_argument("--width", type=int, default=1280)
parser.add_argument('--dtype', type=str, default='bf16', help='Data type used in inference')
parser.add_argument("--cache_dir", type=str, default='./cache_dir')
parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8')
parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8')
parser.add_argument("--text_encoder_name", type=str, default='google/mt5-xxl')
parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v")
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--num_sampling_steps", type=int, default=50)
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_sequence_length", type=int, default=512)
parser.add_argument("--text_prompt", nargs='+')
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
parser.add_argument("--algorithm", type=str, default=None, choices=[None, 'dit_cache', 'sampling_optimize'])
parser.add_argument("--use_cfg_parallel", action='store_true')
parser.add_argument("--test_time", action='store_true')
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--vae_parallel", action='store_true')
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--quant_type", choices=["w8a8"], default="w8a8", )
parser.add_argument("--quant_weight_save_folder", type=str)
parser.add_argument("--quant_dump_calib_folder", type=str)
parser.add_argument("--save_image_flag", action="store_true")
args = parser.parse_args()
if args.dtype not in ['bf16', 'fp16']:
logger.error("Not supported.")
weight_dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float16
torch.npu.config.allow_internal_format = False
world_size = int(os.getenv('WORLD_SIZE', 1))
if world_size > 1:
sp_degree = world_size // 2 if args.use_cfg_parallel else world_size
parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=args.use_cfg_parallel,
world_size=world_size)
init_parallel_env(parallel_config)
args.ae_path = standardize_path(args.ae_path)
args.text_encoder_name = standardize_path(args.text_encoder_name)
args.model_path = standardize_path(args.model_path)
vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16, local_files_only=True).to("npu")
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
vae.vae.tile_sample_min_size = 256
vae.vae.tile_latent_min_size = 32
vae.vae.tile_sample_min_size_t = 29
vae.vae.tile_latent_min_size_t = 8
vae.vae_scale_factor = ae_stride_config[args.ae]
vae.eval()
VAE_PARALLEL = args.vae_parallel
if VAE_PARALLEL:
parallel_dim = -1
parallel_overlap = True
parallel_full_model_warp(vae.vae, parallel_dim)
vae = register_vae_decode(vae, parallel_dim, parallel_overlap)
text_encoder = MT5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir,
low_cpu_mem_usage=True, torch_dtype=weight_dtype,
local_files_only=True).to("npu")
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir, local_files_only=True)
text_encoder.eval()
scheduler = EulerAncestralDiscreteScheduler()
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path, exist_ok=True)
pipeline = load_t2v_checkpoint(args.model_path)
logger.info('load model')
if args.do_quant:
do_multimodal_quant(
args,
pipeline.transformer,
infer_func=run_model_and_save_images,
infer_args=[
pipeline,
args,
],
infer_kwargs=dict(
save_path=os.path.join(args.save_img_path, 'calib_fp'),
)
)
if args.save_image_flag:
run_model_and_save_images(
pipeline,
args,
save_path=os.path.join(args.save_img_path, 'calib_quant')
)
else:
if args.save_image_flag:
run_model_and_save_images(
pipeline,
args,
save_path=args.save_img_path
)
if world_size > 0:
finalize_parallel_env()