import os
import argparse
import json
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
import whisper
from torch import Tensor, nn
from whisper.model import (
AudioEncoder,
MultiHeadAttention,
ResidualAttentionBlock,
TextDecoder,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model",
type=str,
default="./base.en.pt")
parser.add_argument("--output_dir",
type=str,
default="./whisper_base_en")
return parser.parse_args()
def modified_audio_encoder_forward(self: AudioEncoder, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert (
x.shape[2] == self.positional_embedding.shape[1]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
assert (
x.shape[1] == self.positional_embedding.shape[0]
), f"incorrect audio shape: {x.shape}, {self.positional_embedding.shape}"
x = (x + self.positional_embedding[:x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
AudioEncoder.forward = modified_audio_encoder_forward
class AudioEncoderTensorCache(nn.Module):
def __init__(self, inAudioEncoder: AudioEncoder,
inTextDecoder: TextDecoder):
super().__init__()
self.audioEncoder = inAudioEncoder
self.textDecoder = inTextDecoder
def forward(self, x: Tensor):
audio_features = self.audioEncoder(x)
n_layer_cross_k_list = []
n_layer_cross_v_list = []
for block in self.textDecoder.blocks:
n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
return torch.stack(n_layer_cross_k_list), torch.stack(
n_layer_cross_v_list)
class MultiHeadAttentionCross(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
):
q = self.multiHeadAttention.query(x)
wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
return self.multiHeadAttention.out(wv)
class MultiHeadAttentionSelf(nn.Module):
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
super().__init__()
self.multiHeadAttention = inMultiHeadAttention
def forward(
self,
x: Tensor,
k_cache: Tensor,
v_cache: Tensor,
mask: Tensor,
):
q = self.multiHeadAttention.query(x)
k = self.multiHeadAttention.key(x)
v = self.multiHeadAttention.value(x)
k_cache[:, -k.shape[1]:, :] = k
v_cache[:, -v.shape[1]:, :] = v
wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache,
mask)
return self.multiHeadAttention.out(wv), k_cache, v_cache
class ResidualAttentionBlockTensorCache(nn.Module):
def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
super().__init__()
self.originalBlock = inResidualAttentionBlock
self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
self.cross_attn = (MultiHeadAttentionCross(
inResidualAttentionBlock.cross_attn)
if inResidualAttentionBlock.cross_attn else None)
def forward(
self,
x: Tensor,
self_k_cache: Tensor,
self_v_cache: Tensor,
cross_k: Tensor,
cross_v: Tensor,
mask: Tensor,
):
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
self.originalBlock.attn_ln(x),
self_k_cache,
self_v_cache,
mask=mask)
x = x + self_attn_x
if self.cross_attn:
x = x + self.cross_attn(self.originalBlock.cross_attn_ln(x),
cross_k, cross_v)
x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
return x, self_k_cache_updated, self_v_cache_updated
class TextDecoderTensorCache(nn.Module):
def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
super().__init__()
self.textDecoder = inTextDecoder
self.n_ctx = in_n_ctx
self.blocks = []
for orginal_block in self.textDecoder.blocks:
self.blocks.append(
ResidualAttentionBlockTensorCache(orginal_block))
def forward(
self,
tokens: Tensor,
n_layer_self_k_cache: Tensor,
n_layer_self_v_cache: Tensor,
n_layer_cross_k: Tensor,
n_layer_cross_v: Tensor,
offset: Tensor,
):
x = (self.textDecoder.token_embedding(tokens) +
self.textDecoder.positional_embedding[offset[0]:offset[0] +
tokens.shape[-1]])
x = x.to(n_layer_cross_k[0].dtype)
i = 0
for block in self.blocks:
self_k_cache = n_layer_self_k_cache[i, :, :offset[0] +
tokens.shape[-1], :]
self_v_cache = n_layer_self_v_cache[i, :, :offset[0] +
tokens.shape[-1], :]
x, self_k_cache, self_v_cache = block(
x,
self_k_cache=self_k_cache,
self_v_cache=self_v_cache,
cross_k=n_layer_cross_k[i],
cross_v=n_layer_cross_v[i],
mask=self.textDecoder.mask,
)
n_layer_self_k_cache[i, :, :offset[0] +
tokens.shape[-1], :] = self_k_cache
n_layer_self_v_cache[i, :, :offset[0] +
tokens.shape[-1], :] = self_v_cache
i += 1
x = self.textDecoder.ln(x)
logits = (torch.matmul(
self.textDecoder.token_embedding.weight.to(x.dtype),
x.permute(0, 2, 1),
).permute(0, 2, 1).float())
return logits, n_layer_self_k_cache, n_layer_self_v_cache
def convert_tokens(model, output_path):
whisper_dir = Path(whisper.__file__).parent
multilingual = model.is_multilingual
tokenizer = (whisper_dir / "assets" /
(multilingual and "multilingual.tiktoken" or "gpt2.tiktoken"))
if not tokenizer.is_file():
raise ValueError(f"Cannot find {tokenizer}")
tokens = {}
with open(tokenizer, "r") as f:
contents = f.read()
for line in contents.splitlines():
if not line:
continue
token, rank = line.split()
tokens[token] = int(rank)
with open(f"{output_path}/tokens.txt", "w") as f:
for t, i in tokens.items():
f.write(f"{t} {i}\n")
def get_cfg(model, output_path):
convert_tokens(model, output_path)
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model_cfg = {
"n_mels":
model.dims.n_mels,
"n_audio_ctx":
model.dims.n_audio_ctx,
"n_audio_state":
model.dims.n_audio_state,
"n_audio_head":
model.dims.n_audio_head,
"n_audio_layer":
model.dims.n_audio_layer,
"n_vocab":
model.dims.n_vocab,
"n_text_ctx":
model.dims.n_text_ctx,
"n_text_state":
model.dims.n_text_state,
"n_text_head":
model.dims.n_text_head,
"n_text_layer":
model.dims.n_text_layer,
"sot_sequence":
",".join(list(map(str, tokenizer.sot_sequence))),
"all_language_tokens":
",".join(list(map(str,
tokenizer.all_language_tokens))),
"all_language_codes":
",".join(tokenizer.all_language_codes),
"sot":
tokenizer.sot,
"sot_index":
tokenizer.sot_sequence.index(tokenizer.sot),
"eot":
tokenizer.eot,
"blank_id":
tokenizer.encode(" ")[0],
"is_multilingual":
int(model.is_multilingual),
"no_speech":
tokenizer.no_speech,
"non_speech_tokens":
",".join(list(map(str, tokenizer.non_speech_tokens))),
"transcribe":
tokenizer.transcribe,
"translate":
tokenizer.translate,
"sot_prev":
tokenizer.sot_prev,
"sot_lm":
tokenizer.sot_lm,
"no_timestamps":
tokenizer.no_timestamps,
}
with open(f"{output_path}/model_cfg.json", "w") as json_file:
json.dump(model_cfg, json_file)
@torch.no_grad()
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
opset_version = 14
model = whisper.load_model(args.model)
print(
"number of model parameters:",
sum(p.numel() for p in model.parameters()),
)
print(
"number of encoder parameters:",
sum(p.numel() for p in model.encoder.parameters()),
)
print(
"number of decoder parameters:",
sum(p.numel() for p in model.decoder.parameters()),
)
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
model.eval()
get_cfg(model, args.output_dir)
audio = torch.rand(16000 * 2)
audio = whisper.pad_or_trim(audio)
assert audio.shape == (16000 * 30, ), audio.shape
n_mels = model.dims.n_mels
mel = whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device).unsqueeze(0)
batch_size = 1
assert mel.shape == (batch_size, n_mels, 30 * 100)
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
n_layer_cross_k, n_layer_cross_v = encoder(mel)
assert n_layer_cross_k.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), (n_layer_cross_k.shape, model.dims)
assert n_layer_cross_v.shape == (
model.dims.n_text_layer,
batch_size,
model.dims.n_audio_ctx,
model.dims.n_text_state,
), (n_layer_cross_v.shape, model.dims)
encoder_filename = f"{args.output_dir}/encoder.onnx"
torch.onnx.export(
encoder,
mel,
encoder_filename,
opset_version=opset_version,
input_names=["mel"],
output_names=["n_layer_cross_k", "n_layer_cross_v"],
dynamic_axes={
"mel": {
0: "n_audio",
2: "T"
},
"n_layer_cross_k": {
1: "n_audio",
2: "T"
},
"n_layer_cross_v": {
1: "n_audio",
2: "T"
},
},
)
n_audio = mel.shape[0]
tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] *
n_audio).to(mel.device)
decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
n_layer_self_k_cache = torch.zeros(
(
len(model.decoder.blocks),
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
),
device=mel.device,
)
n_layer_self_v_cache = torch.zeros(
(
len(model.decoder.blocks),
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
),
device=mel.device,
)
offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab)
assert n_layer_self_k_cache.shape == (
model.dims.n_text_layer,
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
)
assert n_layer_self_v_cache.shape == (
model.dims.n_text_layer,
n_audio,
model.dims.n_text_ctx,
model.dims.n_text_state,
)
offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(
mel.device)
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
)
decoder_filename = f"{args.output_dir}/decoder.onnx"
torch.onnx.export(
decoder,
(
tokens,
n_layer_self_k_cache,
n_layer_self_v_cache,
n_layer_cross_k,
n_layer_cross_v,
offset,
),
decoder_filename,
opset_version=opset_version,
input_names=[
"tokens",
"in_n_layer_self_k_cache",
"in_n_layer_self_v_cache",
"n_layer_cross_k",
"n_layer_cross_v",
"offset",
],
output_names=[
"logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"
],
dynamic_axes={
"tokens": {
0: "n_audio",
1: "n_tokens"
},
"in_n_layer_self_k_cache": {
1: "n_audio"
},
"in_n_layer_self_v_cache": {
1: "n_audio"
},
"n_layer_cross_k": {
1: "n_audio",
2: "T"
},
"n_layer_cross_v": {
1: "n_audio",
2: "T"
},
},
)
if __name__ == "__main__":
main()