缓存模块
总体说明
MxRAGCache的功能主要是基于开源组件GPTCache进行二次开发,支持以下cache基本功能:
- Cache初始化
- Cache更新
- Cache老化
- Cache查询
- Cache级联
相比GPTCache,MxRAGCache扩展了以下功能:
- 语义近似Cache矢量检索过程支持FAISS_NPU检索(Index SDK)。
- 语义近似Cache embedding支持RAG优化的TEI Embedding。
- 语义近似Cache相似度计算过程支持RAG优化的TEI Reranker。
- 支持RAG SDK chain(图生图,文生文,文生图)的Cache功能。
在原始RAG SDK流程中,在知识文档检索前,增加问答Cache,如果查询命中了Cache,就不会经过大模型推理过程,节省了知识文档检索和大模型推理时延,提升了端到端性能,经过性能测试,缓存命中相比缓存未命中可以提升50倍性能。
模块介绍
缓存类型
章节介绍
缓存类型分为memory cache(hash map + 精确匹配)和similarity cache(向量数据库 + embedding + 相似度计算),其中memory cache用于用户问题完全匹配场景,similarity cache用于用户问题不完全一样,但是比较相似的场景。在实际使用的过程中,也可以将Cache进行串联。
memory cache
Memory cache完全匹配的cache,在内存采用的是hash map结构,hash key是用户的query hash值,hash value是用户的问题,用户问题需要满足完全匹配。
similarity cache
Similarity cache是语义相似匹配的cache,存储结构是sqlite + 向量数据库(faiss, npu_faiss,milvus)。
查询时首先对用户的问题做embedding,从向量数据库查询相似TOPK的结果,然后从sqlite获取缓存答案和问题,再将缓存的问题和用户的问题进行reranker精排,得到最相似的结果返回给用户。该cache不需要满足完全匹配,只需要语义相似即可命中。
缓存配置
章节介绍
本章主要用于阐述提供给用户进行配置数据的内容。
CacheConfig
类功能
用于配置memory cache的配置数据结构。
from gptcache.config import Config
from mx_rag.cache import CacheConfig
# 继承自Config
CacheConfig(cache_size, eviction_policy, auto_flush, data_save_folder, min_free_space, similarity_threshold, disable_report, lock)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| cache_size | int | 必选 | 缓存大小,配置缓存条目数。 cache_size不能小于等于0。 取值范围(0, 100000] |
| eviction_policy | EvictPolicy | 可选 | 缓存老化策略。 默认值为LRU,具体可参考EvictPolicy。 |
| auto_flush | int | 可选 | 数据落盘频度,即缓存多少条条目之后,进行一次落盘操作。 默认值:20 取值范围(0, cache_size] |
| data_save_folder | str | 可选 | 缓存落盘路径,路径长度不能超过1024,不能为软链接和相对路径。 默认值为当前用户家目+"/Ascend/mxRag/cache_save_folder",若不存在需要用户创建。存放路径不能在路径列表中:["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"]。 |
| min_free_space | int | 可选 | 用于检查落盘路径的可用空间,单位字节,默认值为1GB。 取值范围[20MB, 100GB] |
| similarity_threshold | float | 可选 | 相似度计算阈值 默认值:0.8 取值范围[0.0, 1.0] |
| disable_report | bool | 可选 | 是否需要支持维测数据功能 默认值:False 取值范围:True表示不支持;False表示支持。 |
| lock | multiprocessing.synchronize.Lock, _thread.LockType | 可选 | CacheConfig不支持多线程或者多进程进行处理,如果用户需要多进程或者多线程调用此接口需要申请锁。默认值为None。 可选值: None:表示不使用锁,此时该接口不支持并发。 multiprocessing.Lock():表示进程锁,此时该接口支持多进程调用。 threading.Lock():表示线程锁。此时该接口支持多线程调用。 |
Note
- 本接口内部使用了pickle模块,有被恶意构造的数据在unpickle期间攻击的风险。需要保证在被加载的落盘数据data_save_folder是安全存储,仅可加载可信的落盘数据。
- 对于memory cache来说,它的落盘文件不能超过100MB大小。
from paddle.base import libpaddle
from mx_rag.cache import CacheConfig
from mx_rag.cache import EvictPolicy
from mx_rag.cache import MxRAGCache
cache_config = CacheConfig(
cache_size=100,
eviction_policy=EvictPolicy.LRU,
data_save_folder="path_to_cache_save_folder"
)
mxrag_l1_cache = MxRAGCache("memory_cache", cache_config)
SimilarityCacheConfig
类功能
用于配置similarity cache的配置数据结构。
from mx_rag.cache import SimilarityCacheConfig
SimilarityCacheConfig(vector_config, cache_config, emb_config, similarity_config, retrieval_top_k, clean_size, **kwargs)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| vector_config | Dict[str, Any] | 必选 | 配置矢量数据库。具体配置参见表1,注意这里如果是faiss数据库,MindFAISS的load_local_index参数会被data_save_folder落盘路径覆盖,auto_save参数置为False,配置中的字符串长度不能超过1024,字典中包含的可迭代序列长度不能超过1024,字典长度不能超过1024,字典嵌套深度不能超过2层。 |
| cache_config | str | 必选 | 配置标量数据库。当前只支持配置为“sqlite”。 |
| emb_config | Dict[str, Any] | 必选 | 配置embedding模型,请参见表2。字典长度不能超过1024,字典中字符串长度不能超过1024,字典嵌套深度不能超过1层。 |
| similarity_config | Dict[str, Any] | 必选 | 配置相似度计算模型,字典长度不能超过1024,字典中字符串长度不能超过1024,字典嵌套深度不能超过1层。请参考表3。 |
| retrieval_top_k | int | 可选 | 相似检索时的topk值。默认值为1。取值范围:(0, 1000]。 |
| clean_size | int | 可选 | 每次缓存数据添加超过cache_size时,老化的个数,默认值为1。取值范围:(0, cache_size] |
| **kwargs | Any | 必选 | 参数介绍可参见CacheConfig。 |
Note
- 本接口内部使用了pickle模块,有被恶意构造的数据在unpickle期间攻击的风险。需要保证在被加载的落盘数据data_save_folder是安全存储,仅可加载可信的落盘数据。
- vector_config和cache_config必须同时为None或同时不为None。如果vector_config和cache_config同时为None,则等同于memory cache。
- 对于sqlite数据库来说,它的落盘文件不能超过30GB,对于矢量数据库来说,它的落盘文件不能超过20GB。
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| **kwargs | Dict[str, Any] | 必选 | 具体介绍可参见create_storage。 |
| top_k | int | 可选 | 相似度检索时的topk个数。默认值:5。 |
| vector_save_file | str | 必选 | 落盘路径。vector_type为"npu_faiss_db"时,该参数会覆盖MindFAISS中的load_local_index参数作为落盘路径;对于milvus_db,该参数不生效。 |
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| x_dim | int | 可选 | embedding模型的维度,默认值为0。 |
| skip_emb | bool | 可选 | 是否跳过embedding,默认值为False。 |
| **kwargs | Dict[str, Any] | 必选 | 具体介绍可参见create_embedding。 |
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| score_min | float | 可选 | 相似度计算可能存在范围的最小值,默认值为0.0。 取值范围:[0.0, 100.0] |
| score_max | float | 可选 | 相似度计算可能存在范围的最大值,默认值为1。 取值范围:[1.0, 100.0],"score_max"必须大于等于"score_min"。 |
| reverse | bool | 可选 | 相似度分数和相似度的关系,默认值为"False"。 |
| **kwargs | Dict[str, Any] | 必选 | 具体介绍可参见create_reranker。 |
示例1:faiss_npu + local_embedding + local_reranker
from mx_rag.cache import SimilarityCacheConfig
from mx_rag.cache import MxRAGCache
dim = 1024
dev = 1
similarity_config = SimilarityCacheConfig(
vector_config={
"vector_type": "npu_faiss_db",
"x_dim": dim,
"devs": [dev],
},
cache_config="sqlite",
emb_config={
"embedding_type": "local_text_embedding",
"x_dim": dim,
"model_path": "path_to_embedding_model", # emb 模型路径
"dev_id": dev
},
similarity_config={
"similarity_type": "local_reranker",
"model_path": "path_to_reranker_model", # reranker 模型路径
"dev_id": dev
},
retrieval_top_k=1,
cache_size=1000,
clean_size=20,
similarity_threshold=0.86,
data_save_folder="path_to_cache_save_folder", # 落盘路径
disable_report=True
)
similarity_cache = MxRAGCache("similarity_cache", similarity_config)
示例2:milvus_db + tei_embedding + tei_reranker
import getpass
from paddle.base import libpaddle
from mx_rag.cache import SimilarityCacheConfig
from mx_rag.cache import EvictPolicy
from mx_rag.cache import MxRAGCache
from mx_rag.utils import ClientParam
from pymilvus import MilvusClient
dim = 1024
client = MilvusClient("https://x.x.x.x:port", user="xxx", password=getpass.getpass(), secure=True, client_pem_path="path_to/client.pem", client_key_path="path_to/client.key", ca_pem_path="path_to/ca.pem", server_name="localhost")
similarity_config = SimilarityCacheConfig(
vector_config={
"client": client,
"vector_type": "milvus_db",
"x_dim": dim,
"collection_name": "mxrag_cache_123", # milvus db的标签
"param": None
},
cache_config="sqlite",
emb_config={
"embedding_type": "tei_embedding",
"url": "https://<ip>:<port>/embed", # tei_embedding 服务的IP地址和侦听端口
"client_param": ClientParam(ca_file="/path/to/ca.crt")
},
similarity_config={
"similarity_type": "tei_reranker",
"url": "https://<ip>:<port>/rerank", # tei_reranker 服务的IP地址和侦听端口
"client_param": ClientParam(ca_file="/path/to/ca.crt")
},
retrieval_top_k=1,
cache_size=100,
auto_flush=100,
similarity_threshold=0.70,
data_save_folder="path_to_cache_save_folder",
disable_report=True,
eviction_policy=EvictPolicy.FIFO
)
similarity_cache = MxRAGCache("similarity_cache", similarity_config)
EvictPolicy
类功能
用于配置缓存的替换策略。
from mx_rag.cache import EvictPolicy
class EvictPolicy(Enum)
| 属性名 | 数据类型 | 说明 |
|---|---|---|
| LRU | str | 替换最久没有访问的缓存。 |
| LFU | str | 替换使用频率最低的缓存。 |
| FIFO | str | 按照先进先出的规则进行替换。 |
| RR | str | 随机替换缓存。 |
缓存数据
章节介绍
本章节主要用于阐述给用户提供缓存数据更新,缓存数据查找,缓存数据落盘功能的MxRAGCache。
MxRAGCache
类功能
主要提供用户问题和答案的缓存存储、缓存更新、缓存落盘功能。
from mx_rag.cache import MxRAGCache
MxRAGCache(cache_name, config)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| cache_name | str | 必选 | 缓存名字,该命名会体现在落盘文件名中,字符串长度范围:(0, 64)。取值范围:[0-9a-zA-Z_] 只能是字母、数字或下划线。 |
| config | CacheConfig/SimilarityCacheConfig | 必选 | 缓存配置,参考缓存配置。 |
import json
import getpass
from paddle.base import libpaddle
from pymilvus import MilvusClient
from mx_rag.cache import CacheConfig, SimilarityCacheConfig
from mx_rag.cache import EvictPolicy
from mx_rag.cache import MxRAGCache
from mx_rag.utils import ClientParam
dim = 1024
cache_config = CacheConfig(
cache_size=100,
eviction_policy=EvictPolicy.LRU,
data_save_folder="path_to_cache_save_folder"
)
cache = MxRAGCache("memory_cache", cache_config)
# 检查cache实例是否初始化成功
cache_obj = cache.get_obj()
if cache_obj is None:
print(f"cache init failed")
similarity_config = SimilarityCacheConfig(
vector_config={
"vector_type": "milvus_db",
"x_dim": dim,
"client": MilvusClient("https://x.x.x.x:port", user="xxx", password=getpass.getpass(), secure=True, client_pem_path="path_to/client.pem", client_key_path="path_to/client.key", ca_pem_path="path_to/ca.pem", server_name="localhost")
"collection_name": "mxrag_cache_123", # milvus db的标签
"use_http": False,
"param": None
},
cache_config="sqlite",
emb_config={
"embedding_type": "tei_embedding",
"url": "https://<ip>:<port>/embed", # tei_embedding 服务的IP地址和侦听端口
"client_param": ClientParam(ca_file="/path/to/ca.crt")
},
similarity_config={
"similarity_type": "tei_reranker",
"url": "https://<ip>:<port>/rerank", # tei_reranker 服务的IP地址和侦听端口
"client_param": ClientParam(ca_file="/path/to/ca.crt")
},
retrieval_top_k=1,
cache_size=100,
auto_flush=100,
similarity_threshold=0.70,
data_save_folder="path_to_cache_save_folder",
disable_report=True,
eviction_policy=EvictPolicy.FIFO
)
similarity_cache = MxRAGCache("similarity_cache", similarity_config)
# 设置缓存级联
cache.join(similarity_cache)
# 设置缓存每条的字符限制为4000个字符
cache.set_cache_limit(4000)
# 设置是否详细显示缓存过程
cache.set_verbose(False)
# 手动更新缓存
cache.update("小明的爸爸是谁?", json.dumps({"小明的爸爸是谁?": "小明的爸爸名字是大明"}))
# 精确匹配结果
res = cache.search("小明的爸爸是谁?")
print(f"memory match res: {res}")
# 语义近似匹配结果
res = cache.search("小明的爸爸叫什么名字")
print(f"similarity match res: {res}")
# 手动调用flush 将缓存落盘,也会按照auto_flush配置进行自动落盘
cache.flush()
# 删除已落盘的文件和数据
cache.clear()
clear
删除data_save_folder下已落盘的缓存文件,对于memory cache,由于关闭时会自动flush再次写入缓存,需要用户程序中再次清理。
def clear()
flush
该函数主要是将用户的缓存数据强制从内存中刷新至磁盘空间,刷新地址为类功能中的配置参数“data_save_folder”。
必须在初始化之后进行调用。
def flush()
get_obj
获得gptcache对象,用于适配LangChain等开源RAG框架,必须在初始化之后进行调用。
def get_obj()
| 数据类型 | 说明 |
|---|---|
| gptcache对象 | gptcache对象 |
join
将两个cache进行串联,达到多级缓存目的。
def join(next_cache)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| next_cache | MxRAGCache | 必选 | 下一级缓存,缓存必须先初始化,不能串联自己。下级缓存必须是MxRAGCache类型。缓存串联时不能成环,并且最大串联深度为6 |
search
该函数主要负责根据用户的问题找到对应的答案,必须在初始化之后进行调用。
def search(query)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| query | str | 必选 | 用户的问题,字符串长度范围(0, 128*1024*1024]。 |
| 数据类型 | 说明 |
|---|---|
| Dict | 返回QA对的 Unicode 编码字符串,如:{"\u66aa....?": "\u6ae6...",...} |
set_cache_limit
设置cache缓存时LLM返回答案的字符数限制。如果LLM返回的字符串超过这个限制,则不会被缓存。
@classmethod
def set_cache_limit(cache_limit: int)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| cache_limit | int | 必选 | 每一条被缓存答案的字符数限制,默认值是一百万个字符,取值范围(0, 1000000]。中文是计算转为unicode后长度。 |
set_verbose
设置是否开启日志记录。
@classmethod
def set_verbose(verbose: bool)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| verbose | bool | 必选 | 是否开启详细日志记录,默认值为False。 |
update
该函数主要负责将用户问题和答案进行存储,必须在初始化之后调用,否则会抛出异常。
def update(query, answer)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| query | str | 必选 | 用户的问题,字符串长度(0, 128*1024*1024]范围 |
| answer | str | 必选 | 用户的问题所对应的答案,其字符数范围(0, min(1000000 , cache_limit)],否则不会被缓存,具体可参考set_cache_limit。 |
缓存适配
章节介绍
本章节主要介绍MxRAGCache的chain适配。
CacheChainChat
类功能
用于适配RAG SDK的文生文、图生图、文生图的各种chain,同时也提供访问MxRAGCache的能力,当缓存未命中时,将进行大模型推理,然后将结果再刷新至缓存。
from mx_rag.cache import CacheChainChat
CacheChainChat(cache,chain,convert_data_to_cache,convert_data_to_user)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| cache | MxRAGCache | 必选 | RAG SDK缓存。 |
| chain | Chain | 必选 | RAG SDK chain,用于访问大模型。 |
| convert_data_to_cache | Callable[[Any], Dict] | 可选 | 该回调函数主要用于当用户数据无法转换为字符串格式时,由用户提供转换函数。 默认为不做转换。 |
| convert_data_to_user | Callable[[Dict], Any] | 可选 | 该回调函数主要是配合convert_data_to_cache使用,当用户问题命中时,将cache存储的格式转换为用户格式。默认为不做转换。 |
import time
from paddle.base import libpaddle
from langchain.text_splitter import RecursiveCharacterTextSplitter
from mx_rag.chain import SingleText2TextChain
from mx_rag.document.loader import DocxLoader
from mx_rag.embedding.local import TextEmbedding
from mx_rag.knowledge import KnowledgeDB
from mx_rag.knowledge.knowledge import KnowledgeStore
from mx_rag.llm import Text2TextLLM
from mx_rag.storage.document_store import SQLiteDocstore
from mx_rag.knowledge.handler import upload_files
from mx_rag.document import LoaderMng
from mx_rag.storage.vectorstore import MindFAISS
from mx_rag.utils import ClientParam
from mx_rag.cache import CacheChainChat, MxRAGCache, SimilarityCacheConfig
#向量维度
dim = 1024
# NPU卡id
dev = 0
similarity_config = SimilarityCacheConfig(
vector_config={
"vector_type": "npu_faiss_db",
"x_dim": dim,
"devs": [dev],
},
cache_config="sqlite",
emb_config={
"embedding_type": "local_text_embedding",
"x_dim": dim,
"model_path": "/path to emb", # emb 模型路径
"dev_id": dev
},
similarity_config={
"similarity_type": "local_reranker",
"model_path": "/path to reranker", # reranker 模型路径
"dev_id": dev
},
retrieval_top_k=1,
cache_size=1000,
clean_size=20,
similarity_threshold=0.86,
data_save_folder="/save path", # 落盘路径
disable_report=True
)
similarity_cache = MxRAGCache("similarity_cache", similarity_config)
# cache 初始化
cache = MxRAGCache("similarity_cache", similarity_config)
# Step1离线构建知识库,首先注册文档处理器
loader_mng = LoaderMng()
# 加载文档加载器,可以使用RAG SDK自有的,也可以使用langchain的
loader_mng.register_loader(DocxLoader, [".docx"])
# 加载文档切分器,使用langchain的
loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".xlsx", ".docx", ".pdf"],
{"chunk_size": 200, "chunk_overlap": 50, "keep_separator": False})
emb = TextEmbedding(model_path="/path to emb", dev_id=dev)
# 初始化文档chunk关系数据库
chunk_store = SQLiteDocstore(db_path="./sql.db")
# 初始化知识管理关系数据库
knowledge_store = KnowledgeStore(db_path="./sql.db")
# 初始化矢量检索
vector_store = MindFAISS(x_dim=dim,
devs=[dev],
load_local_index="./faiss.index"
)
#添加知识库及管理员
knowledge_store.add_knowledge(knowledge_name="test", user_id='Default', role='admin')
# 初始化知识库管理
knowledge_db = KnowledgeDB(knowledge_store=knowledge_store,
chunk_store=chunk_store,
vector_store=vector_store,
knowledge_name="test",
user_id='Default',
white_paths=["/home"])
# 完成离线知识库构建,上传领域知识test.docx文档。
upload_files(knowledge_db, ["/path to files"],
loader_mng=loader_mng,
embed_func=emb.embed_documents,
force=True)
# Step2在线问题答复,初始化检索器
retriever = vector_store.as_retriever(document_store=chunk_store,
embed_func=emb.embed_documents, k=3, score_threshold=0.3)
# 配置reranker
# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
llm = Text2TextLLM(base_url="https://<ip>:<port>",
model_name="Llama3-8B-Chinese-Chat",
client_param=ClientParam(ca_file="/path/to/ca.crt"))
text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever)
cache_chain = CacheChainChat(chain=text2text_chain, cache=cache)
start_time = time.time()
res = cache_chain.query("请描述2024年高考作文题目")
end_time = time.time()
print(f"no cache query time cost:{(end_time - start_time) * 1000}ms")
print(f"no cache answer {res}")
start_time = time.time()
res = cache_chain.query("2024年的高考题目是什么", )
end_time = time.time()
print(f"cache query time cost:{(end_time - start_time) * 1000}ms")
print(f"cache answer {res}")
query
提供给用户查询缓存的接口,当缓存无法查询时,会访问大模型。
def query(text, *args, **kwargs) -> Union[Dict, Iterator[Dict]]
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| text | str | 必选 | 用户的原始问题,字符数范围(0, 128M] |
| llm_config | LLMParameterConfig | 可选 | 大模型参数,具体介绍可参见LLMParameterConfig。 |
| *args/**kwargs | Any | 可选 | 继承父类方法签名,RAG SDK不涉及使用。 |
| 数据类型 | 说明 |
|---|---|
| Union[Dict, Iterator[Dict]] | 返回问答结果,其中Dict内容为: |
自动生成QA作为缓存
章节介绍
本章节主要介绍生成QA,并更新到MxRAGCache缓存中的接口。
QAGenerate
类功能
问答生成类,输入标题和正文,调用大模型根据该正文生成问答对。
from mx_rag.cache import QAGenerate
QAGenerate(config: QAGenerationConfig)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| config | QAGenerationConfig | 必选 | QAGenerationConfig对象,生成QA的相关参数。 关于QAGenerationConfig原型说明请参见QAGenerationConfig。 |
from paddle.base import libpaddle
from transformers import AutoTokenizer
from mx_rag.cache import QAGenerationConfig, QAGenerate
from mx_rag.llm import Text2TextLLM
from mx_rag.utils import ClientParam
llm = Text2TextLLM(base_url="https://ip:port/v1/chat/completions", model_name="llama3-chinese-8b-chat",
client_param=ClientParam(ca_file="/path/to/ca.crt"))
# 使用模型的tokenizer, 传入模型存放路径
tokenizer = AutoTokenizer.from_pretrained("/home/model/Llama3-8B-Chinese-Chat/", local_files_only=True)
# 可以调用MarkDownParser生成titles和contents
titles = ["2024年高考语文作文题目"]
contents = ['2024年高考语文作文试题\n新课标I卷\n阅读下面的材料,根据要求写作。(60分)\n'
'随着互联网的普及、人工智能的应用,越来越多的问题能很快得到答案。那么,我们的问题是否会越来越少?\n'
'以上材料引发了你怎样的联想和思考?请写一篇文章。'
'要求:选准角度,确定立意,明确文体,自拟标题;不要套作,不得抄袭;不得泄露个人信息;不少于800字。']
config = QAGenerationConfig(titles, contents, tokenizer, llm, qas_num=1)
qa_generate = QAGenerate(config)
qas = qa_generate.generate_qa()
print(qas)
generate_qa
通过QAGenerationConfig传入标题和正文,会按照QAGenerationConfig中的max_tokens值截断正文。
大模型返回的QA如果不符合格式和个数要求,则会跳过。比如生成三个符合要求的QA:
Q1:如何查询成都火车站的停运列车?
参考段落:'查询方式:铁路12306网页首页。查询流程:第一步:进入铁路12306app首页,点击【车站大屏】;第二步:左上角车站名下拉选择成都东站;第三步:搜索框输入车次即可查询车次情况。'
Q2:四川省将洪水灾害防御响应提升至哪个级别?
参考段落:四川将洪水灾害防御四级响应提升至三级。
Q3:在7月14日,四川省气象台发布了哪种天气预警?
参考段落:7月14日15时30分,四川省气象台继续发布暴雨蓝色预警。
def generate_qa(llm_config)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| llm_config | LLMParameterConfig | 可选 | 调用大模型的参数,此处修改默认值temperature为0.5,top_p为0.95,其余参数说明请参见LLMParameterConfig。 |
| 数据类型 | 说明 |
|---|---|
| Dict | 返回生成的QA对列表,格式为 {"从成都到重庆要多久? : 乘坐高铁1个小时",...} |
QAGenerationConfig
生成QA的参数。
from mx_rag.cache import QAGenerationConfig
QAGenerationConfig(titles, contents, tokenizer, llm, max_tokens, qas_num)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| titles | List[str] | 必选 | 标题列表,标题和正文列表一一对应。列表长度范围[1, 10000],字符串长度范围[1, 100]。 |
| contents | List[str] | 必选 | 正文列表,标题和正文列表一一对应。列表长度范围[1, 10000],字符串长度范围[1, 1048576]。 |
| tokenizer | transformers.PreTrainedTokenizerBase | 必选 | tokenizer实例,通过AutoTokenizer.from_pretrained加载。加载外部模型有安全风险,local_files_only设置为True。 |
| llm | Text2TextLLM | 必选 | 大模型对象实例,具体类型请参见Text2TextLLM。 |
| max_tokens | int | 可选 | 用于截断正文的最大token大小,超出部分丢弃,取值范围[500, 100000],默认值为1000。 |
| qas_num | int | 可选 | 生成QA对数量,取值范围[1, 10],默认值为5。 |
from paddle.base import libpaddle
from transformers import AutoTokenizer
from mx_rag.cache import QAGenerationConfig, QAGenerate
from mx_rag.llm import Text2TextLLM
from mx_rag.utils import ClientParam
llm = Text2TextLLM(base_url="https://ip:port/v1/chat/completions", model_name="llama3-chinese-8b-chat",
client_param=ClientParam(ca_file="/path/to/ca.crt"))
# 使用模型的tokenizer, 传入模型存放路径
tokenizer = AutoTokenizer.from_pretrained("/home/model/Llama3-8B-Chinese-Chat/", local_files_only=True)
# 可以调用MarkDownParser生成titles和contents
titles = ["2024年高考语文作文题目"]
contents = ['2024年高考语文作文试题\n新课标I卷\n阅读下面的材料,根据要求写作。(60分)\n'
'随着互联网的普及、人工智能的应用,越来越多的问题能很快得到答案。那么,我们的问题是否会越来越少?\n'
'以上材料引发了你怎样的联想和思考?请写一篇文章。'
'要求:选准角度,确定立意,明确文体,自拟标题;不要套作,不得抄袭;不得泄露个人信息;不少于800字。']
config = QAGenerationConfig(titles, contents, tokenizer, llm, qas_num=1)
qa_generate = QAGenerate(config)
qas = qa_generate.generate_qa()
print(qas)
MarkDownParser
类功能
解析Markdown,返回的标题和正文的类。
from mx_rag.cache import MarkDownParser
MarkDownParser(file_path, max_file_num)
| 参数名 | 数据类型 | 可选/必选 | 说明 |
|---|---|---|---|
| file_path | str | 必选 | markdown文件所在的文件夹路径,路径长度不能超过1024。调用parse时会校验:不能为软链接和相对路径,文件夹路径下的.md文件大小不能超过10MB,.md文件数量不能超过max_file_num。路径不能在路径列表中:["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"]。 |
| max_file_num | int | 可选 | 最大解析的markdown文件个数,默认值为1000,取值范围[1, 10000]。 |
| 数据类型 | 说明 |
|---|---|
| Tuple[List[str], List[str]] | 返回的是Markdown解析后的titles列表和contents列表。 |
from paddle.base import libpaddle
from mx_rag.cache import MarkDownParser
dir_path = "path of .md document "
parser = MarkDownParser(dir_path)
titles, contents = parser.parse()
print(titles)
print(contents)
parse
返回markdown的标题和正文,文件夹内markdown文件数量不能超过max_file_num。
def parse()
| 数据类型 | 说明 |
|---|---|
| Tuple[List[str], List[str]] | 返回的是markdown对应的titles列表和contents列表。 |

