import os
import logging
import pytest
from mm import load_audio
from mm_test.common import TEST_HW_USER_FILE_PATH, TEST_HW_USER_WAV_LOAD_PATH
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_audio_func(wav_name, sampling_fraction):
if isinstance(wav_name, str) and wav_name.endswith('.wav'):
wav_inputs = os.path.join(TEST_HW_USER_WAV_LOAD_PATH, wav_name)
logger.info(f'Loading audio file: {wav_inputs}')
audio, sr = load_audio(wav_inputs, sampling_fraction)
mm_numpydata = audio.numpy()
logger.info(
f'Audio loaded successfully - shape: {mm_numpydata.shape}, sampling rate: {sr}, size: {mm_numpydata.size}')
return mm_numpydata, sr
else:
if isinstance(wav_name, list):
wav_inputs = []
for wav in wav_name:
wav_input = os.path.join(TEST_HW_USER_WAV_LOAD_PATH, wav)
wav_inputs.append(wav_input)
logger.info(f'Loading audio files: {wav_inputs}')
elif isinstance(wav_name, str):
wav_inputs = os.path.join(TEST_HW_USER_FILE_PATH, wav_name)
else:
raise TypeError("Only str or list types are supported")
audio_list_load = load_audio(wav_inputs, sampling_fraction)
results = []
for i, audio in enumerate(audio_list_load):
logger.info(f'Processing audio result {i+1}/{len(audio_list_load)}')
test_audio = audio[0].numpy()
test_sr = audio[1]
logger.info(
f'Audio {i+1} - shape: {test_audio.shape}, sampling rate: {test_sr}')
if sampling_fraction:
assert sampling_fraction == test_sr
results.append((test_audio, test_sr))
return results
def test_wav_file_sr_16000():
wav_name = 'audio_test0.wav'
sampling_fraction = 16000
mm_numpydata, sr = load_audio_func(wav_name, sampling_fraction)
assert sr == sampling_fraction, f"Sampling rate mismatch: expected {sampling_fraction}, got {sr}"
assert mm_numpydata.size > 0, "Audio data should not be empty"
def test_wav_list_sr_16000():
wav_list = ['audio_test0.wav', 'audio_test1.wav']
sampling_fraction = 16000
results = load_audio_func(wav_list, sampling_fraction)
assert len(results) == len(wav_list), f"Result count mismatch: expected {len(wav_list)}, got {len(results)}"
for i, (audio_data, sr) in enumerate(results):
assert sr == sampling_fraction, f"Sampling rate mismatch for audio {i}: expected {sampling_fraction}, got {sr}"
assert audio_data.size > 0, f"Audio data {i} should not be empty"
def test_wav_dir_sr_16000():
wav_dir = 'wav_load_test'
sampling_fraction = 16000
results = load_audio_func(wav_dir, sampling_fraction)
assert len(results) > 0, "Should have loaded at least one audio file from directory"
for i, (audio_data, sr) in enumerate(results):
assert sr == sampling_fraction, f"Sampling rate mismatch for audio {i}: expected {sampling_fraction}, got {sr}"
assert audio_data.size > 0, f"Audio data {i} should not be empty"