import sde
import ml_collections
import torch
from torch import multiprocessing as mp
from datasets import get_dataset
from torchvision.utils import make_grid, save_image
import utils
import einops
from torch.utils._pytree import tree_map
import accelerate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
import tempfile
from tools.fid_score import calculate_fid_given_paths
from absl import logging
import builtins
import os
import wandb
import time
from torch_npu.contrib import transfer_to_npu
def train(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
mp.set_start_method('spawn', force=True)
accelerator = accelerate.Accelerator()
device = accelerator.device
accelerate.utils.set_seed(config.seed, device_specific=True)
logging.info(f'Process {accelerator.process_index} using device: {device}')
config.mixed_precision = accelerator.mixed_precision
config = ml_collections.FrozenConfigDict(config)
assert config.train.batch_size % accelerator.num_processes == 0
mini_batch_size = config.train.batch_size // accelerator.num_processes
if accelerator.is_main_process:
os.makedirs(config.ckpt_root, exist_ok=True)
os.makedirs(config.sample_dir, exist_ok=True)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
name=config.hparams, job_type='train', mode='offline')
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
logging.info(config)
else:
utils.set_logger(log_level='error')
builtins.print = lambda *args: None
dataset = get_dataset(**config.dataset)
assert os.path.exists(dataset.fid_stat)
train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
num_workers=8, pin_memory=True, persistent_workers=True)
train_state = utils.initialize_train_state(config, device)
nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
lr_scheduler = train_state.lr_scheduler
train_state.resume(config.ckpt_root)
def get_data_generator():
while True:
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
yield data
data_generator = get_data_generator()
score_model = sde.ScoreModel(nnet, pred=config.pred, sde=sde.VPSDE())
score_model_ema = sde.ScoreModel(nnet_ema, pred=config.pred, sde=sde.VPSDE())
def train_step(_batch):
step_start_time = time.time()
_metrics = dict()
optimizer.zero_grad()
if config.train.mode == 'uncond':
loss = sde.LSimple(score_model, _batch, pred=config.pred)
elif config.train.mode == 'cond':
loss = sde.LSimple(score_model, _batch[0], pred=config.pred, y=_batch[1])
else:
raise NotImplementedError(config.train.mode)
_metrics['loss'] = accelerator.gather(loss.detach()).mean()
accelerator.backward(loss.mean())
if 'grad_clip' in config and config.grad_clip > 0:
accelerator.clip_grad_norm_(nnet.parameters(), max_norm=config.grad_clip)
optimizer.step()
lr_scheduler.step()
train_state.ema_update(config.get('ema_rate', 0.9999))
train_state.step += 1
step_train_time = time.time() - step_start_time
fps = config.train.batch_size / step_train_time
return dict(lr=train_state.optimizer.param_groups[0]['lr'], fps=fps, **_metrics)
def eval_step(n_samples, sample_steps, algorithm):
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm={algorithm}, '
f'mini_batch_size={config.sample.mini_batch_size}')
def sample_fn(_n_samples):
_x_init = torch.randn(_n_samples, *dataset.data_shape, device=device)
if config.train.mode == 'uncond':
kwargs = dict()
elif config.train.mode == 'cond':
kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
else:
raise NotImplementedError
if algorithm == 'euler_maruyama_sde':
return sde.euler_maruyama(sde.ReverseSDE(score_model_ema), _x_init, sample_steps, **kwargs)
elif algorithm == 'euler_maruyama_ode':
return sde.euler_maruyama(sde.ODE(score_model_ema), _x_init, sample_steps, **kwargs)
elif algorithm == 'dpm_solver':
noise_schedule = NoiseScheduleVP(schedule='linear')
model_fn = model_wrapper(
score_model_ema.noise_pred,
noise_schedule,
time_input_type='0',
model_kwargs=kwargs
)
dpm_solver = DPM_Solver(model_fn, noise_schedule)
return dpm_solver.sample(
_x_init,
steps=sample_steps,
eps=1e-4,
adaptive_step_size=False,
fast_version=True,
)
else:
raise NotImplementedError
with tempfile.TemporaryDirectory() as temp_path:
path = config.sample.path or temp_path
if accelerator.is_main_process:
os.makedirs(path, exist_ok=True)
utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
_fid = 0
if accelerator.is_main_process:
_fid = calculate_fid_given_paths((dataset.fid_stat, path))
logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
_fid = torch.tensor(_fid, device=device)
_fid = accelerator.reduce(_fid, reduction='sum')
return _fid.item()
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
step_fid = []
while train_state.step < config.train.n_steps:
nnet.train()
batch = tree_map(lambda x: x.to(device), next(data_generator))
metrics = train_step(batch)
nnet.eval()
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
logging.info(config.workdir)
wandb.log(metrics, step=train_state.step)
if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
logging.info('Save a grid of images...')
x_init = torch.randn(100, *dataset.data_shape, device=device)
if config.train.mode == 'uncond':
samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50)
elif config.train.mode == 'cond':
y = einops.repeat(torch.arange(10, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
samples = sde.euler_maruyama(sde.ODE(score_model_ema), x_init=x_init, sample_steps=50, y=y)
else:
raise NotImplementedError
samples = make_grid(dataset.unpreprocess(samples), 10)
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
logging.info(f'Save and eval checkpoint {train_state.step}...')
if accelerator.local_process_index == 0:
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
accelerator.wait_for_everyone()
fid = eval_step(n_samples=10000, sample_steps=50, algorithm='dpm_solver')
step_fid.append((train_state.step, fid))
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
logging.info(f'Finish fitting, step={train_state.step}')
logging.info(f'step_fid: {step_fid}')
step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
logging.info(f'step_best: {step_best}')
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
del metrics
accelerator.wait_for_everyone()
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps, algorithm=config.sample.algorithm)
from absl import flags
from absl import app
from ml_collections import config_flags
import sys
from pathlib import Path
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("workdir", None, "Work unit directory.")
def get_config_name():
argv = sys.argv
for i in range(1, len(argv)):
if argv[i].startswith('--config='):
return Path(argv[i].split('=')[-1]).stem
def get_hparams():
argv = sys.argv
lst = []
for i in range(1, len(argv)):
assert '=' in argv[i]
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
hparam, val = argv[i].split('=')
hparam = hparam.split('.')[-1]
if hparam.endswith('path'):
val = Path(val).stem
lst.append(f'{hparam}={val}')
hparams = '-'.join(lst)
if hparams == '':
hparams = 'default'
return hparams
def main(argv):
config = FLAGS.config
config.config_name = get_config_name()
config.hparams = get_hparams()
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
config.sample_dir = os.path.join(config.workdir, 'samples')
train(config)
if __name__ == "__main__":
app.run(main)