找回密码
 立即注册
首页 业界区 业界 AI大模型应用开发入门-LangChain开发RAG增强检索生成 ...

AI大模型应用开发入门-LangChain开发RAG增强检索生成

孔季雅 2025-6-15 08:02:10
检索增强生成(RAG)是一种结合“向量检索”与“大语言模型”的技术路线,能在问答、摘要、文档分析等场景中大幅提升准确性与上下文利用率。
本文将基于 LangChain 构建一个完整的 RAG 流程,结合 PGVector 作为向量数据库,并用 LangGraph 构建状态图控制流程。
大语言模型初始化(llm_env.py)

我们首先使用 LangChain 提供的模型初始化器加载 gpt-4o-mini 模型,供后续问答使用。
  1. # llm_env.py
  2. from langchain.chat_models import init_chat_model
  3. llm = init_chat_model("gpt-4o-mini", model_provider="openai")
复制代码
RAG 主体流程(rag.py)

以下是整个 RAG 系统的主流程代码,主要包括:文档加载与切分、向量存储、状态图建模(analyze→retrieve→generate)、交互式问答。
  1. # rag.py
  2. import os
  3. import sys
  4. import time
  5. sys.path.append(os.getcwd())
  6. from llm_set import llm_env
  7. from langchain_openai import OpenAIEmbeddings
  8. from langchain_postgres import PGVector
  9. from langchain_community.document_loaders import WebBaseLoader
  10. from langchain_core.documents import Document
  11. from langchain_text_splitters import RecursiveCharacterTextSplitter
  12. from langgraph.graph import START, StateGraph
  13. from typing_extensions import List, TypedDict, Annotated
  14. from typing import Literal
  15. from langgraph.checkpoint.postgres import PostgresSaver
  16. from langgraph.graph.message import add_messages
  17. from langchain_core.messages import HumanMessage, BaseMessage
  18. from langchain_core.prompts import ChatPromptTemplate
  19. # 初始化 LLM
  20. llm = llm_env.llm
  21. # 嵌入模型
  22. embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
  23. # 向量数据库初始化
  24. vector_store = PGVector(
  25.     embeddings=embeddings,
  26.     collection_name="my_rag_docs",
  27.     connection="postgresql+psycopg2://postgres:123456@localhost:5433/langchainvector",
  28. )
  29. # 加载网页内容
  30. url = "https://python.langchain.com/docs/tutorials/qa_chat_history/"
  31. loader = WebBaseLoader(web_paths=(url,))
  32. docs = loader.load()
  33. for doc in docs:
  34.     doc.metadata["source"] = url
  35. # 文本分割
  36. text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
  37. all_splits = text_splitter.split_documents(docs)
  38. # 添加 section 元数据
  39. total_documents = len(all_splits)
  40. third = total_documents // 3
  41. for i, document in enumerate(all_splits):
  42.     if i < third:
  43.         document.metadata["section"] = "beginning"
  44.     elif i < 2 * third:
  45.         document.metadata["section"] = "middle"
  46.     else:
  47.         document.metadata["section"] = "end"
  48. # 检查是否已存在向量
  49. existing = vector_store.similarity_search(url, k=1, filter={"source": url})
  50. if not existing:
  51.     _ = vector_store.add_documents(documents=all_splits)
  52.     print("文档向量化完成")
复制代码
分析、检索与生成模块

接下来,我们定义三个函数构成 LangGraph 的流程:analyze → retrieve → generate。
  1. class Search(TypedDict):
  2.     query: Annotated[str, "The question to be answered"]
  3.     section: Annotated[
  4.         Literal["beginning", "middle", "end"],
  5.         ...,
  6.         "Section to query.",
  7.     ]
  8. class State(TypedDict):
  9.     messages: Annotated[list[BaseMessage], add_messages]
  10.     query: Search
  11.     context: List[Document]
  12.     answer: set
  13. # 分析意图 → 获取 query 与 section
  14. def analyze(state: State):
  15.     structtured_llm = llm.with_structured_output(Search)
  16.     query = structtured_llm.invoke(state["messages"])
  17.     return {"query": query}
  18. # 相似度检索
  19. def retrieve(state: State):
  20.     query = state["query"]
  21.     if hasattr(query, 'section'):
  22.         filter = {"section": query["section"]}
  23.     else:
  24.         filter = None
  25.     retrieved_docs = vector_store.similarity_search(query["query"], filter=filter)
  26.     return {"context": retrieved_docs}
复制代码
生成模块基于 ChatPromptTemplate 和当前上下文生成回答:
  1. prompt_template = ChatPromptTemplate.from_messages(
  2.     [
  3.         ("system", "尽你所能按照上下文:{context},回答问题:{question}。"),
  4.     ]
  5. )
  6. def generate(state: State):
  7.     docs_content = "\n\n".join(doc.page_content for doc in state["context"])
  8.     messages = prompt_template.invoke({
  9.         "question": state["query"]["query"],
  10.         "context": docs_content,
  11.     })
  12.     response = llm.invoke(messages)
  13.     return {"answer": response.content, "messages": [response]}
复制代码
构建 LangGraph 流程图

定义好状态结构后,我们构建 LangGraph:
  1. graph_builder = StateGraph(State).add_sequence([analyze, retrieve, generate])
  2. graph_builder.add_edge(START, "analyze")
复制代码
PG 数据库中保存中间状态(Checkpoint)

我们通过 PostgresSaver 记录每次对话的中间状态:
  1. DB_URI = "postgresql://postgres:123456@localhost:5433/langchaindemo?sslmode=disable"
  2. with PostgresSaver.from_conn_string(DB_URI) as checkpointer:
  3.     checkpointer.setup()
  4.     graph = graph_builder.compile(checkpointer=checkpointer)
  5.     input_thread_id = input("输入thread_id:")
  6.     time_str = time.strftime("%Y%m%d", time.localtime())
  7.     config = {"configurable": {"thread_id": f"rag-{time_str}-demo-{input_thread_id}"}}
  8.     print("输入问题,输入 exit 退出。")
  9.     while True:
  10.         query = input("你: ")
  11.         if query.strip().lower() == "exit":
  12.             break
  13.         input_messages = [HumanMessage(query)]
  14.         response = graph.invoke({"messages": input_messages}, config=config)
  15.         print(response["answer"])
复制代码
效果

1.png

总结

本文通过 LangChain 的模块式能力,结合 PGVector 向量库与 LangGraph 有状态控制系统,实现了一个可交互、可持久化、支持多文档结构的 RAG 系统。其优势包括:

  • 支持结构化提问理解(分区查询)
  • 自动化分段与元数据标记
  • 状态流追踪与恢复
  • 可拓展支持文档上传、缓存优化、多用户配置

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册