# Copyright 2024 Huawei Technologies Co., Ltd

from diffusers import DiTPipeline, DPMSolverMultistepScheduler
import torch
import torch_npu

pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("npu")

# pick words from Imagenet class labels
print(pipe.labels)  # to print all available words

# pick words that exist in ImageNet
words = ["white shark", "umbrella"]

class_ids = pipe.get_label_ids(words)

generator = torch.manual_seed(33)
output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)

image = output.images[0]  # label 'white shark'
image.save("white_shark.png")