c0e2cd76创建于 2025年12月23日历史提交
import numpy as np
from rknn.api import RKNN
import argparse
import soundfile as sf
import onnxruntime
import scipy

CHUNK_LENGTH = 3  # 3 seconds
MAX_N_SAMPLES = CHUNK_LENGTH * 16000


def ensure_sample_rate(waveform, original_sample_rate, desired_sample_rate=16000):
    if original_sample_rate != desired_sample_rate:
        print("resample_audio: {} HZ -> {} HZ".format(original_sample_rate, desired_sample_rate))
        desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate))
        waveform = scipy.signal.resample(waveform, desired_length)
    return waveform, desired_sample_rate

def ensure_channels(waveform, original_channels, desired_channels=1):
    if original_channels != desired_channels:
        print("convert_channels: {} -> {}".format(original_channels, desired_channels))
        waveform = np.mean(waveform, axis=1)
    return waveform, desired_channels

def init_model(model_path, target=None, device_id=None):
    if model_path.endswith(".rknn"):
        # Create RKNN object
        model = RKNN()

        # Load RKNN model
        print('--> Loading model')
        ret = model.load_rknn(model_path)
        if ret != 0:
            print('Load RKNN model \"{}\" failed!'.format(model_path))
            exit(ret)
        print('done')

        # init runtime environment
        print('--> Init runtime environment')
        ret = model.init_runtime(target=target, device_id=device_id)
        if ret != 0:
            print('Init runtime environment failed')
            exit(ret)
        print('done')

    elif model_path.endswith(".onnx"):
        model = onnxruntime.InferenceSession(model_path,  providers=['CPUExecutionProvider'])

    return model

def run_model(model, audio):
    if 'rknn' in str(type(model)):
        outputs  = model.inference(inputs=audio)
    elif 'onnx' in str(type(model)):
        outputs  = model.run(None, {model.get_inputs()[0].name: audio})

    return outputs

def release_model(model):
    if 'rknn' in str(type(model)):
        model.release()
    elif 'onnx' in str(type(model)):
        del model
    model = None

def post_process(outputs):
    scores = outputs[2]
    top_class_index = scores.mean(axis=0).argmax()
    return top_class_index

def pad_or_trim(array, length, axis=-1):
    if array.shape[axis] > length:
        array = array.take(indices=range(length), axis=axis)

    if array.shape[axis] < length:
        pad_widths = [(0, 0)] * array.ndim
        pad_widths[axis] = (0, length - array.shape[axis])
        array = np.pad(array, pad_widths)
    return array

def read_txt_to_dict(filename):
    data_dict = {}
    with open(filename, 'r') as txtfile:
        for line in txtfile:
            line = line.strip().split(' ')
            key = line[0]
            value = ' '.join(line[1:])
            data_dict[key] = value
    return data_dict



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Yamnet Long Audio Demo', add_help=True)
    parser.add_argument('--model_path', type=str, required=True, help='model path')
    parser.add_argument('--target', type=str, default='rk3588', help='target platform')
    parser.add_argument('--device_id', type=str, default=None, help='device id')
    args = parser.parse_args()

    # 1. 加载音频并预处理
    audio_path = "/home/orangepi/Desktop/videoplayback(2).wav"
    audio_data, sample_rate = sf.read(audio_path)
    
    channels = audio_data.ndim
    audio_data, _ = ensure_channels(audio_data, channels)
    audio_data, _ = ensure_sample_rate(audio_data, sample_rate)
    audio_array = np.array(audio_data, dtype=np.float32).flatten()

    # 2. 加载标签映射
    label = read_txt_to_dict("../model/yamnet_class_map.txt")

    # 3. 初始化模型
    model = init_model(args.model_path, args.target, args.device_id)

    # --- 核心修改部分:循环处理长音频 ---
    print(f"\n开始分析音频: {audio_path}")
    print("-" * 40)
    
    total_samples = len(audio_array)
    # 这里的 MAX_N_SAMPLES = 30 * 16000 = 480000 个采样点
    step = MAX_N_SAMPLES 
    
    for start in range(0, total_samples, step):
        # 计算当前时间戳
        current_time_sec = start / 16000
        
        # 截取 3 秒片段
        chunk = audio_array[start : start + step]
        
        # 如果最后一段不足 3 秒,进行补齐
        if len(chunk) < MAX_N_SAMPLES:
            chunk = np.pad(chunk, (0, MAX_N_SAMPLES - len(chunk)))
        
        # 维度整理 [1, 480000]
        input_data = np.expand_dims(chunk, 0)

        # 4. 推理
        outputs = run_model(model, input_data)

        # 5. 后处理与实时输出
        # outputs[2] 包含这 3 秒内多个时间切片的预测结果
        scores = outputs[2] 
        # 取这 3 秒内的平均得分最高的类别
        top_class_index = scores.mean(axis=0).argmax()
        
        class_name = label.get(str(top_class_index), "Unknown")
        
        print(f"[{current_time_sec:6.1f}s - {current_time_sec+3:6.1f}s] 主要声音: {class_name}")

    print("-" * 40)
    print("分析完成")

    # 6. 释放模型
    release_model(model)