import argparse
import torch
import time
import os
from tqdm import tqdm
from background_session import BackgroundInferSession
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from ais_bench.infer.interface import InferSession
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', default=256, type=int, choices=[256, 512], help='image size')
parser.add_argument('--num_classes', type=int, default=1000)
parser.add_argument('--cfg_scale', type=float, default=1.5)
parser.add_argument('--num_sampling_steps', type=int, default=250)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--model', required=True, metavar='DIR', help='path to dit model')
parser.add_argument('--vae', required=True, metavar='DIR', help='path to vae model')
parser.add_argument('--parallel', action="store_true", help='use parallel during inference')
parser.add_argument("--class_label", type=int, default=0)
parser.add_argument(
"--results",
type=str,
default="./results",
help="Path of directory to save all class images"
)
return parser.parse_args()
def infer(x, t, y, cfg_scale, sessions):
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
combined = combined.type(torch.FloatTensor)
t = t.type(torch.LongTensor)
y = y.type(torch.LongTensor)
if args.parallel:
combined, combined_ = combined.chunk(2)
t, t_ = t.chunk(2)
y, y_ = y.chunk(2)
sessions[1].infer_asyn([combined_, t_, y_])
model_out = torch.from_numpy(sessions[0].infer([combined, t, y])[0])
if args.parallel:
out = torch.from_numpy(sessions[1].wait_and_get_outputs()[0])
model_out = torch.cat([model_out, out])
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
def encoder_for_dit(class_labels, sessions, vae, diffusion, device, progress=True):
torch.manual_seed(args.seed)
n = len(class_labels)
latent_size = args.image_size // 8
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)
z = torch.cat([z, z], 0)
y_null = torch.tensor([args.num_classes] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
model_kwargs["sessions"] = sessions
start = time.time()
samples = diffusion.p_sample_loop(
infer, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
)
samples, _ = samples.chunk(2, dim=0)
samples = torch.tensor(vae.infer([(samples / vae.scaling_factor)])[0])
end = time.time()
if progress:
print(f"sample time is: {(end - start):.2f}s")
else:
tqdm.write(f"sample {class_labels[0]} time is: {(end - start):.2f}s")
save_image(
samples,
f"{args.results}/sample_{class_labels[0]:06d}.png",
nrow=4,
normalize=True,
value_range=(-1, 1)
)
def main():
if not os.path.exists(args.results):
os.makedirs(args.results, mode=0o640)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
diffusion = create_diffusion(str(args.num_sampling_steps))
vae = InferSession(args.device_id, args.vae)
vae_path = f"stabilityai/sd-vae-ft-{args.vae.split('.')[-2].split('_')[-1]}"
vae.scaling_factor = AutoencoderKL.from_pretrained(vae_path).to("cpu").config.scaling_factor
sessions = []
session = InferSession(args.device_id, args.model)
if args.parallel:
session_bg = BackgroundInferSession.clone(session, args.device_id + 1, [args.model, None])
sessions = [session, session_bg]
else:
sessions = [session]
if args.class_label == -1 :
for i in tqdm(range(1000)):
encoder_for_dit([i], sessions, vae, diffusion, device, False)
else:
encoder_for_dit([args.class_label], sessions, vae, diffusion, device)
if args.parallel:
session_bg.stop()
if __name__ == "__main__":
args = parse_arguments()
main()