from glob import glob
import json
import sys
from ECAPA_TDNN.mel2samp_tacotron2 import Mel2SampWaveglow
from ECAPA_TDNN.prepare_batch_loader import struct_meta, write_to_csv, read_from_csv, reduce_meta, build_speaker_dict, collate_function
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
from functools import partial
from tqdm import tqdm
CONFIGURATION_FILE = 'config.json'
T_THRES = 19
DATA_SET = sys.argv[1]
with open(CONFIGURATION_FILE) as f:
data = f.read()
json_info = json.loads(data)
mel_config = json_info["mel_config"]
MEL2SAMPWAVEGLOW = Mel2SampWaveglow(**mel_config)
hp = json_info["hp"]
global_scope = sys.modules[__name__]
for key in hp:
setattr(global_scope, key, hp[key])
def load_meta(dataset, keyword='vox1'):
if keyword == 'vox1':
wav_files_test = sorted(glob(dataset + '/*/*/*.wav'))
print(f'Len. wav_files_test {len(wav_files_test)}')
test_meta = struct_meta(wav_files_test)
write_to_csv(test_meta, 'vox1_test.csv')
return test_meta
def get_dataloader(keyword='vox1', t_thres=19, batchsize = 16, dataset = DATA_SET):
test_meta = load_meta(dataset, keyword)
test_meta_ = [meta for meta in tqdm(test_meta) if meta[2] < t_thres]
test_meta = reduce_meta(test_meta_, speaker_num=REDUCED_SPEAKER_NUM_TEST)
print(f'Meta reduced {len(test_meta_)} => {len(test_meta)}')
test_speakers = build_speaker_dict(test_meta)
dataset_test = DataLoader(test_meta, batch_size=batchsize,
shuffle=False, num_workers=2,
collate_fn=partial(collate_function,
speaker_table=test_speakers,
max_mel_length=MAX_MEL_LENGTH),
prefetch_factor=2,
pin_memory=True,
drop_last=True)
return dataset_test, test_speakers
if __name__ == "__main__":
predata_path = sys.argv[2]
prespeaker_path = sys.argv[3]
batchsize = 1
dataset_test, test_speakers = get_dataloader('vox1', 19, batchsize)
if not os.path.exists(predata_path):
os.makedirs(predata_path)
if not os.path.exists(prespeaker_path):
os.makedirs(prespeaker_path)
i=0
for mels, mel_length, speakers in tqdm(dataset_test):
i=i+1
mels = np.array(mels).astype(np.float32)
mels.tofile(os.path.join(predata_path, f'mels{str(i)}.bin'))
torch.save(speakers, os.path.join(prespeaker_path, f'speakers{str(i)}.pt'))