tart-full-flan-t5-xl:基于指令微调的多任务交叉编码器,优化文档重排与检索

15亿参数交叉编码器,经40+检索任务指令微调,支持自然语言指令引导的文档重排,在BEIR等基准上超越SOTA,提升检索相关性与指令遵循能力。【此简介由AI生成】

分支1Tags0

任务感知的指令化检索

官方代码库:github.com/facebookresearch/tart

模型描述

facebook/tart-full-flan-t5-xl 是一个通过指令微调在约40个检索任务上训练的多任务交叉编码器模型,其初始化权重基于 google/flan-t5-xl

TART-full 是一个拥有15亿参数的交叉编码器,它能够根据查询和自然语言指令(例如:查找回答此问题的维基百科段落)对候选文档进行重排序。在广泛使用的 BEIRLOTTE 以及我们新推出的评估基准 X^2-Retrieval 上的实验结果表明,TART-full 通过利用自然语言指令,超越了先前最先进的方法。

建模与训练的更多细节请参阅我们的论文:Task-aware Retrieval with Instructions

安装说明

git clone https://github.com/facebookresearch/tart
pip install -r requirements.txt
cd tart/TART

如何使用?

TART-full 可通过我们定制的 EncT5 模型进行加载。

from src.modeling_enc_t5 import EncT5ForSequenceClassification
from src.tokenization_enc_t5 import EncT5Tokenizer
import torch
import torch.nn.functional as F
import numpy as np

# load TART full and tokenizer
model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
tokenizer =  EncT5Tokenizer.from_pretrained("facebook/tart-full-flan-t5-xl")
model.eval()

q = "What is the population of Tokyo?"
in_answer = "retrieve a passage that answers this question from Wikipedia"

p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022, the first decline since 1996, the metropolitan government reported Monday."
p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan."

# 1. TART-full can identify more relevant paragraph. 
features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
    scores = model(**features).logits
    normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]

print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million ... "

# 2. TART-full can identify the document that is more relevant AND follows instructions.
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
q_1 = "How many people live in Tokyo?"
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p_1, q_1], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
    scores = model(**features).logits
    normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]

print([p_1, q_1][np.argmax(normalized_scores)]) #  "How many people live in Tokyo?"

项目介绍

15亿参数交叉编码器,经40+检索任务指令微调,支持自然语言指令引导的文档重排,在BEIR等基准上超越SOTA,提升检索相关性与指令遵循能力。【此简介由AI生成】

定制我的领域

下载使用量

0

项目总下载次数(含Clone、Pull、 zip 包及 release 下载),每日凌晨更新