import os
import sys
sys.path.append('./fairseq')
import numpy as np
import torch
import torch.nn.functional as F
import argparse
import fairseq
from fairseq.data.audio import hubert_dataset
from fairseq.data import data_utils
from fairseq.tasks.hubert_pretraining import LabelEncoder
import itertools
from tqdm import tqdm
def postprocess(wav):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
wav = F.layer_norm(wav, wav.shape)
return wav
def get_audio(audio_root, audio_names, index):
import soundfile as sf
wav_path = os.path.join(audio_root, audio_names[index])
wav, cur_sample_rate = sf.read(wav_path)
wav = torch.from_numpy(wav).float()
wav = postprocess(wav)
return wav
def get_label(label_processors, label_paths,label_offsets_list, index, label_idx):
with open(label_paths[label_idx]) as f:
offset_s, offset_e = label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if label_processors is not None:
label = label_processors[label_idx](label)
return label
def getitem(audio_root,audio_names, label_processors, label_paths, label_offsets_list, index):
wav = get_audio(audio_root, audio_names, index)
labels = get_labels(label_processors, label_paths, label_offsets_list, index)
return {"id": index, "source": wav, "label_list": labels}
def get_labels(label_processors, label_paths, label_offsets_list, index):
return [get_label(label_processors, label_paths, label_offsets_list, index, i) for i in range(1)]
def ordered_indices(sizes):
order = [np.random.permutation(len(sizes))]
order.append(sizes)
return np.lexsort(order)[::-1]
def load_label_offset(label_path, inds, tot):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
assert (
len(code_lengths) == tot
), f"number of labels does not match ({len(code_lengths)} != {tot})"
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
def run_preprocess(args):
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([args.model_path])
if os.path.exists(args.pre_data_source_save_path) == False:
os.makedirs(args.pre_data_source_save_path)
if os.path.exists(args.pre_data_label_save_path) == False:
os.makedirs(args.pre_data_label_save_path)
label_paths = []
label_paths.append(args.datasets_ltr_path)
audio_root, audio_names, inds, tot, sizes = hubert_dataset.load_audio(args.datasets_tsv_path, None, None)
label_processors = [LabelEncoder(task.target_dictionary)]
label_offsets_list = [load_label_offset(p, inds, tot) for p in label_paths]
with data_utils.numpy_seed(1):
indices = ordered_indices(sizes)
for i in tqdm(range (len(indices))):
sample = getitem(audio_root, audio_names, label_processors, label_paths, label_offsets_list, indices[i])
sample["source"] = sample["source"].view(1, -1)
len_source = len(sample["source"][0].cpu())
add_source = torch.zeros(1, 580000 - len_source).float().to("cpu")
sample["source"] = torch.cat((sample["source"], add_source), 1)
np_source = np.array(sample["source"])
np_label = np.array(sample["label_list"][0])
np_source.tofile(os.path.join(args.pre_data_source_save_path + "source" + str(i) + '.bin'))
np_label.tofile(os.path.join(args.pre_data_label_save_path + "label" + str(i) + '.bin'))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', default='./data/pt/hubert_large_ll60k_finetune_ls960.pt')
parser.add_argument('--datasets_tsv_path', default='./data/test-clean/train.tsv')
parser.add_argument('--datasets_ltr_path', default='./data/test-clean/train.ltr')
parser.add_argument('--pre_data_source_save_path', default='./pre_data/test-clean/source/')
parser.add_argument('--pre_data_label_save_path', default='./pre_data/test-clean/label/')
args = parser.parse_args()
run_preprocess(args)
if __name__ == '__main__':
main()