import argparse
import json
import os
import torch
if torch.__version__ >= "1.8":
import torch_npu
import time
from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from glow import WaveGlow, WaveGlowLoss
from mel2samp import Mel2Samp
def load_checkpoint(checkpoint_path, model, optimizer):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
optimizer.load_state_dict(checkpoint_dict['optimizer'])
model_for_loading = checkpoint_dict['model']
model.load_state_dict(model_for_loading)
print("Loaded checkpoint '{}' (iteration {})" .format(
checkpoint_path, iteration))
return model, optimizer, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
print("Saving model and optimizer state at iteration {} to {}".format(
iteration, filepath))
model_for_saving = WaveGlow(**waveglow_config).to("npu:0")
model_for_saving.load_state_dict(model.state_dict())
torch.save({'model': model_for_saving.state_dict(),
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, filepath)
def train(num_gpus, rank, group_name, step_stop, output_directory, epochs, learning_rate,
sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
checkpoint_path, with_tensorboard):
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
if num_gpus == 1:
LOCAL_RANK = 0
elif num_gpus > 1:
init_distributed(rank, num_gpus, group_name, **dist_config)
LOCAL_RANK = torch.distributed.get_rank()
criterion = WaveGlowLoss(sigma)
model = WaveGlow(**waveglow_config).to("npu:{}".format(LOCAL_RANK))
if num_gpus > 1:
model = apply_gradient_allreduce(model)
from apex.optimizers import NpuFusedAdam
optimizer = NpuFusedAdam(model.parameters(),lr=learning_rate)
if fp16_run:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O2',loss_scale = 128.0)
if(num_gpus > 1):
model = DistributedDataParallel(model,
device_ids=[LOCAL_RANK],
broadcast_buffers=False)
iteration = 0
if checkpoint_path != "":
model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
optimizer)
iteration += 1
trainset = Mel2Samp(**data_config)
train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
sampler=train_sampler,
batch_size=batch_size,
pin_memory=False,
drop_last=True)
if rank == 0:
if not os.path.isdir(output_directory):
os.makedirs(output_directory)
os.chmod(output_directory, 0o775)
print("output directory", output_directory)
if with_tensorboard and rank == 0:
from tensorboardX import SummaryWriter
logger = SummaryWriter(os.path.join(output_directory, 'logs'))
model.train()
epoch_offset = max(0, int(iteration / len(train_loader)))
for epoch in range(epoch_offset, epochs):
if(num_gpus > 1):
train_sampler.set_epoch(epoch)
iter_start_time = time.time()
fps_start_step = 0
print("Epoch: {}".format(epoch))
for i, batch in enumerate(train_loader):
if(fps_start_step == 5):
iter5_start_time = time.time()
model.zero_grad()
mel, audio = batch
mel = torch.autograd.Variable(mel.to("npu:{}".format(LOCAL_RANK)))
audio = torch.autograd.Variable(audio.to("npu:{}".format(LOCAL_RANK)))
outputs = model((mel, audio))
loss = criterion(outputs)
if num_gpus > 1:
reduced_loss = reduce_tensor(loss.data, num_gpus).item()
else:
reduced_loss = loss.item()
if fp16_run:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
print("{}:\tloss:{:.9f}\ttime:{:.9f}\t".format(iteration,
reduced_loss,
time.time()-iter_start_time))
iter_start_time = time.time()
if with_tensorboard and rank == 0:
logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
if (iteration % iters_per_checkpoint == 0):
if rank == 0:
checkpoint_path = "{}/waveglow_{}".format(
output_directory, iteration)
if(num_gpus == 1):
save_checkpoint(model, optimizer, learning_rate, iteration,
checkpoint_path)
else:
save_checkpoint(model.module, optimizer, learning_rate, iteration,
checkpoint_path)
iteration += 1
fps_start_step += 1
if step_stop == 1:
break
if step_stop == 1:
break
print("Epoch {} FPS:{:.9f}".format(epoch, (batch_size * num_gpus)/(time.time() - iter5_start_time)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
help='JSON file for configuration')
parser.add_argument('-r', '--rank', type=int, default=0,
help='rank of process for distributed')
parser.add_argument('-g', '--group_name', type=str, default='',
help='name of group for distributed')
parser.add_argument('-step_stop', type=int,default=0,
help='step stop for CANNProfiling')
parser.add_argument('-fp16_run', type=bool, default=True,
help='for apex mixed training')
parser.add_argument('-output_directory', type=str, default='checkpoints',
help='output directory of checkpoints pth')
parser.add_argument('-epochs', type=int, default=100000,
help='training epochs')
parser.add_argument('-learning_rate', type=float, default=1e-4,
help='learing rate')
parser.add_argument('-sigma', type=float, default=1.0,
help='sigma value')
parser.add_argument('-iters_per_checkpoint', type=int, default=2000,
help='iterations of checkpoints')
parser.add_argument('-batch_size', type=int, default=12,
help='batch size')
parser.add_argument('-seed', type=int, default=1234,
help='random seed')
parser.add_argument('-checkpoint_path', type=str, default='',
help='path of pre-trained models')
parser.add_argument('-with_tensorboard', action="store_true")
args = parser.parse_args()
with open(args.config) as f:
data = f.read()
config = json.loads(data)
global data_config
data_config = config["data_config"]
global dist_config
dist_config = config["dist_config"]
global waveglow_config
waveglow_config = config["waveglow_config"]
num_gpus = torch.npu.device_count()
if num_gpus > 1:
if args.group_name == '':
print("WARNING: Multiple NPUs detected but no distributed group set")
print("Only running 1 NPU. Use distributed.py for multiple NPUs")
num_gpus = 1
if num_gpus == 1 and args.rank != 0:
raise Exception("Doing single NPU training on rank > 0")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
train(num_gpus, args.rank, args.group_name, args.step_stop,
args.output_directory, args.epochs, args.learning_rate,
args.sigma, args.iters_per_checkpoint, args.batch_size,
args.seed, args.fp16_run, args.checkpoint_path, args.with_tensorboard)