from datasets import Audio, load_dataset
from torch.utils.data import Dataset
from transformers import WhisperProcessor
from megatron.training import get_args
class AudioDataset(Dataset):
def __init__(
self,
basic_param: dict,
preprocess_param: dict,
**kwargs,
):
super().__init__()
self.dataset = self.get_whisper_dataset(basic_param, preprocess_param)
def get_whisper_dataset(self, basic_param, preprocess_param):
dataset_name_or_path = basic_param.get(
"dataset_name_or_path", "mozilla-foundation/common_voice_11_0"
)
language = basic_param.get("language", "hi")
processor_name_or_path = preprocess_param.get(
"processor_name_or_path", "openai/whisper-large-v3"
)
processor_language = preprocess_param.get("language", "Hindi")
task = preprocess_param.get("task", "transcribe")
train_dataset = load_dataset(
dataset_name_or_path,
language,
split="train+validation",
trust_remote_code=get_args().trust_remote_code,
)
train_dataset = train_dataset.remove_columns(
[
"accent",
"age",
"client_id",
"down_votes",
"gender",
"locale",
"path",
"segment",
"up_votes",
]
)
train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
processor = WhisperProcessor.from_pretrained(
processor_name_or_path,
language=processor_language,
task=task,
local_files_only=True,
)
feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
train_dataset = train_dataset.map(prepare_dataset)
return train_dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)