README.md

Demo运行说明

1 背景

1.0 LangGraph介绍

LangGraph官方介绍

LangGraph是Langchain新出的一个成员,是 LangChain 的 LangChain Expression Language (LCEL)的扩展。能够利用有向无环图的方式,去协调多个LLM或者状态,使用起来比 LCEL 会复杂,但是逻辑会更清晰。我们可以把它也当做langchain扩展出来的Agent框架,langchain原有agent 的实现在LangGraph中都得到了重新实现,所以对于原来使用Langchain的系统去接入更容易。

1.1 RAG SDK介绍

RAG SDK 详细资料参考 《RAG SDK 用户指南》

RAG SDK 是昇腾面向大语言模型的知识增强开发套件,为解决大模型知识更新缓慢以及垂直领域知识问答弱的问题,面向大模型知识库提供垂域调优、生成增强、知识管理等特性。

1.2 昇腾mis介绍

mind inference microservice

昇腾提供基于昇腾硬件加速的reranker和embedding mind inference microservice服务,通过快速部署就可以支撑RAG应用。

embedding tei 安装地址 reranker tei 安装地址

1.3 功能描述

基于 LangGraph 和 RAG SDK 搭建的 RAG 应用样例,以有向无环图(DAG)方式编排多个 RAG 节点,实现完整的知识检索增强问答流程。主要特性包括:

  • 图节点编排:使用 LangGraph 的 StateGraph 定义 RAG 流程,包含缓存查询、问题拆解、混合检索、重排序、生成、幻觉检测、查询改写、缓存更新等节点
  • 语义缓存:支持相似问题缓存,命中时直接返回结果,减少大模型推理开销
  • 问题拆解:利用 LLM 将复杂问题拆解为独立子问题,提升检索质量
  • 混合检索:支持向量检索和 BM25 检索的混合检索模式
  • 幻觉检测:利用大模型评估生成内容的忠实度和相关性,不达标时自动触发查询改写重试

2 前提条件

执行Demo前请先阅读《RAG SDK 用户指南》,并按照其中"安装部署"章节的要求完成必要软、硬件安装。 本章节为"应用开发"章节提供开发样例代码,便于开发者快速开发。

3 环境安装

参考RAG SDK 用户指南"安装部署"章节,分别安装cann、RAG SDK、以及部署embedding, reranker服务,安装langgraph包:

pip3 install langgraph==0.2.19

4 总体介绍

基于langgraph和RAG SDK搭建RAG应用, 根据langgraph的定义需要包含node和graph,node中使用RAG SDK完成相应的功能。

RAG 节点(Node)定义

  • cache search node:用户问题缓存查询节点
  • query decompose node:用户问题拆分子问题节点
  • hybrid Retrieve node:用户问题混合检索文档节点
  • rerank node:重排检索文档节点
  • generate node:大模型生成节点
  • hallucination check node:生成内容幻觉检查节点
  • query rewrite node: 用户问题重写节点
  • cache update node: 用户问题缓存更新节点

RAG 图(GRAPH)定义 图定义如下: alt text

状态转换如下表所示:

name type next hop input output
cache search node if cache hit return generation to user else go query decompose (question) if hit (question, generation) else (question)
query decompose node hybrid retrieve (question) (question, sub_question)
hybrid retrieve node rerank (question, sub_question) (question, sub_question, contexts)
rerank node generate (question, sub_question, contexts) (question, sub_question, contexts)
query rewrite node cache search (question, sub_question, contexts) or (question, sub_question, contexts, generate) (question)
generate node hallucination check (question, sub_question, contexts) (question, sub_question, contexts, generate)
hallucination check node if hallucination check pass go cache update else go query rewrite (question, sub_question, contexts, generate) (question, sub_question, contexts, generate)
cache update node END (question, sub_question, contexts, generate) (question, sub_question, contexts, generate)

5 RAG SDK 功能初始化

完整的代码样例请参考langgraph_demo.py

5.1 RAG文档加载和切分

以下是初始化一个docx的文件加载器和文件切分器,并且按照chunk_size=200,chunk_overlap=50进行切分,详细的API文档请参考RAG SDK的使用手册。

def create_loader_and_spliter(mxrag_component: Dict[str, Any],
                              chunk_size:int = 200,
                              chunk_overlap:int = 50):
    from langchain.text_splitter import RecursiveCharacterTextSplitter

    from mx_rag.knowledge.doc_loader_mng import LoaderMng
    from mx_rag.document.loader import DocxLoader

    loader_mng = LoaderMng()
    loader_mng.register_loader(DocxLoader, [".docx"])
    loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".docx"],
                                 {"chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "keep_separator": False})
    mxrag_component["loader_mng"] = loader_mng

5.2 RAG远端服务

以下是分别初始化embedding,reranker服务,用户需要传入相应的地址。

def create_remote_connector(mxrag_component: Dict[str, Any],
                            reranker_url: str,
                            embedding_url: str,
                            llm_url: str,
                            llm_model_name: str):
    from mx_rag.llm.text2text import Text2TextLLM
    from mx_rag.embedding import EmbeddingFactory
    from mx_rag.reranker.reranker_factory import RerankerFactory

    reranker = RerankerFactory.create_reranker(similarity_type="tei_reranker",
                                               url=reranker_url,
                                               client_param=ClientParam,(use_http=True),
                                               k=3)
    mxrag_component['reranker_connector'] = reranker

    embedding = EmbeddingFactory.create_embedding(embedding_type="tei_embedding",
                                                  url=embedding_url,
                                                  client_param=ClientParam(use_http=True))
    mxrag_component['embedding_connector'] = embedding

    llm = Text2TextLLM(base_url=llm_url, model_name=llm_model_name,
                       client_param=ClientParam(use_http=True),
                       llm_config=LLMParameterConfig(max_tokens=4096))
    mxrag_component['llm_connector'] = llm

5.3 RAG知识库

以下是存放用户知识文档的样例,这里使用mxindex(MindFaiss)作为矢量检索,knowledge_files是用户需要传入包含文件路径的文件名列表。

def create_knowledge_storage(mxrag_component: Dict[str, Any], knowledge_files: List[str]):
    from mx_rag.knowledge.knowledge import KnowledgeStore
    from mx_rag.knowledge import KnowledgeDB
    from mx_rag.knowledge.handler import upload_files
    from mx_rag.storage.document_store import SQLiteDocstore

    npu_dev_id = 0

    # faiss_index_save_file is your faiss index save dir
    faiss_index_save_file:str = "/home/HwHiAiUser/rag_npu_faiss.index"
    vector_store = MindFAISS(x_dim=1024,
                             devs=[npu_dev_id],
                             load_local_index=faiss_index_save_file)
    mxrag_component["vector_store"] = vector_store


    # sqlite_save_file is your sqlite save dir
    sqlite_save_file:str = "/home/HwHiAiUser/rag_sql.db"
    chunk_store = SQLiteDocstore(db_path=sqlite_save_file)
    mxrag_component["chunk_store"] = chunk_store

    # your knowledge file white paths if docx not in white paths will raise exception
    white_paths=["/home/HwHiAiUser/"]
    knowledge_store = KnowledgeStore(db_path=sqlite_save_file)
    knowledge_store.add_knowledge("rag", "Default01", "admin")
    Knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, chunk_store=chunk_store, vector_store=vector_store,
                               knowledge_name="rag", white_paths=white_paths, user_id="Default01")

    upload_files(Knowledge_db, knowledge_files, loader_mng=mxrag_component.get("loader_mng"),
                 embed_func=mxrag_component.get("embedding_connector").embed_documents,
                 force=True)

5.4 RAG缓存系统

定义语义缓存系统,用于缓存用户已经提供过的答案,当用户再次提问相似的问题可以很快 返回结果,不需要再进行大模型推理,加速E2E性能。 语义缓存一般包含矢量数据库,标量数据库以及相应的embedding和相似度计算方法。

def create_cache(mxrag_component: Dict[str, Any],
                 reranker_url: str,
                 embedding_url: str):
    from mx_rag.cache import SimilarityCacheConfig
    from mx_rag.cache import EvictPolicy
    from mx_rag.cache import MxRAGCache

    npu_dev_id = 0
    # data_save_folder is your cache file when you next run your rag applicate it will read form disk
    cache_data_save_folder = "/home/HwHiAiUser/mx_rag/cache_save_folder/"

    similarity_config = SimilarityCacheConfig(
        vector_config={
            "vector_type": "npu_faiss_db",
            "x_dim": 1024,
            "devs": [npu_dev_id],
        },
        cache_config="sqlite",
        emb_config={
            "embedding_type": "tei_embedding",
            "url": embedding_url,
            "client_param": ClientParam(use_http=True)
        },
        similarity_config={
            "similarity_type": "tei_reranker",
            "url": reranker_url,
            "client_param": ClientParam(use_http=True)
        },
        retrieval_top_k=3,
        cache_size=100,
        auto_flush=100,
        similarity_threshold=0.70,
        data_save_folder=cache_data_save_folder,
        disable_report=True,
        eviction_policy=EvictPolicy.LRU
    )

    similarity_cache = MxRAGCache("similarity_cache", similarity_config)
    mxrag_component["cache"] = similarity_cache

5.5 RAG评估系统

以下是初始化评估系统,这里使用大模型进行评估。

def create_evaluate(mxrag_component):
    from mx_rag.evaluate import Evaluate

    llm = mxrag_component.get("llm_connector")
    embedding = mxrag_component.get("embedding_connector")
    mxrag_component["evaluator"] = Evaluate(llm=llm, embedding=embedding)

5.6 RAG混合检索

以下是构建混合检索的样例,这里使用了矢量检索和BM25检索,并按照RRF算法设置权重进行排序得到最后的检索文档。

def create_hybrid_search_retriever(mxrag_component: Dict[str, Any]):
    from langchain.retrievers import EnsembleRetriever

    from mx_rag.retrievers.retriever import Retriever
    from mx_rag.retrievers import BMRetriever

    chunk_store = mxrag_component.get("chunk_store")
    vector_store = mxrag_component.get("vector_store")
    embedding = mxrag_component.get("embedding_connector")

    npu_faiss_retriever = Retriever(vector_store=vector_store, document_store=chunk_store,
                                    embed_func=embedding.embed_documents, k=10, score_threshold=0.4)

    hybrid_retriever = EnsembleRetriever(
        retrievers=[npu_faiss_retriever], weights=[1.0]
    )

    mxrag_component["retriever"] = hybrid_retriever

6 langgraph 图定义和编译运行

完整的代码样例请参考langgraph_demo.py

6.1 Node定义

使用用户的问题,访问rag cache,如果命中generation不为None。

def cache_search(cache):
    def cache_search_process(state):
        logger.info("---QUERY SEARCH ---")
        question = state["question"]
        generation = cache.search(question)
        return {"question": question, "generation": generation}

    return cache_search_process

判决cache search 是否hit,根据generation 是否为None进行判断,如果为None则表示 cache miss,如果不为None则cache hit。

def decide_to_decompose(state):
    logger.info("---DECIDE TO DECOMPOSE---")
    cache_generation = state["generation"]

    if cache_generation is None:
        logger.warning(
            "---DECISION: CACHE MISS GO DECOMPOSE---"
        )
        return "cache_miss"

    logger.info("---DECISION: CACHE HIT END---")
    return "cache_hit"

6.1.2 Query Decompose

使用提示词工程进行问题拆解,拆解为子问题。

def decompose(llm):
    sub_question_key_words = "Q:"
    prompt = PromptTemplate(
        template="""
                    请你参考如下示例,拆分用户的问题为独立子问题,如果无法拆分则返回原始问题:
                    示例一:
                    用户问题: 今天的天气如何, 你今天过的怎么样?

                    {sub_question_key_words}今天的天气如何?
                    {sub_question_key_words}你今天过的怎么样?

                    示例二:
                    用户问题: 汉堡好吃吗?

                    {sub_question_key_words}汉堡好吃吗?

                    现在请你参考示例拆分以下用户问题:
                    用户的问题:{question}
                    """,
        input_variables=["question", "sub_question_key_words"]
    )

    sub_question_generator = LLMChain(llm=llm, prompt=prompt)

    def decompose_process(state):
        logger.info("---QUERY DECOMPOSITION ---")
        question = state["question"]

        sub_queries = sub_question_generator.predict(question=question, sub_question_key_words=sub_question_key_words)
        if sub_question_key_words not in sub_queries:
            sub_queries = None
        else:
            sub_queries = sub_queries.split(sub_question_key_words)
            sub_queries = sub_queries[1:]

        return {"sub_questions": sub_queries, "question": question}

    return decompose_process

6.1.3 Hybrid Retrieve

以下是进行混合检索,如果sub_question为None则使用question进行检索,如果sub_question不为None则使用sub_question进行检索。

def retrieve(retriever: BaseRetriever):
    def retrieve_process(state):
        logger.info("---RETRIEVE---")
        sub_questions = state["sub_questions"]
        question = state["question"]

        documents = []
        docs = []
        if sub_questions is None:
            docs = retriever.get_relevant_documents(question)
        else:
            for query in sub_questions:
                docs.extend(retriever.get_relevant_documents(query))

        for doc in docs:
            if doc.page_content not in documents:
                documents.append(doc.page_content)

        return {"documents": documents, "question": question}

    return retrieve_process

6.1.4 Rerank

将用户的检索文档根据语义进行重排序。

def rerank(reranker):
    def rerank_process(state):
        logger.info("---RERANK---")
        question = state["question"]
        documents = state["documents"]

        scores = reranker.rerank(query=question, texts=documents)
        documents = reranker.rerank_top_k(objs=documents, scores=scores)

        return {"documents": documents, "question": question}

    return rerank_process

6.1.5 Generate

使用提示词工程访问进行大模型推理过程得到生成结果。

def generate(llm):
    prompt = PromptTemplate(
        template="""{context}

                 根据上述已知信息,简洁和专业的来回答用户问题。如果无法从中已知信息中得到答案,请根据自身经验做出回答

                 {question}
                 """,
        input_variables=["context", "question"]
    )

    rag_chain = LLMChain(llm=llm, prompt=prompt)

    def generate_process(state):
        logger.info("---GENERATE---")
        question = state["question"]
        documents = state["documents"]

        generation = rag_chain.predict(context=documents, question=question)
        return {"documents": documents, "question": question, "generation": generation}

    return generate_process

6.1.6 Hallucination Check

利用大模型评估进行判断生成质量是否符合用户需求。

def grade_generation_v_documents_and_question(evaluate,
                                              context_score_threshold: float = 0.6,
                                              answer_score_threshold: float = 0.6):
    generate_evalutor = evaluate_creator(evaluate, "generate_relevancy")

    def grade_generation_v_documents_and_question_process(state):
        logger.info("---CHECK HALLUCINATIONS---")

        answer_score, context_score = generate_evalutor(state)

        answer_score = answer_score[0]
        logger.info("---GRADE GENERATION vs QUESTION---")
        if answer_score < answer_score_threshold:
            logger.warning(f"---DECISION: GENERATION DOES NOT ADDRESS QUESTION,"
                           f" RE-TRY--- answer_score:{answer_score},"
                           f"answer_score_threshold:{answer_score_threshold}")
            return "not useful"

        logger.info(f"---DECISION: GENERATION ADDRESSES QUESTION--- "
                    f"answer_score:{answer_score},"
                    f"answer_score_threshold:{answer_score_threshold}")

        context_score = context_score[0]
        logger.info("---GRADE GENERATION vs DOCUMENTS---")
        if context_score < context_score_threshold:
            logger.warning(f"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, "
                           f" RE-TRY--- context_score:{context_score},"
                           f"context_score_threshold:{context_score_threshold}")
            return "not useful"

        logger.info(f"---DECISION: GENERATION GROUNDED IN DOCUMENTS---"
                    f"context_score:{context_score},"
                    f"context_score_threshold:{context_score_threshold}")
        return "useful"

    return grade_generation_v_documents_and_question_process

6.1.7 CacheUpdate

如果大模型生成质量符合要求,则更新缓存。

def cache_update(cache):
    def cache_update_process(state):
        logger.info("---QUERY UPDATE ---")
        question = state["question"]
        generation = state["generation"]

        cache.update(question, generation)

        return state

    return cache_update_process

6.1.8 Query Rewrite

利用提示词工程进行问答重写。

def transform_query(llm):
    prompt = PromptTemplate(
        template="""
                 你是一个用户问题重写员, 请仔细理解用户问题的内容和语义和检索的文档,在不修改用户问题
                 语义的前提下,将用户问题重写为可以更好被矢量检索的形式

                 用户问题:{question}
                 """,
        input_variables=["question"]
    )

    question_rewriter = LLMChain(llm=llm, prompt=prompt)

    def transform_query_process(state):
        logger.info("---TRANSFORM QUERY---")
        question = state["question"]
        documents = state["documents"]

        better_question = question_rewriter.predict(question=question)

        return {"documents": documents, "question": better_question}

    return transform_query_process

5.2 图编译

def build_mxrag_application(mxrag_component):
    from langgraph.graph import END, START, StateGraph

    class GraphState(TypedDict):
        question: str
        sub_questions: List[str]
        generation: str
        documents: List[str]

    llm = mxrag_component.get("llm_connector")
    retriever = mxrag_component.get("retriever")
    reranker = mxrag_component.get("reranker_connector")
    cache = mxrag_component.get("cache")
    evaluate = mxrag_component.get("evaluator")

    workflow = StateGraph(GraphState)
    workflow.add_node("cache_search", cache_search(cache))
    workflow.add_node("cache_update", cache_update(cache))
    workflow.add_node("decompose", decompose(llm))
    workflow.add_node("retrieve", retrieve(retriever))
    workflow.add_node("rerank", rerank(reranker))
    workflow.add_node("generate", generate(llm))
    workflow.add_node("transform_query", transform_query(llm))

    workflow.add_edge(START, "cache_search")

    workflow.add_conditional_edges(
        "cache_search",
        decide_to_decompose,
        {
            "cache_hit": END,
            "cache_miss": "decompose",
        },
    )

    workflow.add_edge("decompose", "retrieve")
    workflow.add_edge("retrieve", "rerank")
    workflow.add_edge("rerank", "generate")
    workflow.add_edge("transform_query", "cache_search")
    workflow.add_conditional_edges(
        "generate",
        grade_generation_v_documents_and_question(evaluate),
        {
            "useful": "cache_update",
            "not useful": "transform_query"
        },
    )

    workflow.add_edge("cache_update", END)
    app = workflow.compile()
    return app

5.3 在线问答

if __name__ == "__main__":
    mxrag_component: Dict[str, Any] = {}

    # mind inference microservice tei rerank
    mis_tei_reranker_url = "http://127.0.0.1:port/rerank"
    # mind inference microservice tei embed
    mis_tei_embedding_url = "http://127.0.0.1:port/embed"

    # llm model name like Llama3-8B-Chinese-Chat etc
    llm_model_name = "Llama3-8B-Chinese-Chat"

    # your knowledge list
    knowledge_files = ["/home/HwHiAiUser/doc1.docx"]

    create_loader_and_spliter(mxrag_component, chunk_size=200, chunk_overlap=50)

    create_remote_connector(mxrag_component,
                            reranker_url=mis_tei_reranker_url,
                            embedding_url=mis_tei_embedding_url,
                            llm_url=llm_url,
                            llm_model_name=llm_model_name)

    create_knowledge_storage(mxrag_component, knowledge_files=knowledge_files)

    create_cache(mxrag_component,
                 reranker_url=mis_tei_reranker_url,
                 embedding_url=mis_tei_embedding_url)

    create_hybrid_search_retriever(mxrag_component)

    create_evaluate(mxrag_component)

    rag_app = build_mxrag_application(mxrag_component)

    user_question = "your question"

    start_time = time.time()
    user_answer = rag_app.invoke({"question": user_question})
    end_time = time.time()

    print(f"user_question:{user_question}")
    print(f"user_answer:{user_answer}")
    print(f"app time cost:{(end_time - start_time) * 1000} ms")