阅读之前你可以前往 RWKV wiki 了解一些关于 RWKV 的基本知识,不过他们的 wiki 似乎没有对模型架构的详细介绍,于是便有了这篇文章。
RWKV-7 的核心:动态状态演化机制
RWKV-V7 的动态状态演化机制可以通俗理解为 “在线学习上下文关系的动态记忆更新” 。它的核心思想是:通过实时计算和更新一个内部状态(state)来动态捕捉上下文中 key 和 value 的关联关系,并利用这个状态处理当前输入的 query(在 RWKV 中是 r)以生成输出 。
1. 状态(state)的本质
在 RWKV 中,state 是一个三维张量(B, H, N, N),其中:
- B:批量大小(batch size)
- H:注意力头数量(head count)
- N:每个头的维度(head size)
state 的作用是 维护一个动态的“知识库” ,记录历史输入的 key(k)和 value(v)之间的关系。它类似于传统 RNN 的隐藏状态,但更复杂,因为它显式建模了 key-value 的交互。
2. 状态更新公式与代码实现
官方公式:
St=St−1⋅diag(wt)+(St−1at⊤bt+vt⊤kt)
在代码中(RWKV7_OP 函数),这个公式被拆解为:
state = state * w + state @ a @ b + v @ k该公式对应 RWKV-V7 的 动态状态演化机制 ,其核心思想是通过 在线学习 维护一个低秩状态矩阵 S_t,以捕捉上下文中的 key-value 关系。公式分为三部分:
- 权重衰减项(遗忘旧信息)
state * w
- 作用 :控制历史信息的遗忘速率。
- 数学形式 :S_t = S_{t-1} ⋅ diag(w_t),其中 w_t 是负值(通过 exp(-exp(...)) 生成)。
- 思想 :类似 RNN 的遗忘门,w 越小(更负),遗忘越多,确保模型专注于当前上下文。
- 学习率调整项(动态修正知识)
state @ aa @ bb
- 作用 :根据当前输入调整已有知识的权重。
- 数学形式 :S_t = S_{t-1} ⋅ a_t^⊤ b_t,其中 a 和 b 是动态学习率参数(a = -k, b = k ⋅ η)。
- 思想 :模拟梯度下降中的学习率乘以梯度方向,使状态更新更适应当前输入的特征分布。
- 新信息注入项(更新知识库)
vv @ kk
- 作用 :将当前 token 的 key-value 对(k_t, v_t)外积加入状态。
- 数学形式 :S_t += v_t^⊤ ⋅ k_t。
- 思想 :直接注入新信息,确保模型能实时捕捉上下文中的新关联(如实体关系、语义依赖)
GPT 与 RNN 模式
在进一步了解之前需要知道的是,RWKV 分为两种模式,一种是 GPT 模式,一种是 RNN 模式,区别在于:
GPT 模式 (rwkv_v7_demo.py)
- 计算方式 :采用 Transformer 的 GPT 模式 ,一次性处理整个输入序列(类似标准 Transformer 的前向传播)。
- 状态管理 :无需显式维护隐藏状态(state),直接通过注意力机制处理上下文。
- 适用场景 :适合 预填充(prefill) 阶段(即处理初始提示词),但自回归生成时效率较低。
- 性能特点 :
- 预填充阶段可以并行计算所有 token 的表示。
- 自回归生成时需重新计算所有 token 的表示(无状态缓存),导致速度较慢。
RNN 模式 (rwkv_v7_demo_rnn.py)
- 计算方式 :采用 RNN 模式 ,逐个 token 处理,显式维护隐藏状态(state)。
- 状态管理 :通过 state 变量保存每个时间步的中间状态(如 att_kv、ffn_x_prev 等),避免重复计算。
- 适用场景 :适合 自回归生成 (逐 token 生成),效率更高。
- 性能特点 :
- 预填充阶段效率低(需逐 token 处理)。
- 生成阶段只需计算当前 token 的表示,速度显著优于 GPT 模式。
此外,官方还提供了结合两种模式的 fast 模式: rwkv_v7_demo_fast.py
GPT 模式
Time mix(RWKV_Tmix_x070)
Time mix 是整个 RWKV-7 的核心,也是最复杂的部分。
RWKV_Tmix_x070 是 RWKV 模型中实现 Time Mixing 的核心模块,用于处理序列数据中的时序依赖关系。它是 RWKV-V7 架构中负责模拟注意力机制和动态状态更新的组件,类似于 Transformer 中的自注意力模块,但采用了更高效的线性复杂度设计。- @MyFunction
- def forward(self, x, v_first):
- # 获取输入张量的维度:批量大小 B、序列长度 T、特征维度 C
- B, T, C = x.size()
-
- # 获取注意力头数量 H 和每个头的维度 N(固定为 64)
- H = self.n_head
- N = self.head_size
- # 使用 time_shift 操作获取前一个时间步的信息,并与当前输入做差值
- xx = self.time_shift(x) - x # 得到时间差分项,用于生成不同方向的输入
- # 根据时间差分项和可学习参数,生成不同方向的输入
- xr = x + xx * self.x_r # receptance 输入方向
- xw = x + xx * self.x_w # decay weight 输入方向
- xk = x + xx * self.x_k # key 输入方向
- xv = x + xx * self.x_v # value 输入方向
- xa = x + xx * self.x_a # learning rate 输入方向
- xg = x + xx * self.x_g # gate 输入方向
- # 通过线性层生成 r(receptance),类似于传统 attention 中的 query
- r = self.receptance(xr)
- # 构造 w(权重衰减项):控制 state 的遗忘速率
- # 使用 tanh 和 softplus 确保 w 值在 (-inf, -0.5) 范围内
- w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5
- # 通过线性层生成 key 向量
- k = self.key(xk)
- # 通过线性层生成 value 向量
- v = self.value(xv)
- # 如果是第一层,则保存当前 v 作为 v_first(初始 value)
- if self.layer_id == 0:
- v_first = v
- else:
- # 否则使用门控机制对 v 进行残差连接,保留一部分初始信息
- v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
- # 生成动态学习率 a,控制 state 更新的学习率(类似在线梯度下降中的学习率)
- a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)
- # 生成门控参数 g,控制最终输出的激活强度
- g = torch.sigmoid(xg @ self.g1) @ self.g2
- # 对 key 进行缩放和 L2 归一化,增强数值稳定性
- kk = k * self.k_k
- kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
- # 调整 key 的尺度,结合学习率 a 和可学习参数 self.k_a
- k = k * (1 + (a - 1) * self.k_a)
- # 调用 RWKV7_OP 函数进行状态更新和 attention 输出计算
- x = RWKV7_OP(r, w, k, v, -kk, kk * a)
- # 使用 GroupNorm 对输出进行归一化
- x = self.ln_x(x.view(B * T, C)).view(B, T, C)
- # 添加一个额外的残差连接,融合 r, k, v 的乘积项
- x = x + ((r.view(B, T, H, -1) *
- k.view(B, T, H, -1) *
- self.r_k).sum(dim=-1, keepdim=True) *
- v.view(B, T, H, -1)).view(B, T, C)
- # 最终输出经过门控 g 控制,并通过线性层输出
- x = self.output(x * g)
- # 返回输出和 v_first,供下一层使用
- return x, v_first
复制代码 模块基于以下两个核心机制:
一. 时间差分与方向调整
首先,Time Shift 通过 time_shift(x) - x 得到当前 token 与前一 token 的差异,然后结合可学习参数生成多个“方向”的输入(如 xr, xw, xk, xv, xa, xg),分别用于构建不同的注意力相关参数。
具体来说,Time Shift 的作用是将输入张量在时间维度上向前移动一个单位,这样在处理当前时间步 xt 时,模型可以获取到到上一个时间步 xt−1 的信息。
举个例子:- # 假设原始输入是:
- x = [
- [x_0_0, x_0_1, ..., x_0_C], # 时间步 t=0
- [x_1_0, x_1_1, ..., x_1_C], # 时间步 t=1
- ...
- [x_T_0, x_T_1, ..., x_T_C] # 时间步 t=T
- ]
- # Time Shift 后得到的是:
- x_shifted = [
- [0, 0, ..., 0], # 第一个位置用 0 填充(表示没有前一步)
- [x_0_0, x_0_1, ..., x_0_C], # 第二个位置是 t=0 的值(相当于延迟了一个时间步)
- ...
- [x_{T-1}_0, ..., x_{T-1}_C] # 最后一个位置是 t=T-1 的值
- ]
复制代码
接着利用 Time Shift 的结果 xx 我们就可以计算不同的注意力相关参数。- xr = x + xx * self.x_r
- xw = x + xx * self.x_w
- xk = x + xx * self.x_k
- xv = x + xx * self.x_v
- xa = x + xx * self.x_a
- xg = x + xx * self.x_g
复制代码 这些新变量 xr, xw, xk, xv, xa, xg 分别用于生成注意力机制中的不同角色:
变量对应角色作用xrreceptance (r)类似 query,在状态更新中起调节作用xwdecayweight (w)控制遗忘速率xkkey (k)用于构建上下文关联xvvalue (v)上下文信息载体xalearning rate (a)动态调整学习率xggate (g)非线性激活门控二.线性变换生成核心参数
- r = self.receptance(xr)
- w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
- k = self.key(xk)
- v = self.value(xv)
复制代码
- receptance(xr) : 生成 r(类似传统注意力中的 query,但用于控制当前状态如何与新的输入进行交互,决定哪些信息会被“接收”并用于输出计算)。
- w : 权重衰减参数(控制遗忘速率),类似于 RNN 中的遗忘门。
- 公式 :w = -softplus(-(...)) - 0.5,确保 w 为负值。
- 作用 :模拟指数衰减,遗忘旧信息。
- key(xk) : 生成 k(key 向量,决定新信息如何与当前 state 进行交互)。
- value(xv) : 生成 v(value 向量,表示输入的“值”,包含实际的内容信息)。
三. 残差连接机制
- if self.layer_id == 0:
- v_first = v # store the v of the first layer
- else:
- v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
复制代码 <blockquote>一、什么是残差连接?
✅ 简单定义:
残差连接(Residual Connection) 是指将某一层的输入直接加到该层的输出上。
数学表达为:
output=input+F(input)
其中 F(input) 是网络中某个函数或模块对输入的变换。
<strong>
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |