季卓然 发表于 2025-6-8 12:10:52

基于langchain的长文本多迭代总结

之前我们讲到langchain的rag问答,有兴趣的同学可以做下回顾
langchain基于混元大模型的实时内容的RAG问答

今天我们来了解下如何基于前文的方案实现长文本总结
为什么需要文本总结

通常会议内容是冗长的,如果能够提取关键信息的话,能够帮我们节省大量的时间
模型不能总结吗,为什么单独提出来长文本这个概念

大部分模型都会限制输入长度,如果会议长度超出了模型的限制则无法进行总结
方案

langchain提供了多种方案供我们选择,https://python.langchain.com/v0.1/docs/use_cases/summarization/

[*]stuff:全文本总结,将整个文本全部投入模型;这样仍然可能会超出模型
[*]MapReduce:将文本拆成n个小段,每个小段分别总结,然后再将最终的内容一起总结;这样虽然能解决问题,但是可能会破坏文本的上下文导致最终的结果不理想
[*]refine:和MapReduce相似的是将文本拆成n个小段,但是会以循环的方式先总结第一段,然后将第一段的总结结果和第二段再总结以此类推,此方法能够更好的保留原文的语义
难点


[*]代码实现
[*]流式返回
[*]如何确定是最后一轮的返回(在流式响应的情况下,每轮都会返回总结结果,那么入会确定是最后一轮并返回个前端)
实现

由于langchain的部分实现比较紧凑,导致做二次开发不是很方便,所以可能有部分修改源码的地方
1.创建文本加载工具,用于加载文本
AttachCode
from typing import Dict, Optional

from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.callbacks import CallbackManagerForChainRun

from modules.BasinessException import BusinessException
from modules.resultCodeEnum import ResultCodeEnum
from service.SubtitleService import SubtitleService
from utils import constants
from utils.logger import logger

class download_summarize_chain(AnalyzeDocumentChain):
    def _call(
            self,
            inputs: Dict,
            run_manager: Optional = None,
    ) -> dict:

      docs = self.get_docs(inputs, run_manager)

      # Other keys are assumed to be needed for LLM prediction
      other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
      other_keys = docs

      _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
      return self.combine_docs_chain(
            other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
      )

    def get_docs(self, inputs, run_manager):

      file_download_url = str(inputs)
      if file_download_url is not None and file_download_url.startswith("http"):
            # 通过下载地址下载文件
            loader = WebBaseLoader(file_download_url, None, False)
            """Split document into chunks and pass to CombineDocumentsChain."""
            document = loader.load().page_content
            if len(document) <= 0:
                logger.error(f"file not exists:{file_download_url}")
                raise BusinessException.new_instance_with_rce(400, ResultCodeEnum.EMPTY_CONTENT)
      else:
            # 通过企业id和会议id获取字幕
            enterprise_id: str = run_manager.metadata.get(constants.ENTERPRISE_ID)
            meeting_id: str = run_manager.metadata.get(constants.MEETING_ID)
            logger.info(f"process task with llm:{enterprise_id}-{meeting_id}")
            document = SubtitleService().fetch_subtitles(enterprise_id=enterprise_id, meeting_id=meeting_id)

      docs = self.text_splitter.create_documents()
      logger.info("number of splitting doc parts:{}", len(docs))
      return docs3.refine chain
AttachCode
"""Load summarizing chains."""

from typing import Any, Mapping, Optional, Protocol

from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate

from adapters.langchain.chains.refine import RefineDocumentsChain


class LoadingCallable(Protocol):
    """Interface for loading the combine documents chain."""

    def __call__(
            self, llm: BaseLanguageModel, **kwargs: Any
    ) -> BaseCombineDocumentsChain:
      """Callable to load the combine documents chain."""


def _load_stuff_chain(
      llm: BaseLanguageModel,
      prompt: BasePromptTemplate = stuff_prompt.PROMPT,
      document_variable_name: str = "text",
      verbose: Optional = None,
      **kwargs: Any,
) -> StuffDocumentsChain:
    llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)# type: ignore
    # TODO: document prompt
    return StuffDocumentsChain(
      llm_chain=llm_chain,
      document_variable_name=document_variable_name,
      verbose=verbose,# type: ignore
      **kwargs,
    )


def _load_map_reduce_chain(
      llm: BaseLanguageModel,
      map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
      combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
      combine_document_variable_name: str = "text",
      map_reduce_document_variable_name: str = "text",
      collapse_prompt: Optional = None,
      reduce_llm: Optional = None,
      collapse_llm: Optional = None,
      verbose: Optional = None,
      token_max: int = 3000,
      callbacks: Callbacks = None,
      *,
      collapse_max_retries: Optional = None,
      **kwargs: Any,
) -> MapReduceDocumentsChain:
    map_chain = LLMChain(
      llm=llm,
      prompt=map_prompt,
      verbose=verbose,# type: ignore
      callbacks=callbacks,# type: ignore
    )
    _reduce_llm = reduce_llm or llm
    reduce_chain = LLMChain(
      llm=_reduce_llm,
      prompt=combine_prompt,
      verbose=verbose,# type: ignore
      callbacks=callbacks,# type: ignore
    )
    # TODO: document prompt
    combine_documents_chain = StuffDocumentsChain(
      llm_chain=reduce_chain,
      document_variable_name=combine_document_variable_name,
      verbose=verbose,# type: ignore
      callbacks=callbacks,
    )
    if collapse_prompt is None:
      collapse_chain = None
      if collapse_llm is not None:
            raise ValueError(
                "collapse_llm provided, but collapse_prompt was not: please "
                "provide one or stop providing collapse_llm."
            )
    else:
      _collapse_llm = collapse_llm or llm
      collapse_chain = StuffDocumentsChain(
            llm_chain=LLMChain(
                llm=_collapse_llm,
                prompt=collapse_prompt,
                verbose=verbose,# type: ignore
                callbacks=callbacks,
            ),
            document_variable_name=combine_document_variable_name,
      )
    reduce_documents_chain = ReduceDocumentsChain(
      combine_documents_chain=combine_documents_chain,
      collapse_documents_chain=collapse_chain,
      token_max=token_max,
      verbose=verbose,# type: ignore
      callbacks=callbacks,
      collapse_max_retries=collapse_max_retries,
    )
    return MapReduceDocumentsChain(
      llm_chain=map_chain,
      reduce_documents_chain=reduce_documents_chain,
      document_variable_name=map_reduce_document_variable_name,
      verbose=verbose,# type: ignore
      callbacks=callbacks,
      **kwargs,
    )


def _load_refine_chain(
      llm: BaseLanguageModel,
      question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
      refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
      document_variable_name: str = "text",
      initial_response_name: str = "existing_answer",
      refine_llm: Optional = None,
      verbose: Optional = None,
      **kwargs: Any,
) -> RefineDocumentsChain:
    initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)# type: ignore
    _refine_llm = refine_llm or llm
    refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)# type: ignore
    return RefineDocumentsChain(
      initial_llm_chain=initial_chain,
      refine_llm_chain=refine_chain,
      document_variable_name=document_variable_name,
      initial_response_name=initial_response_name,
      verbose=verbose,# type: ignore
      **kwargs,
    )


def load_summarize_chain(
      llm: BaseLanguageModel,
      chain_type: str = "stuff",
      verbose: Optional = None,
      **kwargs: Any,
) -> BaseCombineDocumentsChain:
    """Load summarizing chain.

    Args:
      llm: Language Model to use in the chain.
      chain_type: Type of document combining chain to use. Should be one of "stuff",
            "map_reduce", and "refine".
      verbose: Whether chains should be run in verbose mode or not. Note that this
            applies to all chains that make up the final chain.

    Returns:
      A chain to use for summarizing.
    """
    loader_mapping: Mapping = {
      "stuff": _load_stuff_chain,
      "map_reduce": _load_map_reduce_chain,
      "refine": _load_refine_chain,
    }
    if chain_type not in loader_mapping:
      raise ValueError(
            f"Got unsupported chain type: {chain_type}. "
            f"Should be one of {loader_mapping.keys()}"
      )
    return loader_mapping(llm, verbose=verbose, **kwargs)4.调用chain
AttachCode
"""Combine documents by doing a first pass and then refining on more documents."""

from __future__ import annotations

from typing import Any, Dict, List, Tuple

from langchain.chains.combine_documents.base import (
    BaseCombineDocumentsChain,
)
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import Callbacks, dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from pydantic import ConfigDict, Field, model_validator

from utils.logger import logger


def _get_default_document_prompt() -> PromptTemplate:
    return PromptTemplate(input_variables=["page_content"], template="{page_content}")


class RefineDocumentsChain(BaseCombineDocumentsChain):
    """Combine documents by doing a first pass and then refining on more documents.

    This algorithm first calls `initial_llm_chain` on the first document, passing
    that first document in with the variable name `document_variable_name`, and
    produces a new variable with the variable name `initial_response_name`.

    Then, it loops over every remaining document. This is called the "refine" step.
    It calls `refine_llm_chain`,
    passing in that document with the variable name `document_variable_name`
    as well as the previous response with the variable name `initial_response_name`.

    Example:
      .. code-block:: python

            from langchain.chains import RefineDocumentsChain, LLMChain
            from langchain_core.prompts import PromptTemplate
            from langchain_community.llms import OpenAI

            # This controls how each document will be formatted. Specifically,
            # it will be passed to `format_document` - see that function for more
            # details.
            document_prompt = PromptTemplate(
                input_variables=["page_content"],
               template="{page_content}"
            )
            document_variable_name = "context"
            llm = OpenAI()
            # The prompt here should take as an input variable the
            # `document_variable_name`
            prompt = PromptTemplate.from_template(
                "Summarize this content: {context}"
            )
            initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
            initial_response_name = "prev_response"
            # The prompt here should take as an input variable the
            # `document_variable_name` as well as `initial_response_name`
            prompt_refine = PromptTemplate.from_template(
                "Here's your first summary: {prev_response}. "
                "Now add to it based on the following context: {context}"
            )
            refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
            chain = RefineDocumentsChain(
                initial_llm_chain=initial_llm_chain,
                refine_llm_chain=refine_llm_chain,
                document_prompt=document_prompt,
                document_variable_name=document_variable_name,
                initial_response_name=initial_response_name,
            )
    """

    initial_llm_chain: LLMChain
    """LLM chain to use on initial document."""
    refine_llm_chain: LLMChain
    """LLM chain to use when refining."""
    document_variable_name: str
    """The variable name in the initial_llm_chain to put the documents in.
    If only one variable in the initial_llm_chain, this need not be provided."""
    initial_response_name: str
    """The variable name to format the initial response in when refining."""
    document_prompt: BasePromptTemplate = Field(
      default_factory=_get_default_document_prompt
    )
    """Prompt to use to format each document, gets passed to `format_document`."""
    return_intermediate_steps: bool = False
    """Return the results of the refine steps in the output."""

    @property
    def output_keys(self) -> List:
      """Expect input key.

      :meta private:
      """
      _output_keys = super().output_keys
      if self.return_intermediate_steps:
            _output_keys = _output_keys + ["intermediate_steps"]
      return _output_keys

    model_config = ConfigDict(
      arbitrary_types_allowed=True,
      extra="forbid",
    )

    @model_validator(mode="before")
    @classmethod
    def get_return_intermediate_steps(cls, values: Dict) -> Any:
      """For backwards compatibility."""
      if "return_refine_steps" in values:
            values["return_intermediate_steps"] = values["return_refine_steps"]
            del values["return_refine_steps"]
      return values

    @model_validator(mode="before")
    @classmethod
    def get_default_document_variable_name(cls, values: Dict) -> Any:
      """Get default document variable name, if not provided."""
      if "initial_llm_chain" not in values:
            raise ValueError("initial_llm_chain must be provided")

      llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
      if "document_variable_name" not in values:
            if len(llm_chain_variables) == 1:
                values["document_variable_name"] = llm_chain_variables
            else:
                raise ValueError(
                  "document_variable_name must be provided if there are "
                  "multiple llm_chain input_variables"
                )
      else:
            if values["document_variable_name"] not in llm_chain_variables:
                raise ValueError(
                  f"document_variable_name {values['document_variable_name']} was "
                  f"not found in llm_chain input_variables: {llm_chain_variables}"
                )
      return values

    def combine_docs(
            self, docs: List, callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple:
      """Combine by mapping first chain over all, then stuffing into final chain.

      Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

      Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
      """
      inputs = self._construct_initial_inputs(docs, **kwargs)
      dispatch_custom_event("last_doc_mark", {"chunk": False})
      doc_length = len(docs)
      if doc_length == 1:
            dispatch_custom_event("last_doc_mark", {"chunk": True})
      logger.info(f"refine_docs index:1/{doc_length} of {kwargs}")
      res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
      refine_steps =
      for index, doc in enumerate(docs, start=1):
            logger.info(f"refine_docs index:{index+1}/{doc_length} of {kwargs}")
            if index == doc_length - 1:
                dispatch_custom_event("last_doc_mark", {"chunk": True})
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
      logger.info(f"refine_docs finished of {kwargs}, result:{res}")
      return self._construct_result(refine_steps, res)

    async def acombine_docs(
            self, docs: List, callbacks: Callbacks = None, **kwargs: Any
    ) -> Tuple:
      """Async combine by mapping a first chain over all, then stuffing
         into a final chain.

      Args:
            docs: List of documents to combine
            callbacks: Callbacks to be passed through
            **kwargs: additional parameters to be passed to LLM calls (like other
                input variables besides the documents)

      Returns:
            The first element returned is the single string output. The second
            element returned is a dictionary of other keys to return.
      """
      inputs = self._construct_initial_inputs(docs, **kwargs)
      res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
      refine_steps =
      for doc in docs:
            base_inputs = self._construct_refine_inputs(doc, res)
            inputs = {**base_inputs, **kwargs}
            res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
            refine_steps.append(res)
      return self._construct_result(refine_steps, res)

    def _construct_result(self, refine_steps: List, res: str) -> Tuple:
      if self.return_intermediate_steps:
            extra_return_dict = {"intermediate_steps": refine_steps}
      else:
            extra_return_dict = {}
      return res, extra_return_dict

    def _construct_refine_inputs(self, doc: Document, res: str) -> Dict:
      return {
            self.document_variable_name: format_document(doc, self.document_prompt),
            self.initial_response_name: res,
      }

    def _construct_initial_inputs(
            self, docs: List, **kwargs: Any
    ) -> Dict:
      base_info = {"page_content": docs.page_content}
      base_info.update(docs.metadata)
      document_info = {k: base_info for k in self.document_prompt.input_variables}
      base_inputs: dict = {
            self.document_variable_name: self.document_prompt.format(**document_info)
      }
      inputs = {**base_inputs, **kwargs}
      return inputs

    @property
    def _chain_type(self) -> str:
      return "refine_documents_chain"5.过滤最后一次迭代
AttachCode
def process(tool: BaseTool, prompt_type: QuestionTypeEnum, input_dict: dict,
            run_manager: Optional = None):
    # 获取模型实例
    model_type = ModelTypeEnum.from_string(run_manager.metadata.get(constants.MODEL_TYPE))
    model_instance = ModelAdapter.get_model_instance(model_type)

    # 提示词模板集合
    prompt_template = PromptSynchronizer.get_prompt_template(model_type=model_type, questionType=prompt_type)
    prompt_map = json.loads(prompt_template)
    refine_prompt = PromptTemplate.from_template(prompt_map["refine_template"], template_format="f-string")
    question_prompt = PromptTemplate.from_template(prompt_map["prompt"])

    logger.info("invoke tool input_dicts:{}",input_dict)

    combine_docs_chain=load_summarize_chain(llm=model_instance,
                                          chain_type="refine",
                                          question_prompt=question_prompt,
                                          refine_prompt=refine_prompt,
                                          return_intermediate_steps=True,
                                          input_key="text",
                                          output_key="existing_answer",
                                          verbose=True)

    res = (tool.pre_handler
         | download_summarize_chain(combine_docs_chain=combine_docs_chain,
                                  text_splitter=model_instance.get_text_splitter(),
                                  verbose=True,
                                  input_key="input_document")
         | tool.post_handler).invoke(input={"input_document": "", **input_dict}, config=RunnableConfig())
    return res此笔记由idea插件辅助生成
idea插件推荐 AnNote - IntelliJ IDEs Plugin | Marketplace 75 折折扣:
MGRYF-TJW4N-WZMSJ-MZDLD-LVGJH
BTKQ8-XZLPH-L3QH3-MPKBH-BP9RR
本文由博客群发一文多发等运营工具平台 OpenWrite 发布

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
页: [1]
查看完整版本: 基于langchain的长文本多迭代总结