"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
from typing import Callable
from loguru import logger
from mx_rag.llm import Text2TextLLM
from mx_rag.reranker import Reranker
from mx_rag.tools.finetune.dataprocess.generate_qd import GENERATE_QD_PROMPT
from mx_rag.tools.finetune.dataprocess.improve_query import improve_query
from mx_rag.tools.finetune.dataprocess.llm_preferred import SCORING_QD_PROMPT
from mx_rag.tools.finetune.generator.common import BaseGenerator
from mx_rag.utils.common import (validate_params, validate_list_str, TEXT_MAX_LEN, STR_MAX_LEN,
MAX_PATH_LENGTH, MAX_PROMPT_LENGTH, BOOL_TYPE_CHECK_TIP)
from mx_rag.utils.file_check import FileCheck
from mx_rag.utils.file_operate import write_jsonl_to_file, read_jsonl_from_file
MAX_DATASET_LEN = 10000
class DataProcessConfig:
"""
功能描述:
DataProcessConfig 微调合成数据生成配置
Attributes:
generate_qd_prompt: (str) 自动生成QD对的LLM所用Prompt
llm_preferred_prompt: (str) 通过LLM优选QD对的Prompt
question_number: (int) 针对每个原始doc生成的query问题数
featured: (bool) QD对精选开关, 借助BM25和reranker融合排序筛选
featured_percentage: (float) 精选筛选比例
preferred: (bool) QD对优选开关, 借助LLM进行评分筛选
llm_threshold_score: (float) 优选筛选比例
rewrite: (bool) 问题重写开关
query_rewrite_number: (int) 问题重写的数量
"""
@validate_params(
generate_qd_prompt=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= MAX_PROMPT_LENGTH,
message=f"param must be a str and its length meets (0, {MAX_PROMPT_LENGTH}]"),
llm_preferred_prompt=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= MAX_PROMPT_LENGTH,
message=f"param must be a str and its length meets (0, {MAX_PROMPT_LENGTH}]"),
question_number=dict(validator=lambda x: isinstance(x, int) and 0 < x <= 20,
message="param must meets: Type is int, length range (0, 20]"),
featured=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
featured_percentage=dict(validator=lambda x: isinstance(x, (float, int)) and 0.0 < x < 1.0,
message="param must be float or int and value range (0.0, 1.0)"),
preferred=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
llm_threshold_score=dict(validator=lambda x: isinstance(x, (float, int)) and 0.0 < x < 1.0,
message="param must be float or int and value range (0.0, 1.0)"),
rewrite=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
query_rewrite_number=dict(validator=lambda x: isinstance(x, int) and 0 < x <= 20,
message="param must meets: Type is int, length range (0, 20]"),
)
def __init__(self,
generate_qd_prompt: str = GENERATE_QD_PROMPT,
llm_preferred_prompt: str = SCORING_QD_PROMPT,
question_number: int = 3,
featured: bool = True,
featured_percentage: float = 0.8,
preferred: bool = True,
llm_threshold_score: float = 0.8,
rewrite: bool = True,
query_rewrite_number: int = 2):
self.generate_qd_prompt = generate_qd_prompt
self.llm_preferred_prompt = llm_preferred_prompt
self.question_number = question_number
self.featured = featured
self.featured_percentage = featured_percentage
self.preferred = preferred
self.llm_threshold_score = llm_threshold_score
self.rewrite = rewrite
self.query_rewrite_number = query_rewrite_number
class TrainDataGenerator(BaseGenerator):
@validate_params(
llm=dict(validator=lambda x: isinstance(x, Text2TextLLM), message="param must be instance of Text2TextLLM"),
dataset_path=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= MAX_PATH_LENGTH,
message=f"param must be str and str length range (0, {MAX_PATH_LENGTH}]"),
reranker=dict(validator=lambda x: isinstance(x, Reranker),
message="param must be instance of Reranker"),
encrypt_fn=dict(validator=lambda x: x is None or isinstance(x, Callable),
message="encrypt_fun must be None or callable function"),
decrypt_fn=dict(validator=lambda x: x is None or isinstance(x, Callable),
message="decrypt_fun must be None or callable function")
)
def __init__(self, llm: Text2TextLLM, dataset_path: str, reranker: Reranker,
encrypt_fn: Callable[[str], str] = None, decrypt_fn: Callable[[str], str] = None):
super().__init__(llm, dataset_path, encrypt_fn, decrypt_fn)
self.reranker = reranker
@validate_params(
split_doc_list=dict(validator=lambda x: validate_list_str(x, [1, TEXT_MAX_LEN], [1, STR_MAX_LEN]),
message=f"param must meets: Type is List[str], list length range [1, {TEXT_MAX_LEN}], "
f"str length range [1, {STR_MAX_LEN}]"),
data_process_config=dict(validator=lambda x: isinstance(x, DataProcessConfig),
message="param must be instance of DataProcessConfig"),
batch_size=dict(validator=lambda x: isinstance(x, int) and 0 < x <= 1024,
message="param must meets: Type is int, length range (0, 1024]"),
)
def generate_train_data(self,
split_doc_list: list[str],
data_process_config: DataProcessConfig,
batch_size: int = 8
):
FileCheck.dir_check(self.dataset_path)
corpus_data_path = os.path.join(self.dataset_path, "train_data.jsonl")
if os.path.exists(corpus_data_path):
logger.info("embedding train data has been created.")
return
logger.info("step Generating rough problem documentation pairs")
query_list, doc_list = self._generate_coarsest_qd_pairs(split_doc_list, data_process_config.question_number,
data_process_config.generate_qd_prompt,
batch_size)
logger.info("step Generated rough problem documentation pairs finished")
if data_process_config.featured:
logger.info("step bm25+reranker query document pair selection")
query_list, doc_list = self._feature_qd_pair(query_list,
doc_list,
self.reranker,
data_process_config.featured_percentage)
logger.info("step bm25+reranker selection finished")
if data_process_config.preferred:
logger.info("step LLM optimizing query document pair")
query_list, doc_list = self._prefer_qd_pair(query_list,
doc_list,
data_process_config.llm_threshold_score,
data_process_config.llm_preferred_prompt,
batch_size)
logger.info("step LLM optimizing query document pair finished")
if data_process_config.rewrite:
logger.info("step Enhancing query diversity and preserving training data")
query_list, doc_list = self._rewrite_query(query_list,
doc_list,
data_process_config.query_rewrite_number,
batch_size)
logger.info("step Enhancing query diversity and preserving training data finished")
train_data = []
for query, doc in zip(query_list, doc_list):
train_data.append({"query": self._encrypt(query), "corpus": self._encrypt(doc)})
train_data_path = os.path.join(self.dataset_path, "train_data.jsonl")
write_jsonl_to_file(train_data, train_data_path)
return
def _rewrite_query(self,
preferred_query_list: list[str],
preferred_doc_list: list[str],
query_rewrite_number: int,
batch_size: int):
if len(preferred_query_list) > MAX_DATASET_LEN or len(preferred_doc_list) > MAX_DATASET_LEN:
logger.error(f"inputs len should not bigger than {MAX_DATASET_LEN}")
return [], []
if len(preferred_query_list) != len(preferred_doc_list):
logger.error(f"preferred_query_list and preferred_doc_list has different len")
return [], []
FileCheck.dir_check(self.dataset_path)
query_list = []
doc_list = []
query_list.extend(preferred_query_list)
doc_list.extend(preferred_doc_list)
rewrite_data_path = os.path.join(self.dataset_path, "rewrite_data.jsonl")
if os.path.exists(rewrite_data_path):
rewrite_data_list = read_jsonl_from_file(rewrite_data_path)
for rewrite_data in rewrite_data_list:
query_list.append(self._decrypt(rewrite_data["query"]))
doc_list.append(self._decrypt(rewrite_data["corpus"]))
if len(query_list) == len(preferred_query_list) * query_rewrite_number:
logger.info("rewrite query finished, skip rewrite query process")
return query_list, doc_list
else:
if len(preferred_query_list) == 0:
per_rewrite_number = 0
remain_number = query_rewrite_number - 1
else:
rewritten_number = int(len(rewrite_data_list) / len(preferred_query_list))
per_rewrite_number = len(rewrite_data_list) % len(preferred_query_list)
remain_number = query_rewrite_number - rewritten_number - 1
if remain_number >= 0:
logger.info("rewrite query not finished, continue to rewrite query process")
remain_queries = preferred_query_list[per_rewrite_number:] + preferred_query_list * remain_number
remain_docs = preferred_doc_list[per_rewrite_number:] + preferred_doc_list * remain_number
new_query_list = self._rewrite(remain_queries, remain_docs, rewrite_data_path, batch_size)
query_list.extend(new_query_list)
doc_list.extend(remain_docs)
else:
logger.info('calculated based on parameters, do not need to rewrite the query')
else:
for i in range(query_rewrite_number):
logger.info(f"The {i + 1}st times rewrite the query")
new_query_list = self._rewrite(preferred_query_list, preferred_doc_list, rewrite_data_path, batch_size)
query_list.extend(new_query_list)
doc_list.extend(preferred_doc_list)
return query_list, doc_list
def _rewrite(self, preferred_query_list, preferred_doc_list, rewrite_data_path, batch_size):
logger.info(f"rewrite query count: {len(preferred_query_list)}")
query_list = []
count = 0
for i in range(0, len(preferred_query_list), batch_size):
chunk_query_list = preferred_query_list[i:i + batch_size]
chunk_doc_list = preferred_doc_list[i:i + batch_size]
new_query_list = improve_query(self.llm, chunk_query_list)
temp_qd_pairs = []
for query, doc in zip(new_query_list, chunk_doc_list):
if query != "":
query_list.append(query)
temp_qd_pairs.append({"query": self._encrypt(query), "corpus": self._encrypt(doc)})
write_jsonl_to_file(temp_qd_pairs, rewrite_data_path, 'a')
logger.info(f"The {count + 1} st time rewrite query success by chunk {batch_size}")
count += 1
return query_list