import json
import logging
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.utils.data
from . import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
F0_FRAME_SPACE = 0.005
logger = logging.getLogger(__name__)
class ExpressiveCodeDataConfig(object):
def __init__(self, json_path):
with open(json_path, "r") as f:
self.config = json.load(f)
self._manifests = self.config["manifests"]
@property
def manifests(self):
return self._manifests
@property
def n_units(self):
return self.config["n_units"]
@property
def sampling_rate(self):
return self.config["sampling_rate"]
@property
def code_hop_size(self):
return self.config["code_hop_size"]
@property
def f0_stats(self):
"""pre-computed f0 statistics path"""
return self.config.get("f0_stats", None)
@property
def f0_vq_type(self):
"""naive or precomp"""
return self.config["f0_vq_type"]
@property
def f0_vq_name(self):
return self.config["f0_vq_name"]
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std):
key = "log" if log else "linear"
if norm_mean and norm_std:
key += "_mean_std_norm"
elif norm_mean:
key += "_mean_norm"
else:
key += "_none_norm"
return self.config["f0_vq_naive_quantizer"][key]
@property
def f0_vq_n_units(self):
return self.config["f0_vq_n_units"]
@property
def multispkr(self):
"""how to parse speaker label from audio path"""
return self.config.get("multispkr", None)
def get_f0(audio, rate=16000):
try:
import amfm_decompy.basic_tools as basic
import amfm_decompy.pYAAPT as pYAAPT
from librosa.util import normalize
except ImportError:
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
assert audio.ndim == 1
frame_length = 20.0
to_pad = int(frame_length / 1000 * rate) // 2
audio = normalize(audio) * 0.95
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0)
audio = basic.SignalObj(audio, rate)
pitch = pYAAPT.yaapt(
audio,
frame_length=frame_length,
frame_space=F0_FRAME_SPACE * 1000,
nccf_thresh1=0.25,
tda_frame_length=25.0,
)
f0 = pitch.samp_values
return f0
def interpolate_f0(f0):
try:
from scipy.interpolate import interp1d
except ImportError:
raise "Please install scipy (`pip install scipy`)"
orig_t = np.arange(f0.shape[0])
f0_interp = f0[:]
ii = f0_interp != 0
if ii.sum() > 1:
f0_interp = interp1d(
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0
)(orig_t)
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device)
return f0_interp
def naive_quantize(x, edges):
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1)
return bin_idx
def load_wav(full_path):
try:
import soundfile as sf
except ImportError:
raise "Please install soundfile (`pip install SoundFile`)"
data, sampling_rate = sf.read(full_path)
return data, sampling_rate
def parse_code(code_str, dictionary, append_eos):
code, duration = torch.unique_consecutive(
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True
)
code = " ".join(map(str, code.tolist()))
code = dictionary.encode_line(code, append_eos).short()
if append_eos:
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0)
duration = duration.short()
return code, duration
def parse_manifest(manifest, dictionary):
audio_files = []
codes = []
durations = []
speakers = []
with open(manifest) as info:
for line in info.readlines():
sample = eval(line.strip())
if "cpc_km100" in sample:
k = "cpc_km100"
elif "hubert_km100" in sample:
k = "hubert_km100"
elif "phone" in sample:
k = "phone"
else:
assert False, "unknown format"
code = sample[k]
code, duration = parse_code(code, dictionary, append_eos=True)
codes.append(code)
durations.append(duration)
audio_files.append(sample["audio"])
speakers.append(sample.get("speaker", None))
return audio_files, codes, durations, speakers
def parse_speaker(path, method):
if type(path) == str:
path = Path(path)
if method == "parent_name":
return path.parent.name
elif method == "parent_parent_name":
return path.parent.parent.name
elif method == "_":
return path.name.split("_")[0]
elif method == "single":
return "A"
elif callable(method):
return method(path)
else:
raise NotImplementedError()
def get_f0_by_filename(filename, tgt_sampling_rate):
audio, sampling_rate = load_wav(filename)
if sampling_rate != tgt_sampling_rate:
raise ValueError(
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate)
)
f0 = get_f0(audio, rate=tgt_sampling_rate)
f0 = torch.from_numpy(f0.astype(np.float32))
return f0
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1):
code_len = durations.sum()
targ_len = int(f0_code_ratio * code_len)
diff = f0.size(0) - targ_len
assert abs(diff) <= tol, (
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|"
f" > {tol} (dur=\n{durations})"
)
if diff > 0:
f0 = f0[:targ_len]
elif diff < 0:
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0)
f0_offset = 0.0
seg_f0s = []
for dur in durations:
f0_dur = dur.item() * f0_code_ratio
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)]
seg_f0 = seg_f0[seg_f0 != 0]
if len(seg_f0) == 0:
seg_f0 = torch.tensor(0).type(seg_f0.type())
else:
seg_f0 = seg_f0.mean()
seg_f0s.append(seg_f0)
f0_offset += f0_dur
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}"
return torch.tensor(seg_f0s)
class Paddings(object):
def __init__(self, code_val, dur_val=0, f0_val=-2.0):
self.code = code_val
self.dur = dur_val
self.f0 = f0_val
class Shifts(object):
def __init__(self, shifts_str, pads):
self._shifts = list(map(int, shifts_str.split(",")))
assert len(self._shifts) == 2, self._shifts
assert all(s >= 0 for s in self._shifts)
self.extra_length = max(s for s in self._shifts)
self.pads = pads
@property
def dur(self):
return self._shifts[0]
@property
def f0(self):
return self._shifts[1]
@staticmethod
def shift_one(seq, left_pad_num, right_pad_num, pad):
assert seq.ndim == 1
bos = seq.new_full((left_pad_num,), pad)
eos = seq.new_full((right_pad_num,), pad)
seq = torch.cat([bos, seq, eos])
mask = torch.ones_like(seq).bool()
mask[left_pad_num : len(seq) - right_pad_num] = 0
return seq, mask
def __call__(self, code, dur, f0):
if self.extra_length == 0:
code_mask = torch.zeros_like(code).bool()
dur_mask = torch.zeros_like(dur).bool()
f0_mask = torch.zeros_like(f0).bool()
return code, code_mask, dur, dur_mask, f0, f0_mask
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code)
dur, dur_mask = self.shift_one(
dur, self.dur, self.extra_length - self.dur, self.pads.dur
)
f0, f0_mask = self.shift_one(
f0, self.f0, self.extra_length - self.f0, self.pads.f0
)
return code, code_mask, dur, dur_mask, f0, f0_mask
class CodeDataset(FairseqDataset):
def __init__(
self,
manifest,
dictionary,
dur_dictionary,
f0_dictionary,
config,
discrete_dur,
discrete_f0,
log_f0,
normalize_f0_mean,
normalize_f0_std,
interpolate_f0,
return_filename=False,
strip_filename=True,
shifts="0,0",
return_continuous_f0=False,
):
random.seed(1234)
self.dictionary = dictionary
self.dur_dictionary = dur_dictionary
self.f0_dictionary = f0_dictionary
self.config = config
self.discrete_dur = discrete_dur
self.discrete_f0 = discrete_f0
self.log_f0 = log_f0
self.normalize_f0_mean = normalize_f0_mean
self.normalize_f0_std = normalize_f0_std
self.interpolate_f0 = interpolate_f0
self.return_filename = return_filename
self.strip_filename = strip_filename
self.f0_code_ratio = config.code_hop_size / (
config.sampling_rate * F0_FRAME_SPACE
)
self.manifest = manifest
self._codes = None
self._durs = None
self._f0s = None
with open(f"{manifest}.leng.txt", "r") as f:
lengs = [int(line.rstrip()) for line in f]
edges = np.cumsum([0] + lengs)
self.starts, self.ends = edges[:-1], edges[1:]
with open(f"{manifest}.path.txt", "r") as f:
self.file_names = [line.rstrip() for line in f]
logger.info(f"num entries: {len(self.starts)}")
if os.path.exists(f"{manifest}.f0_stat.pt"):
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt")
elif config.f0_stats:
self.f0_stats = torch.load(config.f0_stats)
self.multispkr = config.multispkr
if config.multispkr:
with open(f"{manifest}.speaker.txt", "r") as f:
self.spkrs = [line.rstrip() for line in f]
self.id_to_spkr = sorted(self.spkrs)
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)}
self.pads = Paddings(
dictionary.pad(),
0,
f0_dictionary.pad() if discrete_f0 else -5.0,
)
self.shifts = Shifts(shifts, pads=self.pads)
self.return_continuous_f0 = return_continuous_f0
def get_data_handlers(self):
logging.info(f"loading data for {self.manifest}")
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r")
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r")
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
self._f0s = np.load(
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r"
)
elif self.config.f0_vq_type == "naive":
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
quantizers_path = self.config.get_f0_vq_naive_quantizer(
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std
)
quantizers = torch.load(quantizers_path)
n_units = self.config.f0_vq_n_units
self._f0_quantizer = torch.from_numpy(quantizers[n_units])
else:
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported")
else:
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r")
def preprocess_f0(self, f0, stats):
"""
1. interpolate
2. log transform (keep unvoiced frame 0)
"""
f0 = f0.clone()
if self.interpolate_f0:
f0 = interpolate_f0(f0)
mask = f0 != 0
if self.log_f0:
f0[mask] = f0[mask].log()
if self.normalize_f0_mean:
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"]
f0[mask] = f0[mask] - mean
if self.normalize_f0_std:
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"]
f0[mask] = f0[mask] / std
return f0
def _get_raw_item(self, index):
start, end = self.starts[index], self.ends[index]
if self._codes is None:
self.get_data_handlers()
code = torch.from_numpy(np.array(self._codes[start:end])).long()
dur = torch.from_numpy(np.array(self._durs[start:end]))
f0 = torch.from_numpy(np.array(self._f0s[start:end]))
return code, dur, f0
def __getitem__(self, index):
code, dur, f0 = self._get_raw_item(index)
code = torch.cat([code.new([self.dictionary.bos()]), code])
dur = torch.cat([dur.new([0]), dur])
if self.discrete_dur:
dur = self.dur_dictionary.encode_line(
" ".join(map(str, dur.tolist())), append_eos=False
).long()
else:
dur = dur.float()
raw_f0 = None
if self.discrete_f0:
if self.config.f0_vq_type == "precomp":
f0 = self.f0_dictionary.encode_line(
" ".join(map(str, f0.tolist())), append_eos=False
).long()
else:
f0 = f0.float()
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
if self.return_continuous_f0:
raw_f0 = f0
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0])
f0 = naive_quantize(f0, self._f0_quantizer)
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0])
else:
f0 = f0.float()
if self.multispkr:
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]])
else:
f0 = self.preprocess_f0(f0, self.f0_stats)
f0 = torch.cat([f0.new([0]), f0])
if raw_f0 is not None:
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0)
else:
raw_f0_mask = None
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0)
if raw_f0_mask is not None:
assert (raw_f0_mask == f0_mask).all()
feats = {
"source": code[:-1],
"target": code[1:],
"mask": code_mask[1:].logical_or(code_mask[:-1]),
"dur_source": dur[:-1],
"dur_target": dur[1:],
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]),
"f0_source": f0[:-1],
"f0_target": f0[1:],
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]),
}
if raw_f0 is not None:
feats["raw_f0"] = raw_f0[1:]
if self.return_filename:
fname = self.file_names[index]
feats["filename"] = (
fname if not self.strip_filename else Path(fname).with_suffix("").name
)
return feats
def __len__(self):
return len(self.starts)
def size(self, index):
return self.ends[index] - self.starts[index] + self.shifts.extra_length
def num_tokens(self, index):
return self.size(index)
def collater(self, samples):
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos()
if len(samples) == 0:
return {}
src_tokens = data_utils.collate_tokens(
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False
)
tgt_tokens = data_utils.collate_tokens(
[s["target"] for s in samples],
pad_idx=pad_idx,
eos_idx=pad_idx,
left_pad=False,
)
src_durs, tgt_durs = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.dur,
eos_idx=self.pads.dur,
left_pad=False,
)
for k in ["dur_source", "dur_target"]
]
src_f0s, tgt_f0s = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
for k in ["f0_source", "f0_target"]
]
mask, dur_mask, f0_mask = [
data_utils.collate_tokens(
[s[k] for s in samples],
pad_idx=1,
eos_idx=1,
left_pad=False,
)
for k in ["mask", "dur_mask", "f0_mask"]
]
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
n_tokens = sum(len(s["source"]) for s in samples)
result = {
"nsentences": len(samples),
"ntokens": n_tokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"dur_src": src_durs,
"f0_src": src_f0s,
},
"target": tgt_tokens,
"dur_target": tgt_durs,
"f0_target": tgt_f0s,
"mask": mask,
"dur_mask": dur_mask,
"f0_mask": f0_mask,
}
if "filename" in samples[0]:
result["filename"] = [s["filename"] for s in samples]
if "prefix" in samples[0]:
result["prefix"] = [s["prefix"] for s in samples]
if "raw_f0" in samples[0]:
raw_f0s = data_utils.collate_tokens(
[s["raw_f0"] for s in samples],
pad_idx=self.pads.f0,
eos_idx=self.pads.f0,
left_pad=False,
)
result["raw_f0"] = raw_f0s
return result