import os
import torch
from torch import distributed as dist
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V
from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank
torch.npu.set_compile_mode(jit_compile=False)
use_fa3 = True
world_size = int(os.getenv('WORLD_SIZE', 1))
if world_size > 1:
sp_degree = world_size // 2
parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=True,
world_size=world_size)
init_parallel_env(parallel_config)
rank = dist.get_rank()
model_path = '/home/Open-Sora-Plan-v1.2.0/93x720p'
dev_id = 8
model = OpenSoraT2V.from_pretrained(model_path, cache_dir="../cache_dir",
low_cpu_mem_usage=False, device_map=None,
torch_dtype=torch.bfloat16).to("npu")
calib_datas = torch.load(f"/home/quant_model/calib_datas.pt", map_location='cpu')
for calib_data in calib_datas:
for i, data in enumerate(calib_data):
if torch.is_tensor(data):
calib_data[i] = data.npu()
quant_config = QuantConfig(
a_bit=8,
w_bit=8,
disable_names=None,
dev_type='npu',
dev_id=rank,
act_method=3,
pr=1.0,
w_sym=True,
mm_tensor=False,
is_dynamic=True,
).fa_quant(fa_amp=0)
calibrator = Calibrator(model, quant_config, calib_data=calib_datas, disable_level='L0',
torch_dtype=torch.bfloat16)
calibrator.run()
if use_fa3:
calibrator.save('/home/quant_model', safetensors_name=f'quant_model_weight_w8a8_dynamic_{rank}.safetensors',
save_type=["safe_tensor"],
json_name=f'quant_model_description_w8a8_dynamic_{rank}.json')
elif rank == 0:
calibrator.save('/home/quant_model', save_type=["safe_tensor"])