#!/usr/bin/env python
# -*- coding: UTF-8 -*-

"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.

MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:

         http://license.coscl.org.cn/MulanPSL2

THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""

"""
导入相关依赖
"""
import os
import json
import numpy as np
import torch
from tqdm import tqdm as tqdm
from torchvision.transforms import ToTensor
from diffusers import StableDiffusion3Pipeline
 
from msmodelslim.pytorch.quant.ptq_tools import Calibrator, QuantConfig
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.quant.ptq_tools.quant_modules import TensorQuantizer


torch.npu.set_compile_mode(jit_compile=False)
option = {}
option["NPU_FUZZY_COMPILE_BLACKLIST"] = "ReduceProd"
torch.npu.set_option(option)

"""
导入相关模型
"""
 
def inference(save_path="imgs", categories=[]):
 
    os.makedirs(save_path, exist_ok=True)
 
    torch.manual_seed(42)
 
    pipe = StableDiffusion3Pipeline.from_pretrained(f"{os.environ['PROJECT_PATH']}/resource/multi_modal/sd3_project/stable-diffusion-3-medium-diffusers/", 
                                                    torch_dtype=torch.float16)
    pipe.to("npu")
   
    pipe.set_progress_bar_config(disable=True)
 
    model = pipe.transformer
    prompt_list = []
    # dataset
    with open(f"{os.environ['PROJECT_PATH']}/resource/multi_modal/sd3_project/PartiPrompts.tsv") as f:
        if categories == []:
            prompt_list = [sample.split("\t")[0] for sample in f][1:]
        else:
            prompt_list = [sample.split("\t")[0] for sample in f if sample.split("\t")[1] in categories]
    count = 0
 
    calib_dataset = torch.load(f"{os.environ['PROJECT_PATH']}/resource/multi_modal/sd3_project/sd3_calib_data_v3.pth", map_location="npu")
 
    for data in tqdm(calib_dataset):
        if isinstance(data, dict):
            for key, value in data.items():
                if isinstance(value, torch.Tensor):
                    data[key] = value.to(torch.float16)
     
    """
    对于linear算子中的激活值如果有表示范围过大,或者"尖刺"的异常值过多,
    需要使用anti outlier功能,使用方法如下
    """
    smooth_config = AntiOutlierConfig(
        anti_method='m4',
        dev_type="npu",
        dev_id=0,
    )
    anti_outlier = AntiOutlier(
        model, calib_dataset[:1], smooth_config, norm_class_name="layernorm"
    )
    anti_outlier.process()
 
    # quantization
    q_config = QuantConfig(
        w_bit=8,
        a_bit=8,
        w_signed=True,
        a_signed=True,
        w_sym=True,
        a_sym=False,
        act_quant=True,
        act_method=1,
        quant_mode=1,
        disable_names=None,
        amp_num=0,
        keep_acc=None,
        sigma=25,
        device="npu"
    )
    calibrator = Calibrator(model, q_config, calib_dataset[:1])
    calibrator.run()
    
    calibrator.export_quant_safetensor(f"{os.environ['PROJECT_PATH']}/output/ptq-tools/quant_sd3")
    
    prompt_list = [
        "Portrait of a tiger wearing a train conductor's hat and holding a skateboard that has a yin-yang symbol on it"]
    for prompt in tqdm(prompt_list):
        prompts = [prompt]
        neg_prompts = [""]
        images = pipe(
            prompt=prompts,
            negative_prompt=neg_prompts,
            num_inference_steps=28,
            height=1024,
            width=1024,
            guidance_scale=7.0,
        ).images
        for i, img in enumerate(images):
            img.save(os.path.join(save_path, str(count) + "_" + str(i) + ".jpg"))
        count += 1
    
 
if __name__ == '__main__':
    path_to_save = f"{os.environ['PROJECT_PATH']}/output/ptq-tools/quant_sd3/samples/"
    categories = []
    inference(path_to_save, categories)