wav2vec2-large-robust-24-ft-age-gender:基于Wav2vec 2.0的24层音频年龄与性别识别模型

该模型通过原始音频信号预测年龄(0-100岁)和性别(儿童、女性、男性概率),并提供最后一层Transformer的池化状态,由Wav2Vec2-Large-Robust微调24层得到。【此简介由AI生成】

分支1Tags0
10d60994创建于 2024年9月19日16次提交
文件最后提交记录最后更新时间
initial commit2 年前
Upload LICENSE2 年前
Fix ONNX link in README1 年前
Upload model2 年前
Adding safetensors variant of this model (#1) - Adding safetensors variant of this model (c465eab72d5aefa27b06d2b5c8962a62897e6968) Co-authored-by: Safetensors convertbot <SFconvertbot@users.noreply.huggingface.co> 2 年前
Create preprocessor_config.json2 年前
Upload model2 年前
Update vocab.json2 年前

datasets:

  • agender
  • mozillacommonvoice
  • timit
  • voxceleb2 inference: true tags:
  • speech
  • audio
  • wav2vec2
  • audio-classification
  • age-recognition
  • gender-recognition license: cc-by-nc-sa-4.0

基于 Wav2vec 2.0(24 层)的年龄与性别识别模型

本模型以原始音频信号作为输入,输出年龄预测值(范围约 0 至 1,对应 0 到 100 周岁)以及儿童/女性/男性性别概率。同时提供最后一层 Transformer 的池化状态输出。该模型通过对 Wav2Vec2-Large-RobustaGenderMozilla Common VoiceTimitVoxceleb 2 数据集上进行微调训练而得。此版本模型完整训练了全部 24 个 Transformer 层。

模型的 ONNX 导出版本可通过 doi:10.5281/zenodo.7761387 获取。更多技术细节请参阅相关论文教程

使用说明

import numpy as np
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
)


class ModelHead(nn.Module):
    r"""Classification head."""

    def __init__(self, config, num_labels):

        super().__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):

        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)

        return x


class AgeGenderModel(Wav2Vec2PreTrainedModel):
    r"""Speech emotion classifier."""

    def __init__(self, config):

        super().__init__(config)

        self.config = config
        self.wav2vec2 = Wav2Vec2Model(config)
        self.age = ModelHead(config, 1)
        self.gender = ModelHead(config, 3)
        self.init_weights()

    def forward(
            self,
            input_values,
    ):

        outputs = self.wav2vec2(input_values)
        hidden_states = outputs[0]
        hidden_states = torch.mean(hidden_states, dim=1)
        logits_age = self.age(hidden_states)
        logits_gender = torch.softmax(self.gender(hidden_states), dim=1)

        return hidden_states, logits_age, logits_gender



# load model from hub
device = 'cpu'
model_name = 'audeering/wav2vec2-large-robust-24-ft-age-gender'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = AgeGenderModel.from_pretrained(model_name)

# dummy signal
sampling_rate = 16000
signal = np.zeros((1, sampling_rate), dtype=np.float32)


def process_func(
    x: np.ndarray,
    sampling_rate: int,
    embeddings: bool = False,
) -> np.ndarray:
    r"""Predict age and gender or extract embeddings from raw audio signal."""

    # run through processor to normalize signal
    # always returns a batch, so we just get the first entry
    # then we put it on the device
    y = processor(x, sampling_rate=sampling_rate)
    y = y['input_values'][0]
    y = y.reshape(1, -1)
    y = torch.from_numpy(y).to(device)

    # run through model
    with torch.no_grad():
        y = model(y)
        if embeddings:
            y = y[0]
        else:
            y = torch.hstack([y[1], y[2]])

    # convert to numpy
    y = y.detach().cpu().numpy()

    return y


print(process_func(signal, sampling_rate))
#    Age        female     male       child
# [[ 0.33793038 0.2715511  0.2275236  0.5009253 ]]

print(process_func(signal, sampling_rate, embeddings=True))
# Pooled hidden states of last transformer layer
# [[ 0.024444    0.0508722   0.04930823 ...  0.07247854 -0.0697901
#   -0.0170537 ]]

项目介绍

该模型通过原始音频信号预测年龄(0-100岁)和性别(儿童、女性、男性概率),并提供最后一层Transformer的池化状态,由Wav2Vec2-Large-Robust微调24层得到。【此简介由AI生成】

定制我的领域

下载使用量

0

项目总下载次数(含Clone、Pull、 zip 包及 release 下载),每日凌晨更新