from typing import List, Optional, Tuple
import bentoml
import torch
from pydantic import Field
from transformers import AutoTokenizer, AutoModelForCausalLM
from vrag.logger import logger
from vrag.shared import ArgsBase, vrag_service
from vrag.tools.path_validator import validate_dir_exists
DEFAULT_INSTRUCTION = "Given a web search query, retrieve relevant passages that answer the query."
PROMPT_PREFIX = (
"<|im_start|>system\n"
"Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
"Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n"
)
PROMPT_SUFFIX = "<|im_end|>\n<|im_start|>assistant"
class QwenRerankerArgs(ArgsBase):
"""Qwen reranker configuration."""
reranker_model_path: str = ""
"""Local path to the Qwen reranker model directory."""
reranker_device: str = "npu:3"
"""Device for Qwen reranker model inference, e.g. 'npu:3' or 'cpu'."""
reranker_batch_size: int = Field(4, ge=1)
"""Batch size for reranking inference."""
default_max_length: int = Field(8192, gt=0)
"""Maximum token length for reranker input sequences."""
default_top_k: int = Field(5, ge=1)
"""Default number of top documents to return after reranking."""
args = bentoml.use_arguments(QwenRerankerArgs).override()
@vrag_service(args)
class QwenRerankerService:
def __init__(self):
self.model_path = validate_dir_exists(args.reranker_model_path, "Qwen reranker model")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="left", local_files_only=True)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, device_map=args.reranker_device, local_files_only=True
)
self.model.eval()
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.prefix_tokens = self.tokenizer.encode(PROMPT_PREFIX, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(PROMPT_SUFFIX, add_special_tokens=False)
msg = f"QwenRerankerService initialized from {self.model_path}."
logger.info(msg)
@staticmethod
def _format_instruction_pairs(
query: str, docs: List[str], instruction: str = DEFAULT_INSTRUCTION
) -> List[Tuple[str, str]]:
return [(query, f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}") for doc in docs]
@bentoml.api
async def rerank(self, query: str, documents: List[str], top_k: Optional[int] = None) -> List[int]:
msg = f"Rerank process indices of {len(documents)} docs."
logger.info(msg)
docs_num = len(documents)
top_k = top_k if top_k is not None else args.default_top_k
if docs_num <= top_k:
return list(range(docs_num))
all_scores = self._rerank_scores_inner(query, documents)
indices = list(range(docs_num))
scored_pairs = sorted(zip(all_scores, indices, strict=True), key=lambda x: x[0], reverse=True)
return [idx for score, idx in scored_pairs[:top_k]]
def _process_input(self, prompt_pairs: List[Tuple[str, str]]) -> dict[str, torch.Tensor]:
max_content_length = args.default_max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
if max_content_length < 0:
raise ValueError("QwenRerankerService should set a bigger default_max_length")
inputs = self.tokenizer(
prompt_pairs,
padding=False,
truncation="longest_first",
return_attention_mask=False,
max_length=max_content_length,
)
for i, element in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = self.prefix_tokens + element + self.suffix_tokens
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=args.default_max_length)
return {k: v.to(self.model.device) for k, v in inputs.items()}
def _compute_logits(self, inputs: dict[str, torch.Tensor]) -> List[float]:
with torch.no_grad():
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
return batch_scores[:, 1].exp().tolist()
def _process_batch(self, query: str, batch_docs: List[str]) -> List[float]:
prompt_pairs = self._format_instruction_pairs(query, batch_docs)
inputs = self._process_input(prompt_pairs)
return self._compute_logits(inputs)
def _rerank_scores_inner(self, query: str, documents: List[str]) -> List[float]:
docs_num = len(documents)
all_score = []
for i in range(0, docs_num, args.reranker_batch_size):
batch_docs = documents[i : i + args.reranker_batch_size]
all_score.extend(self._process_batch(query, batch_docs))
return all_score