import os
import torch
from diffusers import HiDreamImagePipeline
from prompt_utils import run_inference
from transformer_patches import apply_patches
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
apply_patches()
MODEL_PATH = "HiDream-ai/HiDream-I1-Full"
FORTH_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"
OUTPUT_PATH = "./infer_result"
DEVICE = "npu"
os.makedirs(OUTPUT_PATH, exist_ok=True)
tokenizer = PreTrainedTokenizerFast.from_pretrained(FORTH_PATH)
text_encoder = LlamaForCausalLM.from_pretrained(
FORTH_PATH,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipe = HiDreamImagePipeline.from_pretrained(
MODEL_PATH,
tokenizer_4=tokenizer,
text_encoder_4=text_encoder,
torch_dtype=torch.bfloat16,
local_files_only=True,
)
pipe = pipe.to(DEVICE)
pipe.enable_model_cpu_offload()
run_inference(pipe, OUTPUT_PATH)