from pathlib import Path
from typing import Union
import torch
import torchaudio
from vrag.constants import AUDIO_SAMPLE_RATE
from vrag.types import AudioChunkExtraction
def chunk_audio(
audio_path: Union[str, Path],
chunk_length: float = 30.0,
min_chunk_threshold: float = 1.0,
sample_rate: int = AUDIO_SAMPLE_RATE,
) -> AudioChunkExtraction:
"""
Splits audio into fixed-length chunks.
Discard the tail if it is shorter than min_chunk_threshold.
Args:
audio_path: Path to audio file.
chunk_length: Duration of each chunk in seconds.
min_chunk_threshold: Minimum duration in seconds of tail to be kept.
sample_rate: The sample rate of normal.
"""
if chunk_length < min_chunk_threshold:
raise ValueError(
f"Audio chunk_length: {chunk_length} must be larger than min_chunk_threshold: {min_chunk_threshold}"
)
audio_path = Path(audio_path)
speech, sr = torchaudio.load(audio_path)
speech = speech.mean(dim=0)
if sr != sample_rate:
speech = torchaudio.transforms.Resample(sr, sample_rate)(speech)
total_samples_num = len(speech)
num_samples_per_chunk = int(chunk_length * sample_rate)
min_samples_threshold = int(min_chunk_threshold * sample_rate)
chunks = []
durations = []
for i in range(0, total_samples_num, num_samples_per_chunk):
chunk = speech[i : i + num_samples_per_chunk]
current_length = len(chunk)
if current_length < num_samples_per_chunk:
if current_length < min_samples_threshold:
break
chunk = torch.nn.functional.pad(chunk, (0, num_samples_per_chunk - current_length))
start_time = i / sample_rate
end_time = (i + min(current_length, num_samples_per_chunk)) / sample_rate
chunks.append(chunk.numpy())
durations.append((start_time, end_time))
return AudioChunkExtraction(audio_chunks=chunks, durations=durations)