找回密码
 立即注册
首页 业界区 业界 Spring AI 代码分析(九)--记忆能力实现

Spring AI 代码分析(九)--记忆能力实现

郦珠雨 2025-11-27 03:35:00
记忆能力分析

请关注微信公众号:阿呆-bot
1. 工程结构概览

Spring AI 提供了完整的对话记忆(Chat Memory)能力,支持将对话历史持久化到各种存储后端。记忆能力是构建多轮对话应用的基础。
  1. spring-ai-model/
  2. └── chat/memory/                  # 记忆核心抽象
  3.     ├── ChatMemory.java           # 记忆接口
  4.     ├── ChatMemoryRepository.java # 存储仓库接口
  5.     ├── MessageWindowChatMemory.java  # 窗口记忆实现
  6.     └── InMemoryChatMemoryRepository.java  # 内存实现
  7. memory/repository/                # 持久化实现
  8. ├── spring-ai-model-chat-memory-repository-jdbc/      # JDBC 实现
  9. ├── spring-ai-model-chat-memory-repository-mongodb/    # MongoDB 实现
  10. ├── spring-ai-model-chat-memory-repository-neo4j/      # Neo4j 实现
  11. ├── spring-ai-model-chat-memory-repository-cassandra/  # Cassandra 实现
  12. └── spring-ai-model-chat-memory-repository-cosmos-db/  # Cosmos DB 实现
复制代码
2. 技术体系与模块关系

记忆系统采用分层设计:记忆接口 → 存储仓库 → 具体实现
1.png

3. 关键场景示例代码

3.1 基础使用

使用内存记忆:
  1. // 创建内存记忆
  2. ChatMemory memory = new MessageWindowChatMemory(
  3.     new InMemoryChatMemoryRepository(),
  4.     10  // 窗口大小:保留最近 10 条消息
  5. );
  6. // 添加消息
  7. memory.add("conversation-1", new UserMessage("你好"));
  8. memory.add("conversation-1", new AssistantMessage("你好!有什么可以帮助你的?"));
  9. // 获取对话历史
  10. List<Message> history = memory.get("conversation-1");
复制代码
3.2 使用 JDBC 持久化

使用数据库持久化记忆:
  1. @Autowired
  2. private DataSource dataSource;
  3. @Bean
  4. public ChatMemory chatMemory() {
  5.     JdbcChatMemoryRepository repository =
  6.         JdbcChatMemoryRepository.builder()
  7.             .dataSource(dataSource)
  8.             .build();
  9.    
  10.     return new MessageWindowChatMemory(repository, 20);
  11. }
复制代码
3.3 使用 MongoDB

使用 MongoDB 持久化:
  1. @Autowired
  2. private MongoTemplate mongoTemplate;
  3. @Bean
  4. public ChatMemory chatMemory() {
  5.     MongoChatMemoryRepository repository =
  6.         MongoChatMemoryRepository.builder()
  7.             .mongoTemplate(mongoTemplate)
  8.             .build();
  9.    
  10.     return new MessageWindowChatMemory(repository, 30);
  11. }
复制代码
3.4 在 ChatClient 中使用

记忆可以通过 Advisor 集成到 ChatClient:
  1. ChatMemory memory = new MessageWindowChatMemory(repository, 10);
  2. MessageChatMemoryAdvisor memoryAdvisor =
  3.     MessageChatMemoryAdvisor.builder()
  4.         .chatMemory(memory)
  5.         .conversationId("user-123")
  6.         .build();
  7. ChatClient chatClient = ChatClient.builder(chatModel)
  8.     .defaultAdvisors(memoryAdvisor)
  9.     .build();
  10. // 对话会自动保存到记忆
  11. String response = chatClient.prompt()
  12.     .user("我的名字是张三")
  13.     .call()
  14.     .content();
  15. // 后续对话会自动包含历史
  16. String response2 = chatClient.prompt()
  17.     .user("我的名字是什么?")
  18.     .call()
  19.     .content();  // 模型会记住名字是张三
复制代码
4. 核心实现图

4.1 记忆存储和检索流程

2.png

5. 入口类与关键类关系

3.png

6. 关键实现逻辑分析

6.1 ChatMemory 接口设计

ChatMemory 接口提供了简单的记忆 API:
  1. public interface ChatMemory {
  2.     void add(String conversationId, List<Message> messages);
  3.     List<Message> get(String conversationId);
  4.     void clear(String conversationId);
  5. }
复制代码
这个接口设计简洁,但功能强大。它支持:

  • 多对话管理:通过 conversationId 区分不同对话
  • 批量添加:支持一次添加多条消息
  • 清理功能:支持清除特定对话的记忆
6.2 MessageWindowChatMemory 实现

MessageWindowChatMemory 实现了窗口记忆策略:
  1. public class MessageWindowChatMemory implements ChatMemory {
  2.     private final ChatMemoryRepository repository;
  3.     private final int windowSize;
  4.    
  5.     @Override
  6.     public void add(String conversationId, List<Message> messages) {
  7.         // 1. 获取现有消息
  8.         List<Message> existing = repository.findByConversationId(conversationId);
  9.         
  10.         // 2. 添加新消息
  11.         List<Message> allMessages = new ArrayList<>(existing);
  12.         allMessages.addAll(messages);
  13.         
  14.         // 3. 应用窗口策略(只保留最近的 N 条)
  15.         List<Message> windowed = applyWindow(allMessages);
  16.         
  17.         // 4. 保存
  18.         repository.saveAll(conversationId, windowed);
  19.     }
  20.    
  21.     @Override
  22.     public List<Message> get(String conversationId) {
  23.         List<Message> messages = repository.findByConversationId(conversationId);
  24.         return applyWindow(messages);
  25.     }
  26.    
  27.     private List<Message> applyWindow(List<Message> messages) {
  28.         if (messages.size() <= windowSize) {
  29.             return messages;
  30.         }
  31.         // 只返回最近的 N 条消息
  32.         return messages.subList(messages.size() - windowSize, messages.size());
  33.     }
  34. }
复制代码
支持的数据库

  • PostgreSQL
  • MySQL/MariaDB
  • H2
  • SQLite
  • Oracle
  • SQL Server
  • HSQLDB
每个数据库都有自己的 Dialect 实现,处理 SQL 方言差异。
6.4 MongoDB 实现

MongoDB 实现使用文档存储:
  1. public class JdbcChatMemoryRepository implements ChatMemoryRepository {
  2.     private final JdbcTemplate jdbcTemplate;
  3.     private final ChatMemoryRepositoryDialect dialect;
  4.    
  5.     @Override
  6.     public List<Message> findByConversationId(String conversationId) {
  7.         String sql = dialect.getSelectByConversationIdSql();
  8.         
  9.         return jdbcTemplate.query(sql,
  10.             new Object[]{conversationId},
  11.             (rs, rowNum) -> {
  12.                 String content = rs.getString("content");
  13.                 String type = rs.getString("type");
  14.                 Map<String, Object> metadata = parseMetadata(rs.getString("metadata"));
  15.                
  16.                 return createMessage(type, content, metadata);
  17.             }
  18.         );
  19.     }
  20.    
  21.     @Override
  22.     public void saveAll(String conversationId, List<Message> messages) {
  23.         // 1. 删除现有消息
  24.         deleteByConversationId(conversationId);
  25.         
  26.         // 2. 批量插入新消息
  27.         String sql = dialect.getInsertSql();
  28.         List<Object[]> batchArgs = messages.stream()
  29.             .map(msg -> new Object[]{
  30.                 conversationId,
  31.                 msg.getText(),
  32.                 msg.getMessageType().name(),
  33.                 toJson(msg.getMetadata()),
  34.                 Timestamp.from(Instant.now())
  35.             })
  36.             .collect(toList());
  37.         
  38.         jdbcTemplate.batchUpdate(sql, batchArgs);
  39.     }
  40. }
复制代码
MongoDB 文档结构
  1. public class MongoChatMemoryRepository implements ChatMemoryRepository {
  2.     private final MongoTemplate mongoTemplate;
  3.    
  4.     @Override
  5.     public List<Message> findByConversationId(String conversationId) {
  6.         Query query = Query.query(
  7.             Criteria.where("conversationId").is(conversationId)
  8.         ).with(Sort.by("timestamp").descending());
  9.         
  10.         List<Conversation> conversations = mongoTemplate.find(
  11.             query, Conversation.class
  12.         );
  13.         
  14.         return conversations.stream()
  15.             .map(this::mapMessage)
  16.             .collect(toList());
  17.     }
  18.    
  19.     @Override
  20.     public void saveAll(String conversationId, List<Message> messages) {
  21.         // 1. 删除现有消息
  22.         deleteByConversationId(conversationId);
  23.         
  24.         // 2. 转换为文档并保存
  25.         List<Conversation> conversations = messages.stream()
  26.             .map(msg -> new Conversation(
  27.                 conversationId,
  28.                 new Conversation.Message(
  29.                     msg.getText(),
  30.                     msg.getMessageType().name(),
  31.                     msg.getMetadata()
  32.                 ),
  33.                 Instant.now()
  34.             ))
  35.             .collect(toList());
  36.         
  37.         mongoTemplate.insert(conversations, Conversation.class);
  38.     }
  39. }
复制代码
6.5 Neo4j 实现

Neo4j 实现使用图数据库:
  1. {
  2.   "conversationId": "user-123",
  3.   "message": {
  4.     "text": "你好",
  5.     "type": "USER",
  6.     "metadata": {}
  7.   },
  8.   "timestamp": "2025-01-01T00:00:00Z"
  9. }
复制代码
Neo4j 图结构
  1. public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
  2.     @Override
  3.     public List<Message> findByConversationId(String conversationId) {
  4.         String cypher = """
  5.             MATCH (s:Session {id: $conversationId})-[:HAS_MESSAGE]->(m:Message)
  6.             OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:Metadata)
  7.             OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:Media)
  8.             RETURN m, metadata, collect(media) as medias
  9.             ORDER BY m.idx ASC
  10.             """;
  11.         
  12.         return driver.executableQuery(cypher)
  13.             .withParameters(Map.of("conversationId", conversationId))
  14.             .execute(record -> mapToMessage(record));
  15.     }
  16.    
  17.     @Override
  18.     public void saveAll(String conversationId, List<Message> messages) {
  19.         String cypher = """
  20.             MERGE (s:Session {id: $conversationId})
  21.             WITH s
  22.             UNWIND $messages AS msg
  23.             CREATE (m:Message {
  24.                 text: msg.text,
  25.                 type: msg.type,
  26.                 idx: msg.idx
  27.             })
  28.             CREATE (s)-[:HAS_MESSAGE]->(m)
  29.             """;
  30.         
  31.         driver.executableQuery(cypher)
  32.             .withParameters(Map.of(
  33.                 "conversationId", conversationId,
  34.                 "messages", toMessageParams(messages)
  35.             ))
  36.             .execute();
  37.     }
  38. }
复制代码
6.6 Cassandra 实现

Cassandra 实现使用分布式存储:
  1. (Session {id: "user-123"})-[:HAS_MESSAGE]->(Message {text: "你好", type: "USER"})
  2. (Session {id: "user-123"})-[:HAS_MESSAGE]->(Message {text: "你好!", type: "ASSISTANT"})
复制代码
7. 实现对比分析

特性JDBCMongoDBNeo4jCassandra存储模型关系型文档型图型列族型查询方式SQLQuery DSLCypherCQL适用场景通用灵活结构关系查询大规模分布式性能中等高中等极高扩展性好很好好优秀事务支持✅✅✅❌8. 外部依赖

不同实现的依赖:
8.1 JDBC


  • Spring JDBC:JDBC 模板
  • 数据库驱动:PostgreSQL、MySQL 等
8.2 MongoDB


  • Spring Data MongoDB:MongoDB 集成
8.3 Neo4j


  • Neo4j Java Driver:Neo4j 官方驱动
8.4 Cassandra


  • Cassandra Java Driver:Cassandra 官方驱动
9. 工程总结

Spring AI 的记忆能力设计有几个值得学习的地方:
分层抽象。ChatMemory 提供高级 API,ChatMemoryRepository 提供存储抽象,具体实现处理数据库差异。这种设计让记忆功能既易用又灵活。想换存储后端?换个 ChatMemoryRepository 实现就行。
窗口记忆策略。MessageWindowChatMemory 实现了智能的消息管理,只保留最近的 N 条消息,这既控制了上下文长度,又保持了相关性。不会因为对话历史太长导致 token 超限。
多存储后端支持。支持 JDBC、MongoDB、Neo4j、Cassandra 等多种存储,用户可以根据需求选择最合适的后端。想用关系数据库?用 JDBC。想用图数据库?用 Neo4j。
统一的数据模型。所有实现都使用相同的 Message 模型,这让切换存储后端变得简单。今天用 PostgreSQL,明天想换 MongoDB?改个配置就行。
自动模式初始化。大多数实现都支持自动创建表/集合,简化了部署。不用手动建表,启动时自动搞定。
总的来说,Spring AI 的记忆能力既简单又强大。简单的 API 让使用变得容易,强大的实现让系统可以适应各种场景。这种设计让开发者可以轻松构建支持多轮对话的 AI 应用。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册