import os
import warnings
import logging
warnings.filterwarnings("ignore")
from torch.cuda.amp import autocast
import time, cv2, torch, wandb
import torch.distributed as dist
from config.diffconfig import DiffusionConfig, get_model_conf
from config.dataconfig import Config as DataConfig
from tensorfn import load_config as DiffConfig
from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
from tensorfn.optim import lr_scheduler
from torch import nn, optim
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import data as deepfashion_data
from model import UNet
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
logging.basicConfig(filename="train.log", level=logging.INFO, format='%(message)s')
def init_distributed():
dist_url = "env://"
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=rank)
torch.cuda.set_device(local_rank)
dist.barrier()
setup_for_distributed(rank == 0)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_main_process():
try:
if dist.get_rank()==0:
return True
else:
return False
except:
return True
def sample_data(loader):
loader_iter = iter(loader)
epoch = 0
while True:
try:
yield epoch, next(loader_iter)
except StopIteration:
epoch += 1
loader_iter = iter(loader)
yield epoch, next(loader_iter)
def accumulate(model1, model2, decay=0.9999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
def train(conf, loader, val_loader, model, ema, diffusion, betas, optimizer, scheduler, guidance_prob, cond_scale, device, wandb):
import time
i = 0
loss_list = []
loss_mean_list = []
loss_vb_list = []
for epoch in range(300):
if is_main_process: print ('#Epoch - '+str(epoch))
start_time = time.time()
step_end_time = time.time()
for batch in tqdm(loader):
i = i + 1
img = torch.cat([batch['source_image'], batch['target_image']], 0)
target_img = torch.cat([batch['target_image'], batch['source_image']], 0)
target_pose = torch.cat([batch['target_skeleton'], batch['source_skeleton']], 0)
img = img.to(device)
target_img = target_img.to(device)
target_pose = target_pose.to(device)
time_t = torch.randint(
0,
conf.diffusion.beta_schedule["n_timestep"],
(img.shape[0],),
device=device,
)
if not args.use_bf16:
loss_dict = diffusion.training_losses(model, x_start = target_img, t = time_t, cond_input = [img, target_pose], prob = 1 - guidance_prob)
else:
with autocast(dtype=torch.bfloat16):
loss_dict = diffusion.training_losses(model, x_start = target_img, t = time_t, cond_input = [img, target_pose], prob = 1 - guidance_prob)
loss = loss_dict['loss'].mean()
loss_mse = loss_dict['mse'].mean()
loss_vb = loss_dict['vb'].mean()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1)
scheduler.step()
optimizer.step()
loss = loss_dict['loss'].mean()
loss_list.append(loss.detach().item())
loss_mean_list.append(loss_mse.detach().item())
loss_vb_list.append(loss_vb.detach().item())
accumulate(
ema, model.module, 0 if i < conf.training.scheduler.warmup else 0.9999
)
if i%args.save_checkpoints_every_iters == 0 and is_main_process():
if conf.distributed:
model_module = model.module
else:
model_module = model
torch.save(
{
"model": model_module.state_dict(),
"ema": ema.state_dict(),
"scheduler": scheduler.state_dict(),
"optimizer": optimizer.state_dict(),
"conf": conf,
},
conf.training.ckpt_path + f"/model_{str(i).zfill(6)}.pt"
)
if is_main_process():
step_total_time = time.time() - step_end_time
logging.info({'loss': (sum(loss_list) / len(loss_list)),
'loss_vb': (sum(loss_vb_list) / len(loss_vb_list)),
'loss_mean': (sum(loss_mean_list) / len(loss_mean_list)),
'epoch': epoch, 'steps': i, 'step_train_time': step_total_time,
'FPS': args.batch_size / step_total_time})
loss_list = []
loss_mean_list = []
loss_vb_list = []
step_end_time = time.time()
if is_main_process():
print ('Epoch Time '+str(int(time.time()-start_time))+' secs')
print ('Model Saved Successfully for #epoch '+str(epoch)+' #steps '+str(i))
if conf.distributed:
model_module = model.module
else:
model_module = model
torch.save(
{
"model": model_module.state_dict(),
"ema": ema.state_dict(),
"scheduler": scheduler.state_dict(),
"optimizer": optimizer.state_dict(),
"conf": conf,
},
conf.training.ckpt_path + '/last.pt'
)
if (epoch)%args.save_wandb_images_every_epochs==0:
print ('Generating samples at epoch number ' + str(epoch))
val_batch = next(val_loader)
val_img = val_batch['source_image'].cuda()
val_pose = val_batch['target_skeleton'].cuda()
with torch.no_grad():
if args.sample_algorithm == 'ddpm':
print ('Sampling algorithm used: DDPM')
samples = diffusion.p_sample_loop(ema, x_cond = [val_img, val_pose], progress = True, cond_scale = cond_scale)
elif args.sample_algorithm == 'ddim':
print ('Sampling algorithm used: DDIM')
nsteps = 50
noise = torch.randn(val_img.shape).cuda()
seq = range(0, 1000, 1000//nsteps)
xs, x0_preds = ddim_steps(noise, seq, ema, betas.cuda(), [val_img, val_pose])
samples = xs[-1].cuda()
grid = torch.cat([val_img, val_pose[:,:3], samples], -1)
gathered_samples = [torch.zeros_like(grid) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, grid)
if is_main_process():
wandb.log({'samples':wandb.Image(torch.cat(gathered_samples, -2))})
def main(settings, EXP_NAME):
[args, DiffConf, DataConf] = settings
if is_main_process(): wandb.init(project="person-synthesis", name = EXP_NAME, settings = wandb.Settings(code_dir="."))
if DiffConf.ckpt is not None:
DiffConf.training.scheduler.warmup = 0
DiffConf.distributed = True
local_rank = int(os.environ['LOCAL_RANK'])
DataConf.data.train.batch_size = args.batch_size//2
val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = True)
def cycle(iterable):
while True:
for x in iterable:
yield x
val_dataset = iter(cycle(val_dataset))
model = get_model_conf().make_model()
model = model.to(args.device)
ema = get_model_conf().make_model()
ema = ema.to(args.device)
if DiffConf.distributed:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
find_unused_parameters=True
)
optimizer = torch_npu.optim.NpuFusedAdam(model.parameters(), lr=2e-5)
scheduler = DiffConf.training.scheduler.make(optimizer)
if DiffConf.ckpt is not None:
ckpt = torch.load(DiffConf.ckpt, map_location=lambda storage, loc: storage)
if DiffConf.distributed:
model.module.load_state_dict(ckpt["model"])
else:
model.load_state_dict(ckpt["model"])
ema.load_state_dict(ckpt["ema"])
scheduler.load_state_dict(ckpt["scheduler"])
if is_main_process(): print ('model loaded successfully')
betas = DiffConf.diffusion.beta_schedule.make()
diffusion = create_gaussian_diffusion(betas, predict_xstart = False)
train(
DiffConf, train_dataset, val_dataset, model, ema, diffusion, betas, optimizer, scheduler, args.guidance_prob, args.cond_scale, args.device, wandb
)
if __name__ == "__main__":
init_distributed()
import argparse
parser = argparse.ArgumentParser(description='help')
parser.add_argument('--exp_name', type=str, default='pidm_deepfashion')
parser.add_argument('--DiffConfigPath', type=str, default='./config/diffusion.conf')
parser.add_argument('--DataConfigPath', type=str, default='./config/data.yaml')
parser.add_argument('--dataset_path', type=str, default='./dataset/deepfashion')
parser.add_argument('--save_path', type=str, default='checkpoints')
parser.add_argument('--cond_scale', type=int, default=2)
parser.add_argument('--guidance_prob', type=int, default=0.1)
parser.add_argument('--sample_algorithm', type=str, default='ddim')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--save_wandb_logs_every_iters', type=int, default=50)
parser.add_argument('--save_checkpoints_every_iters', type=int, default=2000)
parser.add_argument('--save_wandb_images_every_epochs', type=int, default=10)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--n_gpu', type=int, default=8)
parser.add_argument('--n_machine', type=int, default=1)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
parser.add_argument("--use_bf16", action='store_true')
args = parser.parse_args()
print ('Experiment: '+ args.exp_name)
DiffConf = DiffConfig(DiffusionConfig, args.DiffConfigPath, args.opts, False)
DataConf = DataConfig(args.DataConfigPath)
DiffConf.training.ckpt_path = os.path.join(args.save_path, args.exp_name)
DataConf.data.path = args.dataset_path
if is_main_process():
if not os.path.isdir(args.save_path): os.mkdir(args.save_path)
if not os.path.isdir(DiffConf.training.ckpt_path): os.mkdir(DiffConf.training.ckpt_path)
main(settings = [args, DiffConf, DataConf], EXP_NAME = args.exp_name)