import time
import argparse
from typing import Tuple, Optional
import torch
from torch import nn
import torch_npu
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
from FlagEmbedding import FlagReranker
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaEmbeddings
def parse_args():
parser = argparse.ArgumentParser(description="bge-reranker-v2-m3 infer")
parser.add_argument("--model_path", type=str, default="./",
help="model local path (either local directory or huggingface-Hub)")
parser.add_argument('--warmup', type=int, default=4, help="Warm up times")
parser.add_argument('--loop', type=int, default=10, help="loop times")
parser.add_argument("--devices", type=str, default="['npu:0']", help="target devices")
args = parser.parse_args()
return args
def create_model(args):
model = FlagReranker(args.model_path, trust_remote_code=True)
model.target_devices = eval(args.devices)
return model
class MyXLMRobertaEmbeddings(XLMRobertaEmbeddings):
"""
重写模型的Embedding层
修改原本XLMRobertaEmbeddings中的create_position_ids_from_input_ids方法
将 padding_idx 转换为与 input_ids 相同的设备和张量类型
"""
def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
padding_idx = torch.tensor(padding_idx, device=input_ids.device, dtype=input_ids.dtype)
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if position_ids is None:
if input_ids is not None:
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def rewrite_self_attention_forward(model):
"""
此处有1个优化点: 使用一个Linear(qkv)来代替原有的3个Linear
"""
wq = model.query.weight
wk = model.key.weight
wv = model.value.weight
model.qkv = nn.Linear(wq.shape[0], wq.shape[1] + wk.shape[1] + wv.shape[1])
model.qkv.weight = nn.Parameter(torch.concat([wq, wk, wv], dim=0), requires_grad=False)
model.qkv.bias = nn.Parameter(torch.concat([
model.query.bias, model.key.bias, model.value.bias
], dim=0), requires_grad=False)
del model.query
del model.key
del model.value
def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
qkv_layers = model.qkv(hidden_states)
chunk_size = wq.shape[1]
query_layer = qkv_layers[:, :, :chunk_size]
key_layer = qkv_layers[:, :, chunk_size:chunk_size * 2]
value_layer = qkv_layers[:, :, chunk_size * 2:]
bsz, tgt_len, _ = hidden_states.size()
query_layer = model.transpose_for_scores(query_layer)
key_layer = model.transpose_for_scores(key_layer)
value_layer = model.transpose_for_scores(value_layer)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, model.all_head_size)
outputs = (attn_output,)
if model.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
model.forward = forward
def modify_model(model):
xlm_roberta_config = model.model.config
xlm_roberta_embeddings = model.model.roberta.embeddings
model.model.roberta.embeddings = MyXLMRobertaEmbeddings(xlm_roberta_config)
model.model.roberta.embeddings.load_state_dict(xlm_roberta_embeddings.state_dict())
model.model.roberta.embeddings.eval().half()
for layer in model.model.roberta.encoder.layer:
rewrite_self_attention_forward(layer.attention.self)
return model
if __name__ == '__main__':
args = parse_args()
torch_npu.npu.set_compile_mode(jit_compile=False)
config = CompilerConfig()
config.experimental_config.frozen_parameter = True
npu_backend = tng.get_npu_backend(compiler_config=config)
model = create_model(args)
model = modify_model(model)
model.model.eval().half()
model.model.forward = torch.compile(model.model.forward, dynamic=True, fullgraph=True, backend=npu_backend)
sentences = [
['what is panda?', 'hi'],
['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']
]
with torch.inference_mode():
for _ in range(args.warmup):
output = model.compute_score(sentences, normalize=True)
print(f"the similarity of {sentences[0]} is: ", output[0])
print(f"the similarity of {sentences[1]} is: ", output[1])
start_time = time.time()
for _ in range(args.loop):
output = model.compute_score(sentences, normalize=True)
print(f'E2E time = {(time.time() - start_time) / args.loop * 1000}ms')