Prompt压缩
PromptCompressor
类功能
功能描述
prompt压缩抽象类
函数原型
from mx_rag.compress.base_compressor import PromptCompressor
class PromptCompressor(ABC)
compress_texts
功能描述
压缩prompt文本
函数原型
@abstractmethod
def compress_texts(self, context, question)
参数说明
| 参数名 |
数据类型 |
可选/必选 |
说明 |
| context |
str |
必选 |
待总结的长文本。 |
| question |
str |
必选 |
总结长文本的指令。 |
RerankCompressor
类功能
功能描述
通过排序模型计算question(总结长文本的指令)和context(总结长文本的指令)切片之间的相关性得分,根据设定的压缩率阈值,优先保留相关性高的切片,从而实现对长文本的有效压缩。
函数原型
from mx_rag.compress.rerank_compressor import RerankCompressor
class RerankCompressor(reranker, splitter)
输入参数说明
| 参数名 |
数据类型 |
可选/必选 |
说明 |
| reranker |
Reranker |
必选 |
排序模型实例,实现对文本切片进行精排,只能为mx_rag.reranker的Reranker对象,具体可参见Reranker。 |
| splitter |
TextSplitter |
可选 |
文档切分函数,只能为继承自langchain的TextSplitter的子类。默认为langchain.text_splitter的RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0, separators=["\n", ""], keep_separator=True) |
调用示例
from mx_rag.compress.rerank_compressor import RerankCompressor
from mx_rag.reranker.local import LocalReranker
from mx_rag.reranker.service import TEIReranker
from langchain.text_splitter import RecursiveCharacterTextSplitter
from mx_rag.utils import ClientParam
context="""需要压缩的prompt文本"""
question="请给上述内容起一个标题"
tei_reranker=False
if tei_reranker:
reranker = TEIReranker.create(url="https://ip:port/rerank",
client_param=ClientParam(ca_file="/path/to/ca.crt"))
else:
reranker = LocalReranker(model_path="reranker_path", dev_id=0)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0, separators=["\n", ""], keep_separator=True)
compressor=RerankCompressor(reranker=reranker, splitter=text_splitter)
res=compressor.compress_texts(context, question, 0.3)
print(res)
compress_texts
功能描述
根据指令(question)、长文本(context)以及压缩率(compress_rate)压缩文本
函数原型
def compress_texts(context, question, compress_rate, context_reorder)
输入参数说明
| 参数名 |
数据类型 |
可选/必选 |
说明 |
| context |
str |
必选 |
待总结的长文本。长度范围:[1, 16MB] |
| question |
str |
必选 |
总结长文本的指令,用于计算与context文本切片的相关性。长度范围:[1, 1000*1000] |
| compress_rate |
float |
可选 |
压缩率,默认为0.6,取值范围:(0, 1) |
| context_reorder |
bool |
可选 |
是否根据得分重排,默认为False,若为True,计算完相关性之后,将根据压缩率优先保留相关性低的文本切片。 |
返回值说明
ClusterCompressor
类功能
功能描述
通过聚类模型将嵌入后的文本进行聚类,将其划分为多个语义簇。随后,计算context的切片与question(总结长文本的指令)的余弦相似度。根据设定的压缩率,在每个簇内删除相似度较低的切片,从而保留与指令最相关的信息,实现长文本的压缩式总结。
函数原型
from mx_rag.compress.cluster_compressor import ClusterCompressor
class ClusterCompressor(cluster_func, embed, splitter, dev_id):
输入参数说明
| 参数名 |
数据类型 |
可选/必选 |
说明 |
| cluster_func |
Callable[[List[List[float]]], Union[List[int], np.ndarray]] |
必选 |
聚类函数,将嵌入后的文本切片进行聚类,将其划分为多个语义簇,返回的结果必须为List[int]或ndarray,长度不能超过1000*1000,且长度要和文本切片数量一致。 |
| embed |
Embeddings |
必选 |
嵌入对象,把文本切片转换为向量,只能为继承自langchain_core.embeddings的Embeddings的子类。 |
| splitter |
TextSplitter |
可选 |
文档切分对象,只能为继承自langchain_text_splitters.base的TextSplitter的子类。默认为langchain.text_splitter的RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0, separators=["。", "!", "?", "\n", ",", ";", " ", ""]) |
| dev_id |
int |
可选 |
NPU id,通过npu-smi info查询可用ID,取值范围[0, 63],默认为卡0。 |
调用示例
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sklearn.cluster import HDBSCAN
from mx_rag.compress.cluster_compressor import ClusterCompressor
from mx_rag.embedding.local import TextEmbedding
from mx_rag.embedding.service import TEIEmbedding
from mx_rag.utils import ClientParam
context="""需要压缩的prompt文本"""
question="请给上述内容起一个标题"
tei_emb=False
if tei_emb:
emb = TEIEmbedding.create(url="https://ip:port/embed", client_param=ClientParam(ca_file="/path/to/ca.crt"))
else:
emb = TextEmbedding(model_path="embedding_path", dev_id=0)
def _get_community(sentences_embedding):
node_num=len(sentences_embedding)
min_cluster_size=2
hdbscan = HDBSCAN(min_cluster_size=min(min_cluster_size, node_num))
labels = hdbscan.fit_predict(sentences_embedding)
return labels
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0, separators=["。", "!", "?", "\n", ",", ";", " ", ""], )
compressor=ClusterCompressor(cluster_func=_get_community, embed=emb, splitter=splitter, dev_id=0)
res=compressor.compress_texts(context, question, 0.6)
print(res)
compress_texts
功能描述
根据提供的指令(question)、长文本(context)和压缩率(compress_rate)压缩文本。
函数原型
def compress_texts(context, question, compress_rate)
输入参数说明
| 参数名 |
数据类型 |
可选/必选 |
说明 |
| context |
str |
必选 |
待总结的长文本。长度范围:[1, 16MB] |
| question |
str |
必选 |
总结长文本的指令,用于计算与context文本切片的相关性。长度范围:[1, 1000*1000] |
| compress_rate |
float |
可选 |
压缩率,默认值为0.6,取值范围:(0, 1) |
返回值说明