from tqdm import tqdm
from collections import OrderedDict
from apex import amp
import os
import argparse
import json
import numpy as np
import torch
if torch.__version__ >="1.8":
import torch_npu
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils import data
from apex.optimizers import NpuFusedAdam
from dataloader import *
from model_RawNet2 import RawNet2
from parser1 import get_args
from trainer import *
from utils import *
import torch.npu
CALCULATE_DEVICE = "npu:0"
def main():
args = get_args()
if args.reproducible:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = CALCULATE_DEVICE
l_dev = sorted(get_utt_list(args.DB_vox2 + args.dev_wav))
l_val = sorted(get_utt_list(args.DB + args.val_wav))
l_eval = sorted(get_utt_list(args.DB + args.eval_wav))
d_label_vox2 = get_label_dic_Voxceleb(l_dev)
args.model['nb_classes'] = len(list(d_label_vox2.keys()))
if bool(False):
make_validation_trial(l_utt = l_val, nb_trial = args.nb_val_trial, dir_val_trial = args.DB + 'val_trial.txt')
with open(args.DB + 'val_trial.txt', 'r') as f:
l_val_trial = f.readlines()
with open(args.DB + 'veri_test.txt', 'r') as f:
l_eval_trial = f.readlines()
devset = Dataset_VoxCeleb2(list_IDs=l_dev,
labels=d_label_vox2,
nb_samp=args.nb_samp,
base_dir=args.DB_vox2 + args.dev_wav)
devset_gen = torch.utils.data.DataLoader(devset,
batch_size=args.bs,
shuffle=True,
drop_last=True,
num_workers=args.nb_worker
)
valset = Dataset_VoxCeleb2(list_IDs=l_val,
return_label=False,
nb_samp=args.nb_samp,
base_dir=args.DB)
valset_gen = data.DataLoader(valset,
batch_size=args.bs,
shuffle=False,
drop_last=False,
num_workers=args.nb_worker)
TA_evalset = TA_Dataset_VoxCeleb2(list_IDs=l_eval,
return_label=False,
window_size=args.window_size,
nb_samp=args.nb_samp,
base_dir=args.DB + args.eval_wav)
TA_evalset_gen = torch.utils.data.DataLoader(TA_evalset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=args.nb_worker)
save_dir = args.save_dir + args.name + '/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if not os.path.exists(save_dir + 'results/'):
os.makedirs(save_dir + 'results/')
if not os.path.exists(save_dir + 'models/'):
os.makedirs(save_dir + 'models/')
if not os.path.exists(save_dir + 'prof/'):
os.makedirs(save_dir + 'prof/')
if not os.path.exists(save_dir + 'log/'):
os.makedirs(save_dir + 'log/')
f_params = open(save_dir + 'f_params.txt', 'w')
for k, v in sorted(vars(args).items()):
print(k, v)
f_params.write('{}:\t{}\n'.format(k, v))
for k, v in sorted(args.model.items()):
print(k, v)
f_params.write('{}:\t{}\n'.format(k, v))
f_params.close()
if bool(args.mg):
model = RawNet2(args.model)
if args.load_model:
model.load_state_dict(torch.load(args.load_model_dir))
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
model = model.npu()
else:
model = RawNet2(args.model).to(device)
if args.load_model: model.load_state_dict(torch.load(args.load_model_dir))
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
if not args.load_model:
model.apply(init_weights)
else:
optimizer.load_state_dict(torch.load(args.load_model_opt_dir))
print('nb_params: {}'.format(nb_params))
criterion = {}
criterion['cce'] = nn.CrossEntropyLoss().to(device)
params = [
{
'params': [
param for name, param in model.named_parameters()
if 'bn' not in name
]
},
{
'params': [
param for name, param in model.named_parameters()
if 'bn' in name
],
'weight_decay':
0
},
]
if args.optimizer.lower() == 'sgd':
model = model.to(device)
optimizer = NpuFusedSgd(params,
lr=args.lr,
momentum=args.opt_mom,
weight_decay=args.wd,
nesterov=args.nesterov)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0)
elif args.optimizer.lower() == 'adam':
optimizer = NpuFusedAdam(params,
lr=args.lr,
weight_decay=args.wd,
amsgrad=args.amsgrad)
if args.load_model:
ckpt = torch.load(args.load_model_dir, map_location=torch.device('cpu'))
state_dict = ckpt['model']
remove_module = False
for k, v in state_dict.items():
if 'module.' in k:
remove_module = True
break
if remove_module:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
else:
new_state_dict = ckpt['model']
model.load_state_dict(new_state_dict)
optimizer.load_state_dict(ckpt['optimizer'])
print("Load Model Successfully!")
model = model.to(device)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0)
else:
raise NotImplementedError('Add other optimizers if needed')
if not args.load_model:
model.apply(init_weights)
if bool(args.do_lr_decay):
if args.lr_decay == 'keras':
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: keras_lr_decay(step))
elif args.lr_decay == 'cosine':
raise NotImplementedError('Not implemented yet')
else:
raise NotImplementedError('Not implemented yet')
best_TA_eval_eer = 99.
f_eer = open(save_dir + 'eers.txt', 'a', buffering=1)
for epoch in tqdm(range(args.epoch)):
train_model(model=model,
db_gen=devset_gen,
args=args,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
criterion=criterion,
device=device,
epoch=epoch)
TA_eval_eer = time_augmented_evaluate_model(mode='eval',
model=model,
db_gen=TA_evalset_gen,
l_utt=l_eval,
save_dir=save_dir,
epoch=epoch,
l_trial=l_eval_trial,
args=args,
device=device)
f_eer.write('epoch:%d, TA_eval_eer:%.4f\n' % (epoch, TA_eval_eer))
save_model_dict = model.state_dict()
if float(TA_eval_eer) < best_TA_eval_eer:
print('New best TA_EER: %f'%float(TA_eval_eer))
best_TA_eval_eer = float(TA_eval_eer)
torch.save({'epoch': epoch,
'model': save_model_dict,
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()}, save_dir + 'models/TA_%d_%.4f.pt'%(epoch, TA_eval_eer))
f_eer.close()
if __name__ == '__main__':
if 'npu' in CALCULATE_DEVICE:
torch.npu.set_device(CALCULATE_DEVICE)
main()