找回密码
 立即注册
首页 业界区 业界 rllm中的推理流程

rllm中的推理流程

轨项尺 6 天前
打印一条推理路径

在上文中,我们跑通了rllm框架,下面,让我们仔细分析一下examples/math_tool/run_math_with_tool.py中的内部过程。
run_math_with_tool.py的大致代码如下:
  1.         agent_args = {"tools": ["python"], "parser_name": "qwen", "system_prompt": "You are a math assistant that can write python to solve math problems."}
  2.         env_args = {
  3.                 "tools": ["python"],
  4.                 "reward_fn": math_reward_fn,
  5.         }
  6.        
  7.     engine = AgentExecutionEngine(
  8.         agent_class=ToolAgent,
  9.         agent_args=agent_args,
  10.         env_class=ToolEnvironment,
  11.         env_args=env_args,
  12.         engine_name="openai",
  13.         rollout_engine_args={"base_url": "http://localhost:30000/v1", "api_key": "None"},
  14.         tokenizer=tokenizer,
  15.         sampling_params=sampling_params,
  16.         max_response_length=16384,
  17.         max_prompt_length=2048,
  18.         n_parallel_agents=n_parallel_agents,
  19.     )
  20.     test_dataset = DatasetRegistry.load_dataset("aime2024", "test")
  21.     ...
  22.     tasks = test_dataset.repeat(n=8)  # repeat to evaluate pass@k
  23.         ...
  24.     results = asyncio.run(engine.execute_tasks(tasks[:5])) # 只跑前10条
复制代码
我们打印出一条推理路径看看效果
  1. first_traj = results[0]
  2. print("\n======= 示例轨迹 =======")
  3. print("问题:", first_traj.task)
  4. for i, step in enumerate(first_traj.steps):
  5.         print(f"\n--- Step {i} ---")
  6.         print("Observation:", step.observation)
  7.         print("Model response:", step.model_response)
  8.         print("Action:", step.action)
  9.         print("Reward:", step.reward)
  10.         print("Done:", step.done)
  11.        
  12. print("======================\n")
复制代码
打印出来的结果为(一共有5步,第0步为LLM接受问题;第5步为LLM输出答案,中间步骤都是根据工具调用结果生成推理的过程。Observation是模型接受到的信息,包括问题,工具调用结果等;Action是模型产生的动作,包括工具调用,最终回复等)
  1. 问题: {'id': 60, 'problem': '...', 'answer': '204', 'url': '...', 'year': '2024', 'question': 'Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.', 'ground_truth': '204', 'data_source': 'math'}
  2. --- Step 0 ---
  3. Observation: {'id': 60, 'problem': '...', 'answer': '204', 'url': '...', 'year': '2024', 'question': 'Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.', 'ground_truth': '204', 'data_source': 'math'}
  4. Model response:
  5. ....
  6. <tool_call>
  7. {"name": "python", "arguments": {"code": "import math\n\na = 1\nb = 2\nc = -11.25\n\ndiscriminant = b**2 - 4*a*c\nsqrt_discriminant = math.sqrt(discriminant)\ns1 = (-b + sqrt_discriminant) / (2*a)\ns2 = (-b - sqrt_discriminant) / (2*a)\n\nprint(s1, s2)"}}
  8. </tool_call>
  9. Action: [{'id': '5c7285c2-d967-4e60-a228-7947d8c87524', 'type': 'function', 'function': {'name': 'python', 'arguments': '{"code": "import math\\n\\na = 1\\nb = 2\\nc = -11.25\\n\\ndiscriminant = b**2 - 4*a*c\\nsqrt_discriminant = math.sqrt(discriminant)\\ns1 = (-b + sqrt_discriminant) / (2*a)\\ns2 = (-b - sqrt_discriminant) / (2*a)\\n\\nprint(s1, s2)"}'}}]
  10. Reward: 0
  11. Done: False
  12. --- Step 1 ---
  13. Observation: {'tool_outputs': {'5c7285c2-d967-4e60-a228-7947d8c87524': '2.5 -4.5\n'}}
  14. Model response:
  15. ...
  16. <tool_call>
  17. {"name": "python", "arguments": {"code": "s = 2.5\n\n# Calculate t using equation (1)\nt = (4 - (9 / s)) * 60\nprint(t)"}}
  18. </tool_call>
  19. Action: [{'id': '6e6fc722-29c4-4806-917a-933e7ff87d1b', 'type': 'function', 'function': {'name': 'python', 'arguments': '{"code": "s = 2.5\\n\\n# Calculate t using equation (1)\\nt = (4 - (9 / s)) * 60\\nprint(t)"}'}}]
  20. Reward: 0
  21. Done: False
  22. --- Step 2 ---
  23. ...
  24. --- Step 3 ---
  25. ...
  26. --- Step 4 ---
  27. Observation: {'tool_outputs': {'dc3840b8-3d86-477e-b51e-32a26ad668d4': 's = 2.5 km/h\nt = 23.999999999999993 minutes\nTotal time for third scenario: 204.0 minutes\n'}}
  28. Model response:
  29. ...
  30. ### Final Answer:
  31. $$
  32. \boxed{204}
  33. $$
  34. Action: [{'id': 'b35c758b-ce5d-4f8b-a412-5b7fd86cf383', 'type': 'function', 'function': {'name': 'finish', 'arguments': {'response': "<think>\nOkay, let me check the output again. The code output shows that s is 2.5 km/h, t is approximately 24 minutes (due to floating-point precision), and the total time for the third scenario is exactly 204 minutes. That matches our manual calculations. So the answer is 204 minutes.\n\nThe slight discrepancy in t is because of how floating-point numbers are handled in Python, but since t was originally calculated as 24 minutes, the total time comes out perfectly as 204 minutes when converted. Therefore, the final answer should be 204 minutes.\n</think>\n\nThe number of minutes the walk takes Aya, including the $ t $ minutes spent in the coffee shop, when she walks at $ s + \\frac{1}{2} $ kilometers per hour, is **204 minutes**.\n\n### Step-by-Step Explanation:\n1. **Define Variables:**\n   - Let $ s $ be Aya's walking speed (km/h).\n   - Let $ t $ be the time spent in the coffee shop (minutes).\n\n2. **Set Up Equations:**\n   - **First Scenario:**  \n     Walking time: $ \\frac{9}{s} $ hours.  \n     Total time: $ \\frac{9}{s} + \\frac{t}{60} = 4 $ hours.  \n     Equation: $ \\frac{9}{s} + \\frac{t}{60} = 4 $.\n\n   - **Second Scenario:**  \n     Walking time: $ \\frac{9}{s+2} $ hours.  \n     Total time: $ \\frac{9}{s+2} + \\frac{t}{60} = 2.4 $ hours.  \n     Equation: $ \\frac{9}{s+2} + \\frac{t}{60} = 2.4 $.\n\n3. **Solve for $ s $ and $ t $:**\n   - Subtract the second equation from the first:  \n     $ \\frac{9}{s} - \\frac{9}{s+2} = 1.6 $.  \n     Simplify: $ \\frac{18}{s(s+2)} = 1.6 $.  \n     Solve: $ s^2 + 2s - 11.25 = 0 $.  \n     Roots: $ s = 2.5 $ km/h (discarding the negative root).\n\n   - Substitute $ s = 2.5 $ into the first equation:  \n     $ \\frac{9}{2.5} + \\frac{t}{60} = 4 $.  \n     Solve: $ t = 24 $ minutes.\n\n4. **Third Scenario:**\n   - Speed: $ s + 0.5 = 3 $ km/h.  \n   - Walking time: $ \\frac{9}{3} = 3 $ hours.  \n   - Total time: $ 3 + \\frac{24}{60} = 3.4 $ hours = $ 204 $ minutes.\n\n### Final Answer:\n$$\n\\boxed{204}\n$$"}}}]
  35. Reward: 1.0
  36. Done: True
  37. ======================
复制代码
由此,我们可以分析出来rllm中Agent 工具调用的流程:

  • agent观察到问题后,思考并进行function call
  • rllm框架识别到工具调用操作后,执行工具,并返回结果
  • Agent根据工具返回的结果继续分析。
此外,在正式讲解代码之前,还要明确几个术语:

  • 环境:负责将问题传递给Agent+执行工具
  • 观察:告诉Agent当前时刻的信息(包括接受到的问题,工具执行结果等)
  • 动作:Agent给环境的指令,也就是Agent生成的工具调用的参数
  • 奖励:这一步表现的好不好
举个例子,Agent调用代码工具,首先要从环境中接受到用户问题,然后Agent从环境中接受(观察)到问题,生成思考,思考后生成代码工具的调用参数(中包裹的内容,也就是Agent的动作)。然后在环境中执行Agent生成的代码,将执行结果返回给Agent,Agent观察到结果后,继续进行分析。
下面,我们对环境,和环境交互的Agent,以及奖励进行分析。至于AgentExecutionEngine本身,则是起到了统一协调的作用。
环境

定义在rllm.environments.tools.tool_env中,用于接受用户输入和执行工具调用。
主要代码如下:
  1. class ToolEnvironment(BaseEnv):
  2.         def step(self, action: list[dict] | str | dict):
  3.                 """
  4.                 Take a step in the environment based on the action.
  5.                 Args:
  6.                         actions: List containing a single action string from the agent
  7.        
  8.                 Returns:
  9.                         next_observations, rewards, terminateds, infos
  10.                 """
  11.                 # 检查action中是否有finish字段(如果当前找不到任何工具调用的动作,那么Agent就会执行finish动作,并传入到环境中),如果有,代表回答完成
  12.                 if isinstance(action, list) and action:
  13.                         for tool_call in action:
  14.                                 if tool_call.get("function", {}).get("name") == "finish":
  15.                                         done = True
  16.                                         break
  17.                
  18.                 # 如果回答完成,那么提取llm的回答,并且计算奖励
  19.                 if done:
  20.                         # 提取llm的回答
  21.                         if isinstance(action, str):
  22.                                 llm_response = action
  23.                         elif isinstance(action, list):
  24.                                 ...
  25.        
  26.                         # 根据问题,真实值和llm的回答计算奖励
  27.                         task_info = self.task if self.task is not None else {}
  28.                         reward_output = self.reward_fn(task_info=task_info, action=llm_response)
  29.                         return {}, reward_output.reward, done, {"response": action, "metadata": reward_output.metadata, "is_correct": reward_output.is_correct}
  30.        
  31.                 # 如果回答没有完成,那么执行工具并返回工具执行结果
  32.                 tool_calls = action
  33.                 tool_outputs = self._execute_tool_calls(tool_calls) # 执行工具是,会调用工具类的call方法(一般定义在rllm/tools 文件夹中)
  34.                 next_obs = {"tool_outputs": tool_outputs}
  35.                 # Return results as lists with single items to maintain batch structure
  36.                 return next_obs, reward, done, {"response": action, "metadata": {}}
复制代码
Agent

Agent主要用来维护一个消息队列,其中内容包括系统提示词,用户输入,模型回复以及工具调用
  1. [
  2.         {"role": "system", "content": ""},
  3.         {"role": "user", "content": ""},
  4.         {"role": "assistant", "content": ""},
  5.         {"role": "tool", "content": "","tool_call_id": ""}
  6.         ....
  7.         ....
  8. ]
复制代码
  1. class ToolAgent(BaseAgent):
  2.         def _format_observation_as_messages(self, obs: Any) -> list[dict]:
  3.                 """格式化从环境中接收到的观察"""
  4.                 messages = []
  5.                
  6.                 if isinstance(obs, dict):
  7.                         # 如果有question字段,代表是用户传入的,将role设为user,加入到历史消息中
  8.                         if "question" in obs:
  9.                                 messages.append({"role": "user", "content": obs["question"]})
  10.                         # 如果有tool_outputs字段,代表是工具返回结果,将role设为tool,加入到历史消息中
  11.                         elif "tool_outputs" in obs:
  12.                                 # Format tool outputs from environment observation
  13.                                 for tool_call_id, tool_output_str in obs["tool_outputs"].items():
  14.                                         messages.append(
  15.                                                 {
  16.                                                 "role": "tool",
  17.                                                 "content": tool_output_str,
  18.                                                 "tool_call_id": tool_call_id,
  19.                                                 })
  20.                 elif isinstance(obs, str):
  21.                         messages.append({"role": "user", "content": obs})
  22.                 elif obs:
  23.                         messages.append({"role": "user", "content": str(obs)})
  24.                 return messages
  25.         def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
  26.                 """
  27.                 将环境中获取到的观察加入到消息队列中
  28.                 """
  29.                 obs_messages = self._format_observation_as_messages(observation)
  30.                
  31.                 self.messages.extend(obs_messages)       
  32.         def update_from_model(self, response: str, **kwargs) -> Action:
  33.                 """
  34.                 从response中解析模型生成的工具调用参数
  35.                 """
  36.                 tool_calls_dict = []
  37.                 assistant_content = response
  38.                 # 从模型响应中解析回答
  39.                 try:
  40.                         tool_calls = self.tool_parser.parse(response)
  41.                         tool_calls_dict = [
  42.                                 {
  43.                                         "id": str(uuid.uuid4()),
  44.                                         "type": "function",
  45.                                         "function": tool_call.to_dict(),
  46.                                 }
  47.                                 for tool_call in tool_calls
  48.                         ]
  49.                 # 将模型的完整响应加入到消息队列中
  50.                 assistant_message = {"role": "assistant", "content": assistant_content}
  51.                
  52.                 if len(tool_calls_dict) > 0:
  53.                         # 进行简单的格式转换
  54.                         ...
  55.                        
  56.                 # 如果没有工具调用,那么将当前的动作设置为finish
  57.                 else:
  58.                         tool_calls_dict = [
  59.                                 {
  60.                                         "id": str(uuid.uuid4()),
  61.                                         "type": "function",
  62.                                         "function": {
  63.                                                 "name": "finish",
  64.                                                 "arguments": {
  65.                                                         "response": assistant_content,
  66.                                                 },
  67.                                         },
  68.                                 }
  69.                         ]
  70.                 # 将模型的响应加入到消息队列中
  71.                 self.messages.append(assistant_message)
  72.                 return Action(action=tool_calls_dict)
  73.                
  74.         def reset(self):
  75.                 """初始化(设置system prompt)"""       
  76.                 self.messages = [{"role": "system", "content": self.system_prompt + self.tools_prompt}]
复制代码
Agent执行引擎

代码在rllm/engine/agent_execution_engine.py中(为了简化起见,这里面移除了很多并行和状态维护的代码)。
可以看到,Agent执行引擎用于协调Agent和环境,实现了ReAct的推理模式。
  1. class AgentExecutionEngine:
  2.         async def run_agent_trajectory_async(self, idx, application_id, seed=0, mode="Text", **kwargs):
  3.                 """执行Agent推理的代码"""
  4.                 # 初始化
  5.                 env.reset()
  6.                 agent.reset()
  7.                
  8.                 for step_idx in range(self.max_steps):
  9.                         # 拿到prompt
  10.                         prompt_messages = agent.chat_completions.copy()
  11.                         # 得到response
  12.                         response = self.get_model_response(prompt_messages, application_id, **kwargs)
  13.                         # 从response中解析出动作
  14.                         action: Action = agent.update_from_model(response)
  15.                         action = action.action
  16.                         # 执行动作
  17.                         env.step(action)
  18.                         # Agent更新
  19.                         agent.update_from_env(...)
  20.                         # 执行完成后跳出循环
  21.                         if done:
  22.                                 break
复制代码
奖励函数

奖励函数定义在rllm/rewards/math_reward.py中,这里只使用了正确性奖励,主要代码如下:
  1. class RewardMathFn:
  2.   def __call__(self, task_info: dict, action: str) -> RewardOutput:
  3.         model_response = action
  4.        
  5.         # 剔除<think></think>标签里面的内容
  6.         if THOUGHT_DELIMITER_END in model_response:
  7.                 model_solution = model_response.split(THOUGHT_DELIMITER_END)[1]
  8.         else:
  9.                 model_solution = model_response
  10.         # 提取模型的回答(一般都包裹在\box{}中)
  11.         model_answer = extract_answer(model_solution)
  12.        
  13.         # 获取真实标签
  14.         ground_truths = task_info.get("ground_truth", None)
  15.         # 从真实标签中的\boxed字段里提取答案
  16.         processed_ground_truths = []
  17.         for truth in ground_truths:
  18.                 truth = str(truth)
  19.                 if "\\boxed" in truth:
  20.                         processed_truth = extract_answer(truth)
  21.                         if processed_truth is not None:
  22.                                 processed_ground_truths.append(processed_truth)
  23.                 else:
  24.                         processed_ground_truths.append(truth)
  25.         # 设置正确性奖励
  26.         for ground_truth in processed_ground_truths:
  27.                 # 模型回答是否正确?
  28.                 is_correct = grade_answer_mathd(model_answer, ground_truth) or grade_answer_sympy(model_answer, ground_truth)
  29.                 if is_correct:
  30.                         # 设置正确性奖励
  31.                         reward = self.config.correct_reward
  32.                         return RewardOutput(reward=reward, is_correct=True)
  33.                        
  34.         # 模型回答错误
  35.         return RewardOutput(reward=self.config.incorrect_reward, is_correct=False)
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册