"""
Copyright 2026 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import getpass
import json
import os
import re
import sys
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from dataclasses import dataclass
from typing import Optional, Tuple
from backend.config.config import Config
from backend.utils.file_validator import FileValidator
from backend.services.project_service import ProjectService
from backend.services.llm_service import LLMService
from backend.services.document_service import DocumentService
from backend.services.question_service import QuestionService
from backend.services.dataset_service import DatasetService
from backend.models.param_config import DatasetCreateConfig
from backend.models.constants import (
DOCUMENT_STATUS_QUESTIONS_GENERATED,
QUESTION_STATUS_UNANSWERED,
QUESTION_STATUS_ANSWERED,
QUESTION_STATUS_REVIEWED,
API_KEY_MIN_LENGTH,
API_KEY_MAX_LENGTH,
OPENAI_KEY_PREFIX,
OPENAI_KEY_PATTERN,
API_KEY_SUSPICIOUS_PATTERNS,
DEFAULT_LLM_PROVIDER,
ARG_CONFIG,
ARG_PROJECT,
ARG_DOCUMENT,
ARG_GENERATE_QUESTIONS,
ARG_GENERATE_ANSWERS,
ARG_GENERATE_ANSWERS_NO_CHAIN,
ARG_CREATE_DATASET,
ARG_EXPORT_DATASET,
ARG_IMPORT_QUESTIONS,
ARG_CONFIG_DEFAULT,
ARG_STORE_DEFAULT,
HELP_CONFIG,
HELP_PROJECT,
HELP_DOCUMENT,
HELP_GENERATE_QUESTIONS,
HELP_GENERATE_ANSWERS,
HELP_GENERATE_ANSWERS_NO_CHAIN,
HELP_CREATE_DATASET,
HELP_EXPORT_DATASET,
HELP_IMPORT_QUESTIONS,
MSG_API_KEY_EMPTY,
MSG_API_KEY_TOO_SHORT,
MSG_API_KEY_TOO_LONG,
MSG_API_KEY_PREFIX,
MSG_DOCUMENT_NOT_FOUND,
MSG_XLSX_NOT_FOUND,
MSG_NO_DOCUMENTS,
MSG_DATASET_NOT_FOUND,
MSG_CONFIG_NOT_FOUND,
)
from backend.utils.logger import init_logger
logger = init_logger(__name__)
class ConfigurationError(Exception):
"""配置错误异常类"""
pass
def _read_stdin_config() -> dict:
"""从标准输入读取配置(JSON格式)"""
if sys.stdin.isatty():
return {}
def read_and_parse():
data = sys.stdin.read().rstrip("\n")
if data:
return json.loads(data)
return None
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(read_and_parse)
result = future.result(timeout=5.0)
return result if result is not None else {}
except TimeoutError:
return {}
except Exception:
return {}
def _is_real_terminal() -> bool:
"""检测是否为真实终端(用于判断是否可以使用getpass隐藏输入)"""
if not sys.stdin.isatty():
return False
try:
return os.isatty(sys.stdin.fileno())
except Exception:
return False
def _prompt_api_key(provider: str) -> str:
"""提示用户输入API密钥"""
if _is_real_terminal():
return getpass.getpass(f"\n需要 LLM API KEY (provider: {provider})\n请输入 API KEY: ")
else:
logger.warning(f"\n警告: 使用IDE终端执行不支持密码隐藏,输入将明文显示")
logger.info(f"需要 LLM API KEY (provider: {provider})")
return input("请输入 API KEY: ").strip()
def _validate_api_key(api_key: str, provider: str) -> Tuple[bool, str]:
"""验证API KEY的格式和基本要求"""
if not api_key or not isinstance(api_key, str):
return False, MSG_API_KEY_EMPTY
api_key = api_key.strip()
if len(api_key) < API_KEY_MIN_LENGTH:
return False, MSG_API_KEY_TOO_SHORT.format(API_KEY_MIN_LENGTH)
if len(api_key) > API_KEY_MAX_LENGTH:
return False, MSG_API_KEY_TOO_LONG.format(API_KEY_MAX_LENGTH)
if provider.lower() == DEFAULT_LLM_PROVIDER:
if not api_key.startswith(OPENAI_KEY_PREFIX):
return False, MSG_API_KEY_PREFIX.format(OPENAI_KEY_PREFIX)
if not OPENAI_KEY_PATTERN.match(api_key):
return False, "OpenAI API KEY格式不正确"
for pattern in API_KEY_SUSPICIOUS_PATTERNS:
if re.search(pattern, api_key.lower()):
return False, f"API KEY包含可疑内容: {pattern}"
return True, ""
def _parse_arguments() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="问答数据集生成工具")
parser.add_argument(ARG_CONFIG, type=str, default=ARG_CONFIG_DEFAULT, help=HELP_CONFIG)
parser.add_argument(ARG_PROJECT, type=str, nargs='+', help=HELP_PROJECT)
parser.add_argument(ARG_DOCUMENT, type=str, help=HELP_DOCUMENT)
parser.add_argument(ARG_GENERATE_QUESTIONS, action=ARG_STORE_DEFAULT, help=HELP_GENERATE_QUESTIONS)
parser.add_argument(ARG_GENERATE_ANSWERS, action=ARG_STORE_DEFAULT, help=HELP_GENERATE_ANSWERS)
parser.add_argument(ARG_GENERATE_ANSWERS_NO_CHAIN, action=ARG_STORE_DEFAULT, help=HELP_GENERATE_ANSWERS_NO_CHAIN)
parser.add_argument(ARG_CREATE_DATASET, type=str, help=HELP_CREATE_DATASET)
parser.add_argument(ARG_EXPORT_DATASET, type=str, help=HELP_EXPORT_DATASET)
parser.add_argument(ARG_IMPORT_QUESTIONS, type=str, help=HELP_IMPORT_QUESTIONS)
return parser.parse_args()
@dataclass
class DatasetCreationContext:
"""
数据集创建上下文
用于封装数据集创建过程中需要的服务和配置参数。
"""
args: argparse.Namespace
config: Config
document_service: DocumentService
question_service: QuestionService
dataset_service: DatasetService
project: any
def _needs_api_key(args: argparse.Namespace) -> bool:
"""判断当前命令是否需要LLM API KEY"""
return args.generate_questions or args.generate_answers or args.generate_answers_no_chain
def _resolve_api_key(config: Config) -> Optional[str]:
"""获取并验证LLM API KEY"""
llm_api_key = None
stdin_config = _read_stdin_config()
if stdin_config:
llm_api_key = stdin_config.get('llm_api_key')
if not llm_api_key:
try:
llm_api_key = _prompt_api_key(config.llm_provider)
except (EOFError, IOError):
pass
if not llm_api_key:
error_detail = (
f"无法获取LLM API KEY\n"
f"请通过以下方式之一提供正确的API KEY:\n"
f" 1. 在前端页面<配置管理>中配置 <API 密钥>\n"
f" 2. 使用CLI而非IDE运行main.py"
)
raise ConfigurationError(error_detail)
is_valid, error_msg = _validate_api_key(llm_api_key, config.llm_provider)
if not is_valid:
error_detail = (
f"API KEY验证失败: {error_msg}\n"
f"请通过以下方式之一提供正确的API KEY:\n"
f" 1. 在前端页面<配置管理>中配置 <API 密钥>\n"
f" 2. 使用CLI而非IDE运行main.py"
)
raise ConfigurationError(error_detail)
return llm_api_key
def _create_llm_service(config: Config, api_key: Optional[str]) -> LLMService:
"""创建LLMService实例"""
from backend.models.param_config import (
LLMProviderConfig,
LLMGenerationConfig,
LLMPromptConfig,
)
provider_config = LLMProviderConfig(
provider=config.llm_provider,
api_key=api_key,
model_name=config.llm_model,
llm_api=config.llm_api,
)
generation_config = LLMGenerationConfig(
max_tokens=config.max_tokens,
temperature=config.temperature,
)
prompt_config = LLMPromptConfig(
system_prompt=config.system_prompt,
answer_prompt=config.answer_prompt,
chain_of_thought_prompt=config.chain_of_thought_prompt,
)
return LLMService(
provider_config=provider_config,
generation_config=generation_config,
prompt_config=prompt_config,
)
def _get_project_dir(config: Config, project) -> str:
"""获取项目数据目录"""
return os.path.join(config.data_dir, str(project.id))
def _handle_document(args, document_service, project) -> None:
"""处理文档导入"""
if not (args.document and project):
return
if not os.path.exists(args.document):
raise ConfigurationError(MSG_DOCUMENT_NOT_FOUND.format(args.document))
absolute_document_path = os.path.abspath(args.document)
document_validation = FileValidator().validate_file(
os.path.basename(absolute_document_path),
absolute_document_path,
os.path.dirname(absolute_document_path)
)
if not document_validation.get('valid', False):
errors = document_validation.get('errors', [])
raise ConfigurationError(f"文档文件校验失败: {'; '.join(errors)}")
logger.info(f"正在处理文档: {args.document}")
document = document_service.create_document(project.id, args.document)
chunks = document_service.split_document(document)
logger.info(f"文档已分割为 {len(chunks)} 个文本块")
def _collect_unanswered_questions(document_service, question_service, project) -> list:
"""收集项目中所有未回答的问题"""
all_questions = []
documents = document_service.get_documents(project.id)
for document in documents:
questions = question_service.get_questions_by_document_id(document.id)
all_questions.extend(questions)
return [q for q in all_questions if q.status == QUESTION_STATUS_UNANSWERED]
def _handle_generate_questions(args, document_service, question_service, project) -> None:
"""处理问题生成"""
if not (args.generate_questions and project):
return
logger.info("正在生成问题...")
documents = document_service.get_documents(project.id)
document = documents[-1]
if document.status < DOCUMENT_STATUS_QUESTIONS_GENERATED:
questions = question_service.generate_questions_for_document(document)
logger.info(f"已为文档 {document.file_name} 生成 {len(questions)} 个问题")
else:
logger.info(f"文档 {document.file_name} 的问题已生成")
def _handle_generate_answers(args, document_service, question_service, project) -> None:
"""处理答案生成(含/不含思维链)"""
if not (args.generate_answers or args.generate_answers_no_chain) or not project:
return
with_no_chain = args.generate_answers_no_chain
log_message = "正在生成答案(不含思维链)..." if with_no_chain else "正在生成答案和思维链..."
success_message = "已为 {0} 个问题生成答案(不含思维链)" if with_no_chain else "已为 {0} 个问题生成答案和思维链"
logger.info(log_message)
unanswered = _collect_unanswered_questions(document_service, question_service, project)
if unanswered:
updated = question_service.generate_answers(unanswered, with_chain_of_thought=False) \
if with_no_chain else question_service.generate_answers(unanswered, with_chain_of_thought=True)
logger.info(success_message.format(len(updated)))
else:
logger.info("所有问题都已回答")
def _handle_import_questions(args, document_service, question_service, project) -> None:
"""处理问题导入"""
if not (args.import_questions and project):
return
if not os.path.exists(args.import_questions):
raise ConfigurationError(MSG_XLSX_NOT_FOUND.format(args.import_questions))
logger.info(f"正在从 {args.import_questions} 导入问题...")
documents = document_service.get_documents(project.id)
if not documents:
raise ConfigurationError(MSG_NO_DOCUMENTS)
document = documents[0]
chunk_id = document.chunks[0] if document.chunks else 0
imported = question_service.import_questions_from_xlsx(
args.import_questions, document.id, chunk_id
)
logger.info(f"已从XLSX文件中导入 {len(imported)} 个问题")
xlsx_path = args.import_questions
if imported:
if not os.path.exists(xlsx_path):
logger.warning(f"XLSX文件不存在,跳过删除: {xlsx_path}")
elif not xlsx_path.lower().endswith('.xlsx'):
logger.warning(f"文件不是XLSX格式,跳过删除: {xlsx_path}")
else:
try:
os.remove(xlsx_path)
logger.info(f"已删除XLSX文件: {xlsx_path}")
except OSError as e:
logger.error(f"删除XLSX文件失败: {e}")
def _collect_answered_questions(document_service, question_service, project) -> list:
"""收集项目中所有已回答和已审核的问题"""
all_questions = []
documents = document_service.get_documents(project.id)
for document in documents:
questions = question_service.get_questions_by_document_id(document.id)
all_questions.extend(questions)
return [
q for q in all_questions
if q.status in (QUESTION_STATUS_ANSWERED, QUESTION_STATUS_REVIEWED)
]
def _handle_create_dataset(ctx: DatasetCreationContext) -> None:
"""处理数据集创建"""
if not (ctx.args.create_dataset and ctx.project):
return
logger.info(f"正在创建数据集 '{ctx.args.create_dataset}'...")
dataset = ctx.dataset_service.create_dataset(
project_id=ctx.project.id,
name=ctx.args.create_dataset,
format=ctx.config.default_dataset_format,
file_type=ctx.config.default_dataset_file_type,
system_prompt=ctx.config.system_prompt,
)
answered = _collect_answered_questions(ctx.document_service, ctx.question_service, ctx.project)
if answered:
dataset = ctx.dataset_service.add_question_to_dataset(dataset, answered)
logger.info(f"已将 {len(answered)} 个已回答的问题添加到数据集 '{dataset.name}'")
else:
logger.info("没有找到已回答的问题,无法创建数据集")
def _handle_export_dataset(args, dataset_service, project) -> None:
"""处理数据集导出"""
if not (args.export_dataset and project):
return
logger.info(f"正在导出数据集 '{args.export_dataset}'...")
dataset = dataset_service.get_dataset_by_name(project.id, args.export_dataset)
logger.info(f"已找到数据集 '{dataset}'")
if dataset:
export_path = dataset_service.export_dataset(dataset)
logger.info(f"数据集已导出到: {export_path}")
else:
logger.error(MSG_DATASET_NOT_FOUND.format(args.export_dataset))
def execute_application():
"""主函数,实现命令行问答数据集生成工具"""
args = _parse_arguments()
try:
config = Config.load(args.config)
except FileNotFoundError as e:
logger.error(MSG_CONFIG_NOT_FOUND.format(e))
sys.exit(1)
project_service = ProjectService(config.data_dir, config.__dict__)
project_name = ' '.join(args.project) if isinstance(args.project, list) else args.project
if not project_name:
return
project = project_service.get_or_create_project(project_name)
try:
llm_api_key = _resolve_api_key(config) if _needs_api_key(args) else None
llm_service = _create_llm_service(config, llm_api_key)
project_dir = _get_project_dir(config, project)
document_service = DocumentService(project_dir, config.__dict__)
question_service = QuestionService(project_dir, llm_service, system_prompt=config.system_prompt)
dataset_service = DatasetService(project_dir)
_handle_document(args, document_service, project)
_handle_generate_questions(args, document_service, question_service, project)
_handle_generate_answers(args, document_service, question_service, project)
_handle_import_questions(args, document_service, question_service, project)
dataset_ctx = DatasetCreationContext(
args=args,
config=config,
document_service=document_service,
question_service=question_service,
dataset_service=dataset_service,
project=project
)
_handle_create_dataset(dataset_ctx)
_handle_export_dataset(args, dataset_service, project)
except ConfigurationError as e:
logger.error(str(e))
sys.exit(1)
if __name__ == "__main__":
execute_application()