CoRAG接口参考

CoRAG模块

CoRAG(Chain of Retrieval-Augmented Generation)是一种基于链式检索增强生成的多轮问答框架,通过迭代生成子查询、获取相关文档并整合信息,实现复杂问题的深度推理和回答。

CoRagBaseConfig

类功能

功能描述

CoRAG基础配置类,包含共享的核心参数,用于初始化CoRAG相关组件。

函数原型

from mx_rag.corag.config import CoRagBaseConfig
CoRagBaseConfig(base_llm, retrieve_api_url, num_threads, max_path_length, final_llm, sub_answer_llm, judge_llm)

参数说明

参数名 数据类型 可选/必选 默认值 说明
base_llm Text2TextLLM 必选 - 基础LLM实例,用于生成子查询和答案。具体可参见Text2TextLLM
retrieve_api_url str 必选 - 检索API的URL地址,用于获取相关文档,必须支持POST请求,请求体为JSON格式,包含查询文本query,返回体为JSON格式,支持多种结构(详见下方请求体和响应体示例)。
num_threads int 可选 8 并行处理的线程数。
max_path_length int 可选 3 最大路径长度,表示生成子查询的最大轮数。
final_llm Text2TextLLM 可选 None 最终答案生成LLM实例,若不提供则使用base_llm。具体可参见Text2TextLLM
sub_answer_llm Text2TextLLM 可选 None 子答案生成LLM实例,若不提供则使用base_llm。具体可参见Text2TextLLM
judge_llm Text2TextLLM 可选 None 判断LLM实例,用于评估答案正确性。具体可参见Text2TextLLM
retrieve_top_k int 可选 5 检索API返回的文档数量,默认值为5。
client_param ClientParam 可选 ClientParam() HTTP客户端参数,用于配置请求行为。具体可参见ClientParam

请求体示例

{
    "query": "Which company acquired by Google was founded first?", "top_k": 5
}

响应体示例

支持多种响应格式,以下是常见的几种:

格式1:包含document_ids和documents的标准格式

{
    "document_ids": ["doc1", "doc2"],
    "documents": ["Google is a multinational technology company.", "YouTube was founded on February 14, 2005."]
}

格式2:包含chunks字段的格式

{
    "chunks": [
        "Google is a multinational technology company that specializes in Internet-related services and products.",
        "YouTube is an American online video sharing and social media platform owned by Google."
    ]
}

格式3:包含data字段的格式

{
    "data": [
        {
            "id": "doc1",
            "content": "Google is a multinational technology company."
        },
        {
            "id": "doc2",
            "content": "YouTube was founded on February 14, 2005."
        }
    ]
}

格式4:包含results字段的格式

{
    "results": [
        {
            "doc_id": "doc1",
            "text": "Google is a multinational technology company that specializes in Internet-related services and products."
        },
        {
            "doc_id": "doc2",
            "text": "YouTube is an American online video sharing and social media platform owned by Google."
        }
    ]
}

格式5:包含docs字段的格式

{
    "docs": [
        {
            "id": "doc1",
            "contents": "Google is a multinational technology company."
        },
        {
            "id": "doc2",
            "contents": "YouTube was founded on February 14, 2005."
        }
    ]
}

格式6:包含passages字段的格式

{
    "passages": [
        {
            "id": "doc1",
            "content": "Google is a multinational technology company."
        },
        "YouTube was founded on February 14, 2005."
    ]
}

支持的响应字段说明

  • 文档内容可从以下字段提取:content, contents, text
  • 文档ID可从以下字段提取:id, doc_id
  • 支持直接返回字符串列表或包含上述字段的字典列表

ReasoningPath

类功能

功能描述

表示CoRAG推理路径的数据类,包含原始查询、子查询、子答案、文档ID、思考和文档的列表。

函数原型

from mx_rag.corag.corag_agent import ReasoningPath
ReasoningPath(original_query, subqueries, subanswers, document_ids, reasoning_steps, documents)

参数说明

参数名 数据类型 可选/必选 默认值 说明
original_query str 必选 - 原始查询文本。
subqueries List[str] 可选 [] 子查询列表。
subanswers List[str] 可选 [] 子答案列表。
document_ids List[List[str]] 可选 [] 文档ID列表,每个子查询对应多个文档ID。
reasoning_steps List[str] 可选 [] 思考过程列表,每个子查询对应一个思考过程。
documents List[List[str]] 可选 [] 文档内容列表,每个子查询对应多个文档内容。

CoRagAgent

类功能

功能描述

CoRAG智能体类,负责生成推理路径和最终答案,是CoRAG框架的核心组件。

函数原型

from mx_rag.corag.corag_agent import CoRagAgent
CoRagAgent(base_llm, retrieve_api_url, final_llm, sub_answer_llm)

参数说明

参数名 数据类型 可选/必选 默认值 说明
base_llm Text2TextLLM 必选 - 基础LLM实例,用于生成子查询和答案。具体可参见Text2TextLLM
retrieve_api_url str 必选 - 检索API的URL地址,用于获取相关文档。
final_llm Text2TextLLM 可选 None 最终答案生成LLM实例,若不提供则使用base_llm。具体可参见Text2TextLLM
sub_answer_llm Text2TextLLM 可选 None 子答案生成LLM实例,若不提供则使用base_llm。具体可参见Text2TextLLM
retrieve_top_k int 可选 5 检索API返回的文档数量,默认值为5。
client_param ClientParam 可选 ClientParam() HTTP客户端参数,用于配置请求行为。具体可参见ClientParam

调用示例

from mx_rag.corag.corag_agent import CoRagAgent
from mx_rag.llm import Text2TextLLM, LLMParameterConfig
from mx_rag.utils import ClientParam

# 初始化LLM实例
llm = Text2TextLLM(base_url="https://{ip}:{port}/v1/chat/completions",
                   model_name="qianwen-7b",
                   llm_config=LLMParameterConfig(max_tokens=512),
                   client_param=ClientParam(ca_file="/path/to/ca.crt")
                   )

# 初始化CoRagAgent
agent = CoRagAgent(
    base_llm=llm,
    retrieve_api_url="http://your-retrieve-api.com/retrieve",
    retrieve_top_k=5,
    client_param=ClientParam(ca_file="/path/to/ca.crt")
)

# 生成推理路径
task_desc = "回答用户的复杂问题,通过多轮子查询获取相关信息"
rag_path = agent.sample_path(
    query="什么是CoRAG框架的工作原理?",
    task_desc=task_desc,
    max_path_length=3
)

# 生成最终答案
final_answer = agent.generate_final_answer(
    rag_path=rag_path,
    task_description=task_desc
)

print("最终答案:", final_answer)

sample_path

功能描述

通过迭代生成子查询,根据子查询从数据源检索相关文档,并收集子答案和相关文档,构建一个完整的推理路径。

函数原型

def sample_path(self, query, task_desc, max_path_length)

参数说明

参数名 数据类型 可选/必选 默认值 说明
query str 必选 - 原始查询文本。
task_desc str 必选 - 任务描述,指导LLM的行为。
max_path_length int 可选 3 最大路径长度,表示生成子查询的最大轮数。

返回值说明

数据类型 说明
ReasoningPath 包含完整推理路径的ReasoningPath对象。

generate_final_answer

功能描述

基于生成的推理路径,整合所有信息生成最终答案。

函数原型

def generate_final_answer(self, rag_path, task_description)

参数说明

参数名 数据类型 可选/必选 默认值 说明
rag_path ReasoningPath 必选 - 包含推理路径的ReasoningPath对象。
task_description str 必选 - 任务描述,指导LLM的行为。

返回值说明

数据类型 说明
str 生成的最终答案文本。

SampleGenerator

类功能

功能描述

样本生成器类,负责生成CoRAG训练样本,通过多线程并行处理输入数据,为每个查询生成有效的推理路径,并将其转换为可用于训练的样本格式。

函数原型

from mx_rag.corag.sample_generator import SampleGenerator
SampleGenerator(config)

参数说明

参数名 数据类型 可选/必选 说明
config CoRagBaseConfig 必选 配置对象,包含LLM实例、API地址和并行参数等。具体可参见CoRagBaseConfig

调用示例

from mx_rag.corag.sample_generator import SampleGenerator
from mx_rag.corag.config import CoRagBaseConfig
from mx_rag.llm import Text2TextLLM, LLMParameterConfig
from mx_rag.utils import ClientParam

# 初始化LLM实例
llm = Text2TextLLM(base_url="https://{ip}:{port}/v1/chat/completions",
                   model_name="qianwen-7b",
                   llm_config=LLMParameterConfig(max_tokens=512),
                   client_param=ClientParam(ca_file="/path/to/ca.crt")
                   )


# 初始化配置
config = CoRagBaseConfig(
    base_llm=llm,
    retrieve_api_url="http://your-retrieve-api.com/query",
    num_threads=4,
    max_path_length=3,
    client_param=ClientParam(ca_file="/path/to/ca.crt")
)

# 初始化样本生成器
generator = SampleGenerator(config)

# 生成训练样本
samples = generator.generate(
    input_file="data/train_queries.json",
    output_file="results/corag_train_samples.jsonl",
    n_samples=3
)

print("生成的样本数量:", sum(len(query_samples) for query_samples in samples))

generate

功能描述

生成样本主方法,从输入文件加载数据,并行处理生成训练样本,并保存到输出文件。

函数原型

def generate(self, input_file, output_file, n_samples)

参数说明

参数名 数据类型 可选/必选 默认值 说明
input_file str 必选 - 输入数据文件路径(JSONL格式),每条数据包含一个查询-答案对,示例:{"query": "中国的首都是哪里?", "answer": "北京"}
output_file str 必选 - 输出文件路径。
n_samples int 可选 5 每个查询采样的路径数量。

返回值说明

数据类型 说明
List[List[Dict[str, Any]]] 处理后的样本列表,示例:{"type": "subquery_generation", "messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "SubQuery: ..."}]}

微调

FineTuneArguments

类功能

功能描述

模型微调参数类,用于配置模型微调相关参数,包括模型路径、训练数据文件路径和最大序列长度等。

函数原型

from mx_rag.corag import FineTuneArguments
FineTuneArguments(model_name_or_path, train_file, max_len)

输入参数说明

参数名 数据类型 可选/必选 说明
model_name_or_path str 可选 预训练模型路径,仅支持本地模型,默认值为"Qwen/Qwen2.5-7B-Instruct"。
train_file Optional[str] 可选 训练数据文件路径(jsonl格式),默认值为"data/aligned_train.jsonl"。
max_len int 可选 分词后的最大输入序列长度,默认值为2048。

SubqueryFineTuner

类功能

功能描述

子查询微调器类,用于微调模型以优化子查询生成。支持NPU加速,使用前需要调用torch.npu.set_device设置NPU设备。

函数原型

from mx_rag.corag import SubqueryFineTuner
SubqueryFineTuner(finetune_args, train_args)

输入参数说明

参数名 数据类型 可选/必选 说明
finetune_args FineTuneArguments 必选 模型微调参数。
train_args TrainingArguments 必选 训练参数,来自transformers库的TrainingArguments。

核心方法

train

功能描述

训练模型,执行模型准备、数据准备、训练器初始化,然后执行训练并保存模型。

函数原型

def train(self)

返回值说明

无返回值,训练完成后会保存模型和分词器到指定目录。

使用示例

基本使用示例

from mx_rag.corag import SubqueryFineTuner, FineTuneArguments
from transformers import TrainingArguments
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

# 设置NPU设备
torch.npu.set_device(0)

# 配置微调参数
finetune_args = FineTuneArguments(
    model_name_or_path="Qwen/Qwen2.5-7B-Instruct",
    train_file="data/aligned_train.jsonl",
    max_len=2048
)

# 配置训练参数
train_args = TrainingArguments(
    output_dir="./output",
    do_train=True,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    logging_dir="./logs",
    learning_rate=1e-5,
    logging_steps=10,
    save_steps=500,
    remove_unused_columns=False
)

# 创建微调器实例
tuner = SubqueryFineTuner(finetune_args, train_args)

# 执行训练
tuner.train()

CoRagEvaluator

类功能

功能描述

CoRAG评估器类,通过多线程并行处理评估数据,计算检索召回率等指标,并生成详细的评估报告。

函数原型

from mx_rag.corag.evaluator import CoRagEvaluator
CoRagEvaluator(config)

参数说明

参数名 数据类型 可选/必选 说明
config CoRagBaseConfig 必选 配置对象,包含LLM实例、API地址和并行参数等。具体可参见CoRagBaseConfig

调用示例

from mx_rag.corag.evaluator import CoRagEvaluator
from mx_rag.corag.config import CoRagBaseConfig
from mx_rag.llm import Text2TextLLM, LLMParameterConfig
from mx_rag.utils import ClientParam

# 初始化LLM实例
llm = Text2TextLLM(base_url="https://{ip}:{port}/v1/chat/completions",
                   model_name="qianwen-7b",
                   llm_config=LLMParameterConfig(max_tokens=512),
                   client_param=ClientParam(ca_file="/path/to/ca.crt")
                   )


# 初始化配置
config = CoRagBaseConfig(
    base_llm=llm,
    retrieve_api_url="http://your-retrieve-api.com/retrieve",
    num_threads=4,
    max_path_length=3,
    client_param=ClientParam(ca_file="/path/to/ca.crt")
)

# 初始化评估器
evaluator = CoRagEvaluator(config)

# 执行评估
eval_results = evaluator.evaluate(
    eval_file="data/eval_data.json",
    save_file="results/corag_eval_results.json",
    calc_recall=True,
    enable_naive_retrieval=True
)

# 输出评估结果
print("评估结果汇总:", eval_results[0])

evaluate

功能描述

执行评估主方法,从评估文件加载数据,并行处理生成评估结果,并保存到输出文件。

函数原型

def evaluate(self, eval_file, save_file, calc_recall, enable_naive_retrieval, num_contexts)

参数说明

参数名 数据类型 可选/必选 默认值 说明
eval_file str 必选 - 评估数据文件路径(JSON格式)。支持HotpotQA和MuSiQue两种格式,详见下方示例。
save_file str 必选 - 结果保存文件路径。
calc_recall bool 可选 True 是否计算召回率。
enable_naive_retrieval bool 可选 True 是否启用朴素检索对比。朴素检索是指通过原始问题直接调用检索API检索相关文档,不依赖CoRAG流程。
num_contexts int 可选 10 检索上下文数量。

HotpotQA格式

[
  {
    "question": "Which company acquired by Google was founded first?",
    "answer": "YouTube",
    "context": [
      ["Title1", ["sentence1", "sentence2"]],
      ["Title2", ["sentence3", "sentence4"]]
    ],
    "supporting_facts": [
      ["Title1", [0, 1]],
      ["Title2", [0]]
    ]
  }
]

MuSiQue格式

[
  {
    "question": "Which company acquired by Google was founded first?",
    "answer": "YouTube",
    "paragraphs": [
      {
        "paragraph_text": "Google is a multinational technology company that specializes in Internet-related services and products.",
        "is_supporting": false
      },
      {
        "paragraph_text": "YouTube is an American online video sharing and social media platform owned by Google. It was founded on February 14, 2005.",
        "is_supporting": true
      },
      {
        "paragraph_text": "Google Maps is a web mapping platform and consumer application offered by Google. It was first launched in February 2005.",
        "is_supporting": false
      }
    ]
  }
]

返回值说明

数据类型 说明
List[Dict[str, Any]] 评估结果列表,第一个元素是聚合指标,后续元素是每个样本的详细评估结果。详见下方示例

评估输出

[
    {
        "type": "Summary",
        "total_samples": 11,
        "corag_accuracy": 0.36,
        "naive_accuracy": 0.090,
        "corag_correct_count": 4,
        "naive_correct_count": 1,
        "avg_path_time": 142.308,
        "avg_time": 56.682,
        "corag_micro_recall": 0.863,
        "naive_micro_recall": 0.68
    },
    {
        "question": "Who is the child of the performer of song Me And Bobby Mcgee?",
        "ground_truth": "Dean Miller",
        "corag_prediction": "xxx",
        "naive_prediction": "xxx",
        "is_correct": true,
        "naive_is_correct": false,
        "reasoning_steps": [
            {
                "subquery": "subquery1",
                "subanswer": "subanswer1"
            }
        ],
        "time": [
            144.0536253452301,
            66.77552223205566
        ],
        "corag_recall": {
            "hits": 1,
            "total": 2,
            "recall": 0.5
        },
        "naive_recall": {
            "hits": 1,
            "total": 2,
            "recall": 0.5
        }
    }, ...
]