15亿参数交叉编码器,经40+检索任务指令微调,支持自然语言指令引导的文档重排,在BEIR等基准上超越SOTA,提升检索相关性与指令遵循能力。【此简介由AI生成】
| 文件 | 最后提交记录 | 最后更新时间 |
|---|---|---|
| 3 年前 | ||
| 3 年前 | ||
| 3 年前 | ||
| 3 年前 | ||
| 3 年前 | ||
| 3 年前 | ||
spiece.modelLFS | 3 年前 | |
| 3 年前 |
以下内容由 AI 翻译,如有问题请 点此提交 issue 反馈
任务感知的指令化检索
官方代码库:github.com/facebookresearch/tart
模型描述
facebook/tart-full-flan-t5-xl 是一个通过指令微调在约40个检索任务上训练的多任务交叉编码器模型,其初始化权重基于 google/flan-t5-xl。
TART-full 是一个拥有15亿参数的交叉编码器,它能够根据查询和自然语言指令(例如:查找回答此问题的维基百科段落)对候选文档进行重排序。在广泛使用的 BEIR、LOTTE 以及我们新推出的评估基准 X^2-Retrieval 上的实验结果表明,TART-full 通过利用自然语言指令,超越了先前最先进的方法。
建模与训练的更多细节请参阅我们的论文:Task-aware Retrieval with Instructions。
安装说明
git clone https://github.com/facebookresearch/tart
pip install -r requirements.txt
cd tart/TART
如何使用?
TART-full 可通过我们定制的 EncT5 模型进行加载。
from src.modeling_enc_t5 import EncT5ForSequenceClassification
from src.tokenization_enc_t5 import EncT5Tokenizer
import torch
import torch.nn.functional as F
import numpy as np
# load TART full and tokenizer
model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
tokenizer = EncT5Tokenizer.from_pretrained("facebook/tart-full-flan-t5-xl")
model.eval()
q = "What is the population of Tokyo?"
in_answer = "retrieve a passage that answers this question from Wikipedia"
p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022, the first decline since 1996, the metropolitan government reported Monday."
p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan."
# 1. TART-full can identify more relevant paragraph.
features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
scores = model(**features).logits
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million ... "
# 2. TART-full can identify the document that is more relevant AND follows instructions.
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
q_1 = "How many people live in Tokyo?"
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p_1, q_1], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
scores = model(**features).logits
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
print([p_1, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?"