import nvidia.dali as dali
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import multiprocessing
import numpy as np
import torch
import math
import itertools
class DaliPipeline():
def __init__(self, *,
train_pipeline: bool,
device_id,
num_threads,
batch_size,
file_root: str,
file_list: str,
sample_rate,
discrete_resample_range: bool,
resample_range: list,
window_size,
window_stride,
nfeatures,
nfft,
frame_splicing_factor,
dither_coeff,
silence_threshold,
preemph_coeff,
pad_align,
max_duration,
mask_time_num_regions,
mask_time_min,
mask_time_max,
mask_freq_num_regions,
mask_freq_min,
mask_freq_max,
mask_both_num_regions,
mask_both_min_time,
mask_both_max_time,
mask_both_min_freq,
mask_both_max_freq,
preprocessing_device="gpu",
is_triton_pipeline=False):
self._dali_init_log(locals())
if torch.distributed.is_initialized():
shard_id = torch.distributed.get_rank()
n_shards = torch.distributed.get_world_size()
else:
shard_id = 0
n_shards = 1
self.preprocessing_device = preprocessing_device.lower()
assert self.preprocessing_device == "cpu" or self.preprocessing_device == "gpu", \
"Incorrect preprocessing device. Please choose either 'cpu' or 'gpu'"
self.frame_splicing_factor = frame_splicing_factor
assert frame_splicing_factor == 1, "Frame splicing is not yet implemented"
self.resample_range = resample_range
self.discrete_resample_range = discrete_resample_range
self.train = train_pipeline
self.sample_rate = sample_rate
self.dither_coeff = dither_coeff
self.nfeatures = nfeatures
self.max_duration = max_duration
self.mask_params = {
'time_num_regions': mask_time_num_regions,
'time_min': mask_time_min,
'time_max': mask_time_max,
'freq_num_regions': mask_freq_num_regions,
'freq_min': mask_freq_min,
'freq_max': mask_freq_max,
'both_num_regions': mask_both_num_regions,
'both_min_time': mask_both_min_time,
'both_max_time': mask_both_max_time,
'both_min_freq': mask_both_min_freq,
'both_max_freq': mask_both_max_freq,
}
self.do_remove_silence = True if silence_threshold is not None else False
@dali.pipeline_def
def dali_jasper_pipe():
if is_triton_pipeline:
assert not self.train, "Pipeline for Triton shall be a validation pipeline"
if torch.distributed.is_initialized():
raise RuntimeError(
"You're creating Triton pipeline, using multi-process mode. Please use single-process mode.")
encoded, label = fn.external_source(device="cpu", name="DALI_INPUT_0", no_copy=True)
else:
encoded, label = fn.readers.file(device="cpu", name="file_reader",
file_root=file_root, file_list=file_list, shard_id=shard_id,
num_shards=n_shards, shuffle_after_epoch=train_pipeline)
speed_perturbation_coeffs = None
if resample_range is not None:
if discrete_resample_range:
values = [self.resample_range[0], 1.0, self.resample_range[1]]
speed_perturbation_coeffs = fn.random.uniform(device="cpu", values=values)
else:
speed_perturbation_coeffs = fn.random.uniform(device="cpu", range=resample_range)
if self.train and speed_perturbation_coeffs is not None:
dec_sample_rate_arg = speed_perturbation_coeffs * self.sample_rate
elif resample_range is None:
dec_sample_rate_arg = self.sample_rate
else:
dec_sample_rate_arg = None
audio, _ = fn.decoders.audio(encoded, sample_rate=dec_sample_rate_arg, dtype=types.FLOAT, downmix=True)
if self.do_remove_silence:
begin, length = fn.nonsilent_region(audio, cutoff_db=silence_threshold)
audio = fn.slice(audio, begin, length, axes=[0])
if self.preprocessing_device == "gpu":
audio = audio.gpu()
if self.dither_coeff != 0.:
audio = audio + fn.random.normal(audio) * self.dither_coeff
audio = fn.preemphasis_filter(audio, preemph_coeff=preemph_coeff)
spec = fn.spectrogram(audio, nfft=nfft,
window_length=window_size * sample_rate, window_step=window_stride * sample_rate)
mel_spec = fn.mel_filter_bank(spec, sample_rate=sample_rate, nfilter=self.nfeatures, normalize=True)
log_features = fn.to_decibels(mel_spec, multiplier=np.log(10), reference=1.0, cutoff_db=math.log(1e-20))
log_features_len = fn.shapes(log_features)
if self.frame_splicing_factor != 1:
log_features_len = self._div_ceil(log_features_len, self.frame_splicing_factor)
log_features = fn.normalize(log_features, axes=[1])
log_features = fn.pad(log_features, axes=[1], fill_value=0, align=pad_align)
if self.train and self._do_spectrogram_masking():
anchors, shapes = fn.external_source(source=self._cutouts_generator, num_outputs=2, cycle=True)
log_features = fn.erase(log_features, anchor=anchors, shape=shapes, axes=[0, 1], fill_value=0,
normalized_anchor=True)
return log_features.gpu(), label.gpu(), log_features_len.gpu()
self.pipe_handle = dali_jasper_pipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id)
def get_pipeline(self):
return self.pipe_handle
@classmethod
def from_config(cls, train_pipeline: bool, device_id, batch_size, file_root: str, file_list: str, config_data: dict,
config_features: dict, device_type: str = "gpu", do_resampling: bool = True,
num_cpu_threads=multiprocessing.cpu_count()):
max_duration = config_data['max_duration']
sample_rate = config_data['sample_rate']
silence_threshold = -60 if config_data['trim_silence'] else None
if do_resampling and config_data['speed_perturbation'] is not None:
resample_range = [config_data['speed_perturbation']['min_rate'],
config_data['speed_perturbation']['max_rate']]
discrete_resample_range = config_data['speed_perturbation']['discrete']
else:
resample_range = None
discrete_resample_range = False
window_size = config_features['window_size']
window_stride = config_features['window_stride']
nfeatures = config_features['n_filt']
nfft = config_features['n_fft']
frame_splicing_factor = config_features['frame_splicing']
dither_coeff = config_features['dither']
pad_align = config_features['pad_align']
pad_to_max_duration = config_features['pad_to_max_duration']
assert not pad_to_max_duration, "Padding to max duration currently not supported in DALI"
preemph_coeff = .97
config_spec = config_features['spec_augment']
if config_spec is not None:
mask_time_num_regions = config_spec['time_masks']
mask_time_min = config_spec['min_time']
mask_time_max = config_spec['max_time']
mask_freq_num_regions = config_spec['freq_masks']
mask_freq_min = config_spec['min_freq']
mask_freq_max = config_spec['max_freq']
else:
mask_time_num_regions = 0
mask_time_min = 0
mask_time_max = 0
mask_freq_num_regions = 0
mask_freq_min = 0
mask_freq_max = 0
config_cutout = config_features['cutout_augment']
if config_cutout is not None:
mask_both_num_regions = config_cutout['masks']
mask_both_min_time = config_cutout['min_time']
mask_both_max_time = config_cutout['max_time']
mask_both_min_freq = config_cutout['min_freq']
mask_both_max_freq = config_cutout['max_freq']
else:
mask_both_num_regions = 0
mask_both_min_time = 0
mask_both_max_time = 0
mask_both_min_freq = 0
mask_both_max_freq = 0
inst = cls(train_pipeline=train_pipeline,
device_id=device_id,
preprocessing_device=device_type,
num_threads=num_cpu_threads,
batch_size=batch_size,
file_root=file_root,
file_list=file_list,
sample_rate=sample_rate,
discrete_resample_range=discrete_resample_range,
resample_range=resample_range,
window_size=window_size,
window_stride=window_stride,
nfeatures=nfeatures,
nfft=nfft,
frame_splicing_factor=frame_splicing_factor,
dither_coeff=dither_coeff,
silence_threshold=silence_threshold,
preemph_coeff=preemph_coeff,
pad_align=pad_align,
max_duration=max_duration,
mask_time_num_regions=mask_time_num_regions,
mask_time_min=mask_time_min,
mask_time_max=mask_time_max,
mask_freq_num_regions=mask_freq_num_regions,
mask_freq_min=mask_freq_min,
mask_freq_max=mask_freq_max,
mask_both_num_regions=mask_both_num_regions,
mask_both_min_time=mask_both_min_time,
mask_both_max_time=mask_both_max_time,
mask_both_min_freq=mask_both_min_freq,
mask_both_max_freq=mask_both_max_freq)
return inst.get_pipeline()
@staticmethod
def _dali_init_log(args: dict):
if (not torch.distributed.is_initialized() or (
torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)):
max_len = max([len(ii) for ii in args.keys()])
fmt_string = '\t%' + str(max_len) + 's : %s'
print('Initializing DALI with parameters:')
for keyPair in sorted(args.items()):
print(fmt_string % keyPair)
@staticmethod
def _div_ceil(dividend, divisor):
return (dividend + (divisor - 1)) // divisor
def _do_spectrogram_masking(self):
return self.mask_params['time_num_regions'] > 0 or self.mask_params['freq_num_regions'] > 0 or \
self.mask_params['both_num_regions'] > 0
@staticmethod
def _interleave_lists(*lists):
"""
[*, **, ***], [1, 2, 3], [a, b, c] -> [*, 1, a, **, 2, b, ***, 3, c]
Returns:
iterator over interleaved list
"""
assert all((len(lists[0]) == len(test_l) for test_l in lists)), "All lists have to have the same length"
return itertools.chain(*zip(*lists))
def _generate_cutouts(self):
"""
Returns:
Generates anchors and shapes of the cutout regions.
Single call generates one batch of data.
The output shall be passed to DALI's Erase operator
anchors = [f0 t0 f1 t1 ...]
shapes = [f0w t0h f1w t1h ...]
"""
MAX_TIME_DIMENSION = 20 * 16000
freq_anchors = np.random.random(self.mask_params['freq_num_regions'])
time_anchors = np.random.random(self.mask_params['time_num_regions'])
both_anchors_freq = np.random.random(self.mask_params['both_num_regions'])
both_anchors_time = np.random.random(self.mask_params['both_num_regions'])
anchors = []
for anch in freq_anchors:
anchors.extend([anch, 0])
for anch in time_anchors:
anchors.extend([0, anch])
for t, f in zip(both_anchors_time, both_anchors_freq):
anchors.extend([f, t])
shapes = []
shapes.extend(
self._interleave_lists(
np.random.randint(self.mask_params['freq_min'], self.mask_params['freq_max'] + 1,
self.mask_params['freq_num_regions']),
[int(MAX_TIME_DIMENSION)] * self.mask_params['freq_num_regions']
)
)
shapes.extend(
self._interleave_lists(
[self.nfeatures] * self.mask_params['time_num_regions'],
np.random.randint(self.mask_params['time_min'], self.mask_params['time_max'] + 1,
self.mask_params['time_num_regions'])
)
)
shapes.extend(
self._interleave_lists(
np.random.randint(self.mask_params['both_min_freq'], self.mask_params['both_max_freq'] + 1,
self.mask_params['both_num_regions']),
np.random.randint(self.mask_params['both_min_time'], self.mask_params['both_max_time'] + 1,
self.mask_params['both_num_regions'])
)
)
return anchors, shapes
def _cutouts_generator(self):
"""
Generator, that wraps cutouts creation in order to randomize inputs
and allow passing them to DALI's ExternalSource operator
"""
def tuples2list(tuples: list):
"""
[(a, b), (c, d)] -> [[a, c], [b, d]]
"""
return map(list, zip(*tuples))
[anchors, shapes] = tuples2list([self._generate_cutouts() for _ in range(self.pipe_handle.max_batch_size)])
yield np.array(anchors, dtype=np.float32), np.array(shapes, dtype=np.float32)
class DaliTritonPipeline(DaliPipeline):
def __init__(self, **kwargs):
kwargs['is_triton_pipeline'] = True
super().__init__(**kwargs)
def serialize_dali_triton_pipeline(output_path: str, config_data: dict, config_features: dict):
pipe = DaliTritonPipeline.from_config(train_pipeline=False, device_id=-1, batch_size=-1, file_root=None,
file_list=None, config_data=config_data, config_features=config_features,
do_resampling=False, num_cpu_threads=-1)
pipe.serialize(filename=output_path)