from transformers import Wav2Vec2Processor, HubertModel
import soundfile as sf
import numpy as np
import torch
import os
from utils.commons.hparams import set_hparams, hparams
wav2vec2_processor = None
hubert_model = None
def get_hubert_from_16k_wav(wav_16k_name):
speech_16k, _ = sf.read(wav_16k_name)
hubert = get_hubert_from_16k_speech(speech_16k)
return hubert
@torch.no_grad()
def get_hubert_from_16k_speech(speech, device="cuda:0"):
global hubert_model, wav2vec2_processor
local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
if hubert_model is None:
print("Loading the HuBERT Model...")
print("Loading the Wav2Vec2 Processor...")
if os.path.exists(local_path):
hubert_model = HubertModel.from_pretrained(local_path)
wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
else:
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
hubert_model = hubert_model.to(device)
if speech.ndim ==2:
speech = speech[:, 0]
input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values
input_values_all = input_values_all.to(device)
kernel = 400
stride = 320
clip_length = stride * 1000
num_iter = input_values_all.shape[1] // clip_length
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
res_lst = []
for i in range(num_iter):
if i == 0:
start_idx = 0
end_idx = clip_length - stride + kernel
else:
start_idx = clip_length * i
end_idx = start_idx + (clip_length - stride + kernel)
input_values = input_values_all[:, start_idx: end_idx]
hidden_states = hubert_model.forward(input_values).last_hidden_state
res_lst.append(hidden_states[0])
if num_iter > 0:
input_values = input_values_all[:, clip_length * num_iter:]
else:
input_values = input_values_all
if input_values.shape[1] >= kernel:
hidden_states = hubert_model(input_values).last_hidden_state
res_lst.append(hidden_states[0])
ret = torch.cat(res_lst, dim=0).cpu()
assert abs(ret.shape[0] - expected_T) <= 1
if ret.shape[0] < expected_T:
ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
else:
ret = ret[:expected_T]
return ret
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--video_id', type=str, default='May', help='')
args = parser.parse_args()
person_id = args.video_id
wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
speech_16k, _ = sf.read(wav_16k_name)
hubert_hidden = get_hubert_from_16k_speech(speech_16k)
np.save(hubert_npy_name, hubert_hidden.detach().numpy())
print(f"Saved at {hubert_npy_name}")