找回密码
 立即注册
首页 业界区 业界 Minimind-一个开源LLM项目的代码分析1:模型结构 ...

Minimind-一个开源LLM项目的代码分析1:模型结构

烯八 2025-9-20 20:23:49
如果你是一名刚接触大语言模型(LLM)的初学者,很可能会在社交媒体上看到这样一个项目——MiniMind
1.jpg
2.jpg

这个项目实现了一个参数规模较小但功能完整的 LLM,涵盖了预训练、LoRA 微调、SFT、蒸馏以及基于人类反馈的强化学习(RLHF)等多个模块,可以说是非常难得的入门教材。
MiniMind 提供了清晰的复现指南和环境配置说明,但在代码背后的原理解释上并不算详细。对于像笔者这样并非 NLP 出身的初学者来说,直接啃源码还是有相当的难度,因此有必要把一些关键的基础知识梳理下来,既能帮助加深理解,也便于后续复习。
因此,本文主要记录了笔者在学习该项目过程中新掌握或重新温习的重要知识点,并在文末推荐了一些适合入门的参考博客。同时,文中还附上了带注释的源码片段,希望能为理解整个项目的实现提供更直观的帮助。
本文是此系列的第一遍博客,主要介绍了model_minimind.py中涉及的基础知识,后续会继续更新,敬请期待。
RMSNorm

现在LLM架构或者其他transfomer架构喜欢使用RMSNorm,我们在这里一并区分三种常见的Norm:
BatchNorm(for cv)

BatchNorm,是按照batch维度进行归一化。常用于CV任务, BatchNorm把一个batch中同一通道的所有特征(如下图红色通道对应特征图)视为一个分布(有几个通道就有几个分布),并将其标准化。
3.jpeg

代码:一个batchsize的数据送入网络,经过卷积层,得到一个四维的tensor,形状为(batch_size, channels, height, width),即(N,C,H,W)然后对这个tensor进行BatchNorm操作。
假设输入张量形状是:

\[x \in \mathbb{R}^{(N, C, H, W)}\]

  • \(N\):batch size
  • \(C\):通道数(channel)
  • \(H, W\):空间维度(height, width)
对每个通道 \(c\),在整个 batch(N 个样本,H×W 个位置)上统计均值和方差,然后做标准化。
对于通道 \(c\),我们先算均值和方差:

\[\mu_c = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W x_{n,c,h,w}\]

\[\sigma_c^2 = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^N \sum_{h=1}^H \sum_{w=1}^W (x_{n,c,h,w} - \mu_c)^2\]
然后标准化:

\[\hat{x}_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}\]
最后再加上可学习的缩放和平移参数:

\[y_{n,c,h,w} = \gamma_c \hat{x}_{n,c,h,w} + \beta_c\]
其中:

  • \(\gamma_c, \beta_c\) 是学习到的参数(每个通道一对)。
  • \(\epsilon\) 是防止除零的常数。
手动实现:
  1. import torch
  2. def batchnorm2d(x, gamma, beta, eps=1e-5):
  3.     # x: [8,3,32,32]
  4.     mean = x.mean(dim=(0, 2, 3), keepdim=True)       # mean: [1,3,1,1]
  5.     var = x.var(dim=(0, 2, 3), keepdim=True, unbiased=False) # var:  [1,3,1,1]
  6.    
  7.     x_hat = (x - mean) / torch.sqrt(var + eps)       # 标准化,x_hat: [8,3,32,32],mean,var从[1,3,1,1] 广播为[8,3,32,32]
  8.     y = gamma.view(1, -1, 1, 1) * x_hat + beta.view(1, -1, 1, 1)
  9.     return y
  10. # 测试
  11. x = torch.randn(8, 3, 32, 32)  # batch=8, 通道=3, 32x32
  12. gamma = torch.ones(3)          # 初始缩放
  13. beta = torch.zeros(3)          # 初始平移
  14. y = batchnorm2d(x, gamma, beta)
  15. print(y.shape)  # torch.Size([8, 3, 32, 32])
复制代码
调库实现
  1. import torch
  2. import torch.nn as nn
  3. batch_size, channels, height, width = 8, 3, 32, 32
  4. # 创建一个BatchNorm层,只需要指定通道数
  5. batch_norm = nn.BatchNorm2d(channels)
  6. input_tensor = torch.randn(batch_size, channels, height, width)
  7. output_tensor = batch_norm(input_tensor)
  8. print(output_tensor.shape)  # 输出形状仍然是 (batch_size, channels, height, width)
复制代码
多提一嘴:复习广播机制:

  • 如果有一个维度是 1,可以扩展到另一个维度的大小。
  • 如果维度不相等,且没有 1,那么报错。
  • 赋值原则:若从(1,x,y,1)扩展到(a,x,y,b),则(1,x,y,1)的值会被复制a*b次
Layernorm (for nlp)

先明确一下nlp处理的tensor形状是什么,nlp中常用的输入形状是(batch_size, seq_length, embedding_dim),即(N, L, D),其中N是batch size(几个句子),L是序列长度(一个句子多少个token),D是每个词的嵌入维度。
4.jpeg

LayerNorm是对每个句子的所有词向量进行归一化。也就是每一个Embedding进行归一化,这样做的目的是保证了每个 token 的表示数值稳定,不会因为 embedding 的绝对大小不同而影响训练。
假设有一个 batch,里面有 \(N\) 个句子,每个句子有 \(M\) 个 token,每个 token 的 embedding 维度是 \(H\):

\[x \in \mathbb{R}^{N \times M \times H}\]
LayerNorm 不会跨样本,不会跨 token。它只会对 每个 token 的 hidden 维度 \(H\) 求均值方差。数学上,如果第 \(n\) 个句子、第 \(m\) 个 token 的 embedding 是

\[x_{n,m,:} = (x_{n,m,1}, x_{n,m,2}, \dots, x_{n,m,H})\]
那么 LN 的均值和方差是:

\[\mu_{n,m} = \frac{1}{H}\sum_{i=1}^H x_{n,m,i}, \quad \sigma_{n,m}^2 = \frac{1}{H}\sum_{i=1}^H (x_{n,m,i}-\mu_{n,m})^2\]
然后归一化:

\[\hat{x}_{n,m,i} = \frac{x_{n,m,i} - \mu_{n,m}}{\sqrt{\sigma_{n,m}^2+\epsilon}}\]
再乘上可学习参数:

\[y_{n,m,i} = \gamma_i \hat{x}_{n,m,i} + \beta_i\]
调库实现
  1. import torch
  2. import torch.nn as nn
  3. N, M, H = 2, 4, 6   # batch=2, seq_len=4, hidden=6
  4. x = torch.randn(N, M, H)
  5. layernorm = nn.LayerNorm(H)  # 只对最后一维 hidden 归一化
  6. y = layernorm(x)
  7. print("输入形状:", x.shape)  # (2,4,6)
  8. print("输出形状:", y.shape)  # (2,4,6)
复制代码
手写
  1. def layernorm(x, gamma, beta, eps=1e-5):
  2.     # x: [N, M, H]
  3.     mean = x.mean(dim=-1, keepdim=True)   # [N, M, 1] 逐样本逐序列求均值
  4.     var = x.var(dim=-1, keepdim=True, unbiased=False)  # [N, M, 1] 方差
  5.    
  6.     x_hat = (x - mean) / torch.sqrt(var + eps)  # [N, M, H] 标准化
  7.     y = gamma * x_hat + beta                   # [N, M, H] 缩放平移
  8.     return y
复制代码
RMSNorm

RMSNorm 干脆 不要均值减法,只用平方均值 (Root Mean Square, RMS):

\[\mathrm{RMSNorm}(x)=\frac x{\mathrm{RMS}(x)}\cdot\gamma \]
其中

\[\mathrm{RMS}(x)=\sqrt{\frac1d\sum_{i=1}^dx_i^2+\epsilon}\]
在Minimind的代码中,实现如下:
  1. class RMSNorm(torch.nn.Module):
  2.     def __init__(self, dim: int, eps: float = 1e-5):
  3.         super().__init__()
  4.         self.eps = eps
  5.         self.weight = nn.Parameter(torch.ones(dim))  # 可学习的权重参数gamma
  6.     def _norm(self, x):
  7.         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # rsqrt:reciprocal square root 倒数平方根运算
  8.     def forward(self, x):
  9.         return self.weight * self._norm(x.float()).type_as(x)   # 确保类型一致
  10.         
复制代码
RoPE

RoPE (Rotary Position Embedding) 是一种位置编码方法,旨在为 Transformer 模型引入位置信息。与传统的绝对位置编码(如正弦余弦位置编码)不同,RoPE 通过对查询(Q)和键(K)向量进行旋转变换来实现相对位置编码。
定义二位旋转矩阵:

\[\boldsymbol{f}(\boldsymbol{q},m)=\binom{\cos m\theta\quad-\sin m\theta}{\sin m\theta\quad\cos m\theta}\left(\frac{q_0}{q_1}\right)\]
由于内积满足线性叠加性,因此任意偶数维的RoPE,我们都可以表示为二维情形的拼接,即

\[\underbrace{\begin{pmatrix}\cos m\theta_{0}&-\sin m\theta_{0}&0&0&\cdots&0&0\\\sin m\theta_{0}&\cos m\theta_{0}&0&0&\cdots&0&0\\0&0&\cos m\theta_{1}&-\sin m\theta_{1}&\cdots&0&0\\0&0&\sin m\theta_{1}&\cos m\theta_{1}&\cdots&0&0\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&0&\cdots&\cos m\theta_{d/2-1}&-\sin m\theta_{d/2-1}\\0&0&0&0&\cdots&\sin m\theta_{d/2-1}&\cos m\theta_{d/2-1}\end{pmatrix}}_{\mathbf{a}_{m}}\begin{pmatrix}q_{0}\\q_{1}\\q_{2}\\q_{3}\\\vdots\\q_{d-2}\\q_{d-1}\end{pmatrix}\]
也就是说,给位置为\(m\)的向量\(q\)乘上矩阵\(\mathcal{R}_m\)、位置为\(\color{red}{n}\)的向量\(k\)乘上矩阵\(\mathcal{R}_n\),用变换后的\(Q,K\)序列
做Attention,那么Attention就自动包含相对位置信息了,因为成立恒等式:

\[(\mathcal{R}_m\boldsymbol{q})^\top(\mathcal{R}_n\boldsymbol{k})=\boldsymbol{q}^\top\mathcal{R}_m^\top\mathcal{R}_n\boldsymbol{k}=\boldsymbol{q}^\top\mathcal{R}_{n-m}\boldsymbol{k}\]
鉴于计算中的稀疏性,直接用矩阵乘法来实现会很浪费算力,推荐通过下述方式来实现RoPE:

\[\begin{pmatrix}q_0\\q_1\\q_2\\q_3\\\vdots\\q_{d-2}\\q_{d-1}\end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_0\\\cos m\theta_0\\\cos m\theta_1\\\cos m\theta_1\\\vdots\\\cos m\theta_{d/2-1}\\\cos m\theta_{d/2-1}\end{pmatrix}+\begin{pmatrix}-q_1\\q_0\\-q_3\\q_2\\\vdots\\-q_{d-1}\\q_{d-2}\end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_0\\\sin m\theta_0\\\sin m\theta_1\\\sin m\theta_1\\\vdots\\\sin m\theta_{d/2-1}\\\sin m\theta_{d/2-1}\end{pmatrix}\]
最后说一下\(\theta\)的取值,代码里面\(\theta_i\)记为了\(\omega_i\):
先构造频率

\[\omega_i=\frac1{\theta^{i/d}},\quad i=0,2,4,\ldots,dim-2\]
一共 dim/2 个不同的频率。然后乘以位置 \(m\):

\[\text{freqs}[m,i]=m\cdot\omega_i\]
这就对应到公式里的 \(m\theta_i\)。换句话说,freqs 里面存的就是 角度 \(m\theta_i\)。
Minimind的代码实现:
  1. def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
  2.     # 计算所有维度位置的频率
  3.     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  4.     # 不同位置的q,k向量乘不同的m,这里生成0,1,2,...,end-1
  5.     t = torch.arange(end, device=freqs.device)
  6.     # 外积得到(end, dim/2) 第i行对应适用于token位置i的q,k
  7.     freqs = torch.outer(t, freqs).float()
  8.     # 计算cos和sin,并拼接成(end, dim)
  9.     freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
  10.     freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
  11.     return freqs_cos, freqs_sin
  12. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  13.     def rotate_half(x):
  14.         return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
  15.     q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
  16.     k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
  17.     return q_embed, k_embed
复制代码
GQA-attention

在梳理Attention模块的的代码实现之前,我们需要提前认知几个重要概念,并做一些回顾:
torch操作回顾


  • 在 PyTorch 里,@ 和 torch.matmul 对高维张量矩阵乘法的定义是:只取最后两个维度做矩阵乘法!
  1. A.shape = [2, 3, 4, 5]   # (batch=2, heads=3, 4×5矩阵)
  2. B.shape = [2, 3, 5, 6]   # (batch=2, heads=3, 5×6矩阵)
  3. C = A @ B  # [2, 3, 4, 6]
复制代码

  • 对高维张量,transpose(dim1, dim2) 只会交换这两个维度,其他维度保持不变。(注意,dim从前往后数是0开始的,使用负数的时候,-1 表示最后一维,-2 表示倒数第二维)
  1. A.shape = [2, 3, 4, 5]   # (batch=2, heads=3, 4×5矩阵)
  2. B = A.transpose(1, 2)  # 交换第1维和第2维
  3. B.shape  # [2, 4, 3, 5]
  4. x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
  5. y = x.transpose(-1, -3)      # 相当于 transpose(3, 1)
  6. print(y.shape)               # (2, 5, 4, 3)
复制代码

  • nn.Linear 本质上也是矩阵乘法,当高维 tensor 输入 linear 层时,只有最后两个维度会参与线性计算,前面的维度会被视为 batch 维度而保留下来。例如:
  1. proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
复制代码
如果输入的形状是 (bsz, seq_len, hidden_size),经过 proj 之后会变成 (bsz, seq_len, num_attention_heads * head_dim)。这里的线性层相当于对每个 token 的隐状态向量做一次相同的全连接变换,把它映射到多头注意力所需的维度空间。
kv-cache

一种缓存机制,用于加速推理过程。鉴于计算注意力得分的时候,key和value需要被复用,因此可以缓存之前的key和value,空间换时间。
5.gif
  1. # kv_cache实现,推理时使用,训练时关闭
  2. if past_key_value is not None:
  3.     xk = torch.cat([past_key_value[0], xk], dim=1)
  4.     xv = torch.cat([past_key_value[1], xv], dim=1)
  5. past_kv = (xk, xv) if use_cache else None
复制代码
mask操作

推荐一个写的很不错的博客:https://zhuanlan.zhihu.com/p/28786272137
因果mask:(Causal Mask,下三角)

  • Q/K/V的形状: (bsz, num_heads, seq_len, head_dim)
  • 转置后形状:(bsz, self.n_local_heads, seq_len, self.head_dim)
矩阵乘法默认发生在最后两个维度上,因此 Q @ K^T 的结果形状是 (bsz, num_heads, seq_len, seq_len),表示一个bsz中,每个head的每个query与所有key的相似度得分。
因果mask矩阵作用在(seq_len, seq_len)维度上,确保每个位置只能看到它之前(包括它自己)的token。具体的,每个待处理token之后的位置被设置为负无穷(-inf),这样在softmax之后,这些位置的权重就变成了0。
6.png

padding mask
在神经网络的训练过程中,同一个batch会包含有多个文本序列,不同的序列长度并不一定会一致。而神经网络的输入需要一个规整的张量。为了符合模型的输入方式,在数据集的生成过程中,我们要对输入序列进行对齐,使同一个batch内所有序列的长度一致。具体来说就是:
7.png

因此,综合考虑两种mask,我们可以将它们相加,得到最终的注意力掩码矩阵。这个矩阵会在计算注意力得分时使用,确保模型只能关注到合法的位置。
8.png

完整代码解析+注释

现在,我们进行源代码的逐行解析,在看代码之前,我们先明确一下Dense model涉及到的所有参数
模块参数名含义典型值备注GQA (Grouped Query Attention)hidden_size输入 embedding 维度512Q/K/V 输入输出的基准维度num_attention_headsQ 的头数8Query 被分成多少个子空间num_key_value_headsKV 的头数2K/V 使用更少的头数,节省显存head_dim每个头的维度hidden_size // num_attention_heads = 64Q/K/V 每个头的子空间大小n_repQ 头 / KV 头比值num_attention_heads // num_key_value_heads = 4每个 KV 头复制给多个 Q 头FFN (Feed-Forward Network)hidden_size输入维度512与 Transformer 输入输出一致intermediate_size中间层维度通常 hidden_size * 4 = 2048扩展后再压缩hidden_act激活函数SiLU典型是 ReLU / GELU / SiLUdropoutDropout 比例0.0防止过拟合简写几个关键变量:

  • batch_size = bsz
  • seq_len = L
  • hidden_size = 512
  • num_attention_heads = 8
  • num_key_value_heads = 2
  • head_dim = 64
9.png

现在可以看Attention部分的源代码了,我做了详细的注释
  1. class Attention(nn.Module):
  2.     def __init__(self, args: MiniMindConfig):
  3.         super().__init__()
  4.         # GQA 中 q的头数可以多于kv的头数,因此先做一个判断
  5.         self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
  6.         assert args.num_attention_heads % self.num_key_value_heads == 0 # 确保q的头数是kv头数的整数倍
  7.         self.n_local_heads = args.num_attention_heads # Q 的头数
  8.         self.n_local_kv_heads = self.num_key_value_heads # K/V 的头数
  9.         self.n_rep = self.n_local_heads // self.n_local_kv_heads # Q头数 / KV头数,表示每个KV头会被复制给多少个Q头
  10.         self.head_dim = args.hidden_size // args.num_attention_heads # 每个头的维度,即原始Embedding压缩后的维度 (一般这种压缩满足:原始hidden_size = num_heads * head_dim)
  11.         self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False) # W_q,多头已经并起来了
  12.         self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) # W_k,多头已经并起来了
  13.         self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) # W_v,多头已经并起来了
  14.         self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False) # 输出投影,把多头结果线性组合(变换)为原始hidden_size维度
  15.         self.attn_dropout = nn.Dropout(args.dropout) # 注意力得分的dropout
  16.         self.resid_dropout = nn.Dropout(args.dropout) # 输出的dropout
  17.         self.dropout = args.dropout # 用于flash attention的dropout
  18.         self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn # 检查是否能用flash attention这一高效实现,如果可以则self.flash = True
  19.         # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
  20.     # x的输入形状: (bsz, seq_len, hidden_size)
  21.     def forward(self,
  22.                 x: torch.Tensor,
  23.                 position_embeddings: Tuple[torch.Tensor, torch.Tensor],  # 修改为接收cos和sin
  24.                 past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
  25.                 use_cache=False,    # 训练时关闭,推理时打开
  26.                 attention_mask: Optional[torch.Tensor] = None # (bsz, seq_len) 1表示有效,0表示padding(后续转化为False,用于mask)
  27.                 ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  28.         bsz, seq_len, _ = x.shape
  29.         xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
  30.         xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
  31.         xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
  32.         xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
  33.         cos, sin = position_embeddings
  34.         xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len]) #(应用RoPE后)
  35.         # kv_cache实现,推理时使用,训练时关闭
  36.         if past_key_value is not None:
  37.             xk = torch.cat([past_key_value[0], xk], dim=1)
  38.             xv = torch.cat([past_key_value[1], xv], dim=1)
  39.         past_kv = (xk, xv) if use_cache else None
  40.         # repeat k,v的头数与q匹配
  41.         # 转置:(bsz, seq_len, self.n_local_heads, self.head_dim) -> (bsz, self.n_local_heads, seq_len, self.head_dim)
  42.         xq, xk, xv = (
  43.             xq.transpose(1, 2),
  44.             repeat_kv(xk, self.n_rep).transpose(1, 2),
  45.             repeat_kv(xv, self.n_rep).transpose(1, 2)
  46.         )
  47.         # 如果可以使用flash attention
  48.         if self.flash and seq_len != 1:
  49.             # 训练时打开 dropout,推理时关闭
  50.             dropout_p = self.dropout if self.training else 0.0
  51.             attn_mask = None
  52.             if attention_mask is not None: # 有padding mask
  53.                 attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1) # bsz, self.n_local_heads, seq_len, seq_len
  54.                 attn_mask = attn_mask.bool() if attention_mask is not None else None # 转0,1为bool,符合flash attention的要求
  55.             output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True) # 高性能实现,自动加上因果 mask output = softmax(QK^T/sqrt(d)+mask)V->droupout (bsz,self.n_local_heads,seq_len,self,head_dim)
  56.         else:
  57.             scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) # Q @ K^T / sqrt(d_k),形状:(bsz, self.n_local_heads, seq_len, seq_len)
  58.             scores = scores + torch.triu(
  59.                 torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
  60.                 diagonal=1
  61.             ).unsqueeze(0).unsqueeze(0)  # scores+mask (bsz,self.n_local_heads,seq_len,seq_len)
  62.             if attention_mask is not None:
  63.                 extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (bsz, seq_len)->(bsz, 1, 1, seq_len)
  64.                 extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 #
  65.                 scores = scores + extended_attention_mask # (bsz,self.n_local_heads,seq_len,seq_len)
  66.             scores = F.softmax(scores.float(), dim=-1).type_as(xq) # (bsz,self.n_local_heads,seq_len,seq_len)
  67.             scores = self.attn_dropout(scores)
  68.             output = scores @ xv # (bsz,self.n_local_heads,seq_len,seq_len)*(bsz,self.n_local_heads,seq_len,self.head_dim)
  69.         # output = (bsz,self.n_local_heads,seq_len,self.head_dim)
  70.         output = output.transpose(1, 2).reshape(bsz, seq_len, -1)  # (bsz, seq_len,self.n_local_heads*self.head_dim)
  71.         output = self.resid_dropout(self.o_proj(output))  
  72.         return output, past_kv
  73.         # 最终输出形状;# (bsz, seq_len,self.n_local_heads*self.head_dim)
  74.         # =  (bsz, seq_len, hidden_size)
复制代码
FFN部分(采用了GLU结构)

前置知识1:函数ACT2FN
  1. self.act_fn = ACT2FN[config.hidden_act]
复制代码

  • ACT2FN 是个字典,存了激活函数名字到实现的映射,比如:
    1. ACT2FN = {
    2.     "relu": torch.nn.functional.relu,
    3.     "gelu": torch.nn.functional.gelu,
    4.     "silu": torch.nn.functional.silu,  # SiLU = Swish
    5. }
    复制代码
  • 在 LLaMA 系列里,激活函数是 SiLU(又叫 Swish)。
    数学形式:

    \[\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}\]
前置知识2:GLU风格的FFN结构

这个 FFN 用的是 Gated Linear Unit (GLU) 风格:

\[\text{FFN}(x) = W_{down} \Big( \text{SiLU}(W_{gate} x) \odot (W_{up} x) \Big)\]
其中:

  • \(W_{gate}, W_{up} \in \mathbb{R}^{d_{hidden} \times d_{inter}}\)
  • \(W_{down} \in \mathbb{R}^{d_{inter} \times d_{hidden}}\)
  • \(\odot\) 表示逐元素乘法。
  1. # 输入形状: (bsz, seq_len, hidden_size)
  2. # 中间有门控节
  3. # 输出形状: (bsz, seq_len, hidden_size)
  4. class FeedForward(nn.Module):
  5.     def __init__(self, config: MiniMindConfig):
  6.         super().__init__()
  7.         if config.intermediate_size is None:
  8.             intermediate_size = int(config.hidden_size * 8 / 3) # 通常设为hidden_size的8/3倍
  9.             config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64) # 对齐到 64 的整数倍,硬件友好
  10.         self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
  11.         self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  12.         self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
  13.         self.dropout = nn.Dropout(config.dropout)
  14.         self.act_fn = ACT2FN[config.hidden_act]
  15.     def forward(self, x):
  16.         return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
复制代码
MOE 模块 (混合专家模型)

10.png

整体结构

详细的可参见苏剑林博客:https://spaces.ac.cn/archives/10699 我觉得写的很好,我自己如何整理也不如这篇博客讲的清晰了。
简答来说,在MInimind的模型里,MOE模块是这样的

  • 1 有一部分共享专家,总是被选择
  • 2 有一个打分器(Router),其数学形式为\(\underbrace{[\rho_1,\rho_2,\cdots,\rho_n]}_{\rho}=h(\boldsymbol{xW}^{(R)})\quad\in\mathbb{R}_{\geq0}^n\)
  • 3 选择top-k个专家,并对其打分进行softmax归一化,加权求和
  • 4 计算负载均衡loss,鼓励路由器均匀使用专家
前置torch知识


  • torch.Tensor.scatter_add_(dim, index, src):在指定的维度 dim 上,按照 index 里的位置,把 src 中的值「加到」当前 Tensor 的对应位置上。


  • index: 索引 Tensor,和 src 形状相同,表示要加到目标 Tensor 的哪个位置。
  • src: 源 Tensor,包含要加的值。
  1. import torch
  2. # 初始 Tensor
  3. out = torch.zeros(3, 5, dtype=torch.float)
  4. # 索引
  5. index = torch.tensor([[0, 1, 2],
  6.                       [2, 3, 4]])
  7. # 源值
  8. src = torch.tensor([[1, 1, 1],
  9.                     [2, 2, 2]], dtype=torch.float)
  10. out.scatter_add_(1, index, src)
  11. print(out)
复制代码
结果输出;
  1. tensor([[1., 1., 1., 0., 0.],
  2.         [0., 0., 2., 2., 2.],
  3.         [0., 0., 0., 0., 0.]])
复制代码
dim=1 表示在第 1 个维度(列递增方向)上 scatter。第一行的 1 分别加到了第 0、1、2 列。第二行的 2 分别加到了第 2、3、4 列。

  • torch.nn.functional.one_hot(tensor, num_classes):将整数 Tensor 转换为 one-hot 编码形式。


  • tensor: 输入的整数 Tensor,元素值应在 [0, num_classes-1] 范围内。
  • num_classes: 类别总数,决定 one-hot 向量的长度。
  1. import torch
  2. # 输入的整数 Tensor
  3. tensor = torch.tensor([0, 2, 1, 3])
  4. # 转换为 one-hot 编码
  5. one_hot = torch.nn.functional.one_hot(tensor, num_classes=4)  
  6. """
  7. tensor([[1, 0, 0, 0],
  8.         [0, 0, 1, 0],
  9.         [0, 1, 0, 0],
  10.         [0, 0, 0, 1]])
  11. """
复制代码
MOE Gate
  1. # 输入形状: (bsz, seq_len, hidden_size)
  2. # 输出: topk_idx (bsz*seq_len, top_k), topk_weight (bsz*seq_len, top_k), aux_loss (可求导的标量)
  3. class MoEGate(nn.Module):
  4.     def __init__(self, config: MiniMindConfig):
  5.         super().__init__()
  6.         self.top_k = config.num_experts_per_tok        # 每个 token 选择多少个专家
  7.         self.n_routed_experts = config.n_routed_experts # 总共多少个专家
  8.         self.scoring_func = config.scoring_func        # 打分方式(一般是 softmax)
  9.         self.alpha = config.aux_loss_alpha             # 辅助损失的权重
  10.         self.seq_aux = config.seq_aux                  # 是否用序列级别的辅助损失
  11.         self.norm_topk_prob = config.norm_topk_prob    # 是否对 top-k 概率归一化
  12.         self.gating_dim = config.hidden_size
  13.         # Router 的核心参数,相当于 [n_experts, hidden_dim] 的打分矩阵
  14.         self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
  15.         self.reset_parameters()
复制代码
forward部分:首先输出每个 token 在专家上的分布,形状为(bsz*seq_len, n_experts) (token数量,可投票选择的专家数量),然后选出topk及其索引,根据接收参数决定是都要归一化
  1. def forward(self, hidden_states):
  2.     bsz, seq_len, h = hidden_states.shape
  3.     hidden_states = hidden_states.view(-1, h)  # [batch * seq, hidden_dim]
  4.    
  5.     # 计算 gating logits: [batch*seq, n_experts]
  6.     logits = F.linear(hidden_states, self.weight, None)
  7.    
  8.     # softmax 得到每个 token 在专家上的分布
  9.     scores = logits.softmax(dim=-1)
  10.     # 选出前 top-k 个专家
  11.     # topk_weight: [batch*seq, top_k], topk_idx: [batch*seq, top_k]
  12.     topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
  13.     # 如果需要,归一化 top-k 权重,让它们和为 1
  14.     if self.top_k > 1 and self.norm_topk_prob:
  15.         denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
  16.         topk_weight = topk_weight / denominator
复制代码
紧接着计算aux_loss,促进负载平衡.关于aux_loss的计算,有序列级别(句子级别)和token级别两种计算方式。先看序列级别的aux_loss:
  1.     # 负载均衡辅助损失 (aux_loss),避免所有 token 都路由到同一个专家
  2.     if self.training and self.alpha > 0.0: # 只在训练时计算 aux_loss,判断alpha是否大于0是因为如果alpha=0则不需要计算aux_loss
  3.         scores_for_aux = scores # (bsz*seq_len, n_experts)
  4.         aux_topk = self.top_k # 每个 token 选择的专家数
  5.         # 将 topk_idx 转换为二维形状,方便后续计算
  6.         topk_idx_for_aux_loss = topk_idx.view(bsz, -1) # (bsz, seq_len*top_k) 每一行是一个bsz的所有句子token选择的专家索引,排列在一起共计seq_len*k个
  7.         
  8.         if self.seq_aux: # 是否使用序列级别的辅助损失
  9.             scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) # (B,seq_len,n_experts)
  10.             ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) # (bsz, n_experts) 一个bsz等于一个句子,ce用于累计每个句子中各专家的使用频率
  11.             # 统计方式:每一个bsz内累加专家使用次数,然后除以归一化常数
  12.             ce.scatter_add_(
  13.                 1, topk_idx_for_aux_loss,
  14.                 torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
  15.             ).div_(seq_len * aux_topk / self.n_routed_experts)
  16.             # 计算scores_for_seq_aux.mean(dim=1)  (B,seq_len,n_experts)-> (B,n_experts) 每个bsz内各专家获得的平均打分
  17.             # ce (bsz, n_experts) 每个bsz内各专家的实际使用频率
  18.             aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
复制代码
从代码来看,aux_loss的含义比较清晰,就是打分(归一化概率分布)和实际使用频率的乘积之和。这一项存在约束:分数之和是固定的,而实际使用频率和打分的分布大致是正相关的。在这样的约束下,我们感性上便可以发现(类似于基本不等式),如果专家使用的不均衡,那么这一项auxloss应该会更大一些。反之,如果专家使用均衡,那么这一项auxloss会更小一些。因此,最小化auxloss的目标,实际上是鼓励专家使用均衡。更详细的、严格的解释可见:https://spaces.ac.cn/archives/10735
aux_loss的另一种计算方式是token级别的aux_loss(原理类似):
  1.         else:
  2.             # token 级别:类似 one-hot,鼓励负载均衡
  3.             # (bsz, seq_len*top_k) -> (bsz*seq_len*top_k, n_experts)
  4.             mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
  5.             ce = mask_ce.float().mean(0)          # 每个专家的实际使用比例
  6.             Pi = scores_for_aux.mean(0)           # 理论分布
  7.             fi = ce * self.n_routed_experts       # 归一化因子
  8.             aux_loss = (Pi * fi).sum() * self.alpha
  9.     else:
  10.         aux_loss = 0
复制代码
最后,返回topk的索引、权重和aux_loss
  1.     return topk_idx, topk_weight, aux_loss
复制代码
MOE的FFN

先看一般的前馈
  1. class MOEFeedForward(nn.Module):
  2.     def __init__(self, config: MiniMindConfig):
  3.         super().__init__()
  4.         self.config = config
  5.         self.experts = nn.ModuleList([
  6.             FeedForward(config)
  7.             for _ in range(config.n_routed_experts)
  8.         ])
  9.         self.gate = MoEGate(config)
  10.         if config.n_shared_experts > 0:
  11.             self.shared_experts = nn.ModuleList([
  12.                 FeedForward(config)
  13.                 for _ in range(config.n_shared_experts)
  14.             ])
  15.     # 输入 形状: (bsz, seq_len, hidden_size)
  16.     # 输出 形状: (bsz, seq_len, hidden_size)
  17.     def forward(self, x):
  18.         identity = x
  19.         orig_shape = x.shape
  20.         bsz, seq_len, _ = x.shape
  21.         # 使用门控机制选择专家
  22.         topk_idx, topk_weight, aux_loss = self.gate(x) # topk_idx: (bsz*seq_len, top_k), topk_weight: (bsz*seq_len, top_k)
  23.         x = x.view(-1, x.shape[-1]) # (bsz*seq_len, hidden_size)
  24.         flat_topk_idx = topk_idx.view(-1) # (bsz*seq_len*top_k, ) # 所有句子所有token的topk id选择排列为一行
  25.         if self.training:
  26.             x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) # (bsz*seq_len*top_k, hidden_size)
  27.             y = torch.empty_like(x, dtype=torch.float16) # (bsz*seq_len*top_k, hidden_size)
  28.             for i, expert in enumerate(self.experts): # 遍历所有专家
  29.                 y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)  # 所有使用了专家i的token都送进去expert i计算,赋值到y对应位置
  30.             # 按照 topk_weight 加权求和
  31.             # (bsz*seq_len,top_k, hidden_size)
  32.             y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) # (bsz*seq_len,top_k, hidden_size)* (bsz*seq_len,top_k,1) sum-> (bsz*seq_len, hidden_size)
  33.             y = y.view(*orig_shape) # (bsz, seq_len, hidden_size)
  34.         else: # 推理时
  35.             y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
  36.         if self.config.n_shared_experts > 0:
  37.             for expert in self.shared_experts:
  38.                 y = y + expert(identity)
  39.         self.aux_loss = aux_loss
  40.         return y
复制代码
推理时的MOE计算:
  1.     @torch.no_grad()
  2.     def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
  3.         expert_cache = torch.zeros_like(x)
  4.         idxs = flat_expert_indices.argsort()
  5.         tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
  6.         token_idxs = idxs // self.config.num_experts_per_tok
  7.         # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
  8.         # 且token_idxs = [3, 7, 19, 21, 24, 25,  4,  5,  6, 10, 11, 12...] 时
  9.         # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
  10.         # 接下来9个位置token_idxs[6:15] -> [4,  5,  6, 10, 11, 12...]属于专家1处理的token...依此类推
  11.         for i, end_idx in enumerate(tokens_per_expert):
  12.             start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
  13.             if start_idx == end_idx:
  14.                 continue
  15.             expert = self.experts[i]
  16.             exp_token_idx = token_idxs[start_idx:end_idx]
  17.             expert_tokens = x[exp_token_idx]
  18.             expert_out = expert(expert_tokens).to(expert_cache.dtype)
  19.             expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
  20.             expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
  21.         return expert_cache
复制代码
完全的架构

单个Minimind block  = GQA + FFN/ MOE

MOEFFN
11.png
12.png
  1. class MiniMindBlock(nn.Module):
  2.     def __init__(self, layer_id: int, config: MiniMindConfig):
  3.         super().__init__()
  4.         # 注意力头的数量
  5.         self.num_attention_heads = config.num_attention_heads
  6.         # 隐层维度大小
  7.         self.hidden_size = config.hidden_size
  8.         # 每个注意力头的维度 = hidden_size / num_heads
  9.         self.head_dim = config.hidden_size // config.num_attention_heads
  10.         # 自注意力层(包含 QKV 计算、多头注意力、输出映射等)
  11.         self.self_attn = Attention(config)
  12.         # 层的编号(主要用于模型内部调试或分布式并行时定位)
  13.         self.layer_id = layer_id
  14.         # 输入到 Attention 前的 RMSNorm(Pre-LN 结构)
  15.         self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  16.         # 输入到 FFN/MoE 前的 RMSNorm(第二个 Pre-LN)
  17.         self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  18.         # 前馈网络:可以是普通 FFN,也可以是 MoE 版本
  19.         self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
  20.     def forward(
  21.         self,
  22.         hidden_states,        # 输入序列 (B, L, d),B=batch,L=序列长度,d=隐层维度
  23.         position_embeddings,  # 位置编码,用于注意力计算
  24.         past_key_value=None,  # KV cache(推理时加速用)
  25.         use_cache=False,      # 是否启用 KV 缓存
  26.         attention_mask=None   # 注意力 mask(避免看未来或 padding 部分)
  27.     ):
  28.         # === 1. Self-Attention 子层 ===
  29.         residual = hidden_states  # 保存残差
  30.         hidden_states, present_key_value = self.self_attn(
  31.             self.input_layernorm(hidden_states),  # LN -> Attention
  32.             position_embeddings,
  33.             past_key_value,
  34.             use_cache,
  35.             attention_mask
  36.         )
  37.         hidden_states += residual  # 残差连接: H = H + Attention(LN(H))
  38.         # === 2. FFN/MoE 子层 ===
  39.         residual = hidden_states  # 再次保存残差
  40.         hidden_states = hidden_states + self.mlp(
  41.             self.post_attention_layernorm(hidden_states)  # LN -> MLP
  42.         )
  43.         # 输出处理后的序列表示 + KV cache
  44.         return hidden_states, present_key_value
复制代码
整体Bone

1. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)


  • 作用:这是一个 可训练的查表层(lookup table),常用于词向量(token embedding)。
  • 输入:一个 LongTensor,形状 (batch_size, seq_len),里面是 token 的 整数索引(范围 [0, vocab_size-1])。
  • 输出:一个 FloatTensor,形状 (batch_size, seq_len, hidden_size),即每个 token 被映射为一个 hidden_size 维的向量。
  • 训练性:参数(embedding 矩阵大小为 (vocab_size, hidden_size))是 可学习的,会在训练时更新。
    除非你加载了某个预训练好的 embedding 矩阵,否则默认是随机初始化。
比如:
  1. import torch
  2. import torch.nn as nn
  3. embed = nn.Embedding(1000, 64)   # 1000个token,维度64
  4. x = torch.tensor([[1, 5, 8], [2, 9, 3]])  # batch=2, seq_len=3
  5. out = embed(x)   # (2, 3, 64)
复制代码
2. register_buffer 的作用
  1. self.register_buffer("freqs_cos", freqs_cos, persistent=False)
  2. self.register_buffer("freqs_sin", freqs_sin, persistent=False)
复制代码

  • 作用:把 freqs_cos 和 freqs_sin 注册为 buffer
  • 区别于参数

    • 参数 (nn.Parameter):会被优化器更新(训练时变化)。
    • buffer:不会训练更新,但会随着 model.state_dict() 保存/加载。

  • 是不是全局变量?
    不是全局变量,而是模型内部的持久状态。你可以通过 self.freqs_cos 访问,但它不在 Python 全局命名空间里,只属于模型对象。
  • 能不能随处访问?
    可以在模型的方法里随意用,但要通过模型实例访问,比如:
    1. model.freqs_cos   # ✅
    2. freqs_cos         # ❌(除非单独定义)
    复制代码
所以 register_buffer 的本质是:模型的常量状态,存着就行,不要训练更新
3. forward 最终输出


  • 输出
    1. hidden_states  # shape = (batch_size, seq_len, hidden_size)
    复制代码
  • 这相当于是 transformer encoder/decoder 最后一层的 上下文表示,还没有做 softmax。
  1. class MiniMindModel(nn.Module):
  2.     def __init__(self, config: MiniMindConfig):
  3.         super().__init__()
  4.         self.config = config
  5.         # 模型的词表大小和层数
  6.         self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
  7.         # 词嵌入层 (token embedding),把输入的 token id 映射到 hidden_size 维的向量
  8.         self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  9.         # dropout 用于防止过拟合
  10.         self.dropout = nn.Dropout(config.dropout)
  11.         # 堆叠多个 Transformer Block,每层是 MiniMindBlock
  12.         self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
  13.         # 最后的归一化层,使用 RMSNorm(比 LayerNorm 更高效)
  14.         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  15.         # 预计算旋转位置编码 RoPE 所需的 cos 和 sin 值
  16.         freqs_cos, freqs_sin = precompute_freqs_cis(
  17.             dim=config.hidden_size // config.num_attention_heads,  # 每个注意力头的维度
  18.             end=config.max_position_embeddings,                   # 最大支持的序列长度
  19.             theta=config.rope_theta                                # RoPE 的缩放参数
  20.         )
  21.         # 注册为 buffer,表示这些不是参数(不会参与训练),但会随着模型保存/加载
  22.         self.register_buffer("freqs_cos", freqs_cos, persistent=False)
  23.         self.register_buffer("freqs_sin", freqs_sin, persistent=False)
  24.     def forward(self,
  25.                 input_ids: Optional[torch.Tensor] = None,             # 输入 token 序列 [batch_size, seq_length]
  26.                 attention_mask: Optional[torch.Tensor] = None,        # 注意力掩码
  27.                 past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,  # KV 缓存
  28.                 use_cache: bool = False,                              # 是否启用 KV 缓存(加速推理)
  29.                 **kwargs):
  30.         # 获取 batch_size 和序列长度
  31.         batch_size, seq_length = input_ids.shape
  32.         # 如果没有传递 KV 缓存,初始化为 None
  33.         past_key_values = past_key_values or [None] * len(self.layers)
  34.         # 如果有缓存,确定从哪个位置开始(start_pos)
  35.         start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
  36.         # 将 token id 转换为向量,并做 dropout
  37.         hidden_states = self.dropout(self.embed_tokens(input_ids))
  38.         # 取出对应序列长度的 RoPE 位置编码 (cos, sin)
  39.         position_embeddings = (
  40.             self.freqs_cos[start_pos:start_pos + seq_length],
  41.             self.freqs_sin[start_pos:start_pos + seq_length]
  42.         )
  43.         # 保存每一层的 KV,用于下次增量推理
  44.         presents = []
  45.         # 遍历每一层 Transformer Block
  46.         for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
  47.             hidden_states, present = layer(
  48.                 hidden_states,             # 当前层输入
  49.                 position_embeddings,       # RoPE 编码
  50.                 past_key_value=past_key_value,
  51.                 use_cache=use_cache,
  52.                 attention_mask=attention_mask
  53.             )
  54.             presents.append(present)       # 保存当前层的 KV
  55.         # 最后一层 RMSNorm
  56.         hidden_states = self.norm(hidden_states)
  57.         # 如果是 MoE 层,需要计算 auxiliary loss(负载均衡损失)
  58.         aux_loss = sum(
  59.             layer.mlp.aux_loss
  60.             for layer in self.layers
  61.             if isinstance(layer.mlp, MOEFeedForward)
  62.         )
  63.         # 输出:最后的 hidden states( shape = (batch_size, seq_len, hidden_size))、每层 KV 缓存、MoE 的辅助损失
  64.         return hidden_states, presents, aux_loss
复制代码
Head(下游任务和因果语言建模)

上文,我们得到了最终的hidden_states,形状为(batch_size, seq_len, hidden_size)。但我们知道,大语言模型的本质是把下一个token的预测问题转化为一个类别数量=词表大小的多分类问题,这需要我们把hidden_states通过softmax转换为词表大小的维度上输出概率。因此,最后的代码实际上起到了这样的作用。
  1. class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
  2.     # -------------------------------
  3.     # 1. HuggingFace 的规范要求
  4.     # -------------------------------
  5.     # PreTrainedModel 是 HuggingFace 的基类,负责参数保存/加载、权重初始化等。
  6.     # GenerationMixin 提供了 generate() 方法的实现(自回归解码用)。
  7.     # config_class 是 HuggingFace 约定的字段,告诉框架用什么配置类来构建模型。
  8.     config_class = MiniMindConfig
  9.     def __init__(self, config: MiniMindConfig = None):
  10.         # --------------------------------
  11.         # 2. 初始化配置
  12.         # --------------------------------
  13.         # 如果没有传入配置,就使用默认的 MiniMindConfig。
  14.         self.config = config or MiniMindConfig()
  15.         # 必须调用父类的 __init__ 来注册 config(PreTrainedModel 需要它)。
  16.         super().__init__(self.config)
  17.         # --------------------------------
  18.         # 3. 模型主体
  19.         # --------------------------------
  20.         # 主体是 MiniMindModel (相当于 Transformer Encoder/Decoder 堆叠)
  21.         self.model = MiniMindModel(self.config)
  22.         # 语言建模头(LM Head)
  23.         # 线性层: (hidden_size -> vocab_size)
  24.         # 输入: (bsz, seq_len, hidden_size)
  25.         # 输出: (bsz, seq_len, vocab_size)
  26.         self.lm_head = nn.Linear(
  27.             self.config.hidden_size, self.config.vocab_size, bias=False
  28.         )
  29.         # 权重 tying (权重共享):
  30.         # 将 embedding 层和输出层共享参数,以减少参数量并提升泛化。
  31.         # 两者形状都是 (vocab_size, hidden_size)。
  32.         self.model.embed_tokens.weight = self.lm_head.weight
  33.         # HuggingFace 约定的输出容器(dict-like),存储 logits/hidden_states 等。
  34.         self.OUT = CausalLMOutputWithPast()
  35.     def forward(self,
  36.                 input_ids: Optional[torch.Tensor] = None,   # (bsz, seq_len),输入 token 序列
  37.                 attention_mask: Optional[torch.Tensor] = None, # (bsz, seq_len),mask 用于避免 padding 干扰
  38.                 past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
  39.                 # past_key_values 用于缓存 kv,推理时减少重复计算
  40.                 use_cache: bool = False,  # 是否启用缓存(加速生成)
  41.                 logits_to_keep: Union[int, torch.Tensor] = 0,
  42.                 **args):
  43.         # -------------------------------
  44.         # 4. 主干模型前向
  45.         # -------------------------------
  46.         # h: (bsz, seq_len, hidden_size) -> 隐藏状态
  47.         # past_kvs: List[...] -> 缓存的 KV,用于加速生成
  48.         # aux_loss: 可能存在的额外损失(如 MoE 专家均衡损失)
  49.         h, past_kvs, aux_loss = self.model(
  50.             input_ids=input_ids,
  51.             attention_mask=attention_mask,
  52.             past_key_values=past_key_values,
  53.             use_cache=use_cache,
  54.             **args
  55.         )
  56.         # -------------------------------
  57.         # 5. logits 截取
  58.         # -------------------------------
  59.         # logits_to_keep 用于控制返回哪些时间步的预测结果:
  60.         # - 若为 int,如 1,表示只保留最后 1 个 token 的预测结果(常见于推理)。
  61.         # - 若为 0,默认保留全部。
  62.         # - 若为张量,可自定义索引。
  63.         slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  64.         # 计算输出 logits
  65.         # h[:, slice_indices, :] : (bsz, slice_len, hidden_size)
  66.         # lm_head -> (bsz, slice_len, vocab_size)
  67.         logits = self.lm_head(h[:, slice_indices, :])
  68.         # -------------------------------
  69.         # 6. 按 HuggingFace 规范组织输出
  70.         # -------------------------------
  71.         # 存储最后隐藏状态 (训练时可能需要)
  72.         self.OUT.__setitem__('last_hidden_state', h)  # (bsz, seq_len, hidden_size)
  73.         # 存储 logits (训练/推理都需要)
  74.         self.OUT.__setitem__('logits', logits)  # (bsz, slice_len, vocab_size)
  75.         # 存储额外损失(如 MoE 负载均衡损失)
  76.         self.OUT.__setitem__('aux_loss', aux_loss)  # (标量)
  77.         # 存储 past_kvs,推理时用于缓存
  78.         self.OUT.__setitem__('past_key_values', past_kvs)
  79.         return self.OUT
复制代码
重要参考文献


  • 原始开源项目:https://github.com/jingyaogong/minimind
    13.jpg
    14.jpg

  • https://zhuanlan.zhihu.com/p/28786272137 介绍了transfomer中使用的全部mask,从目的到具体形式,图文并茂,写的很清晰,感谢作者给出这样优质的博客!
  • https://spaces.ac.cn/archives/10699 苏剑林大佬的MOE环游记系列,清晰地介绍了MOE的提出动机,提高效率的原理,以及aux_loss的设计
  • Rope(Su. et al.):https://spaces.ac.cn/archives/8265
  • 有关transfomer的基础教程,网络上很多了,不再一一赘婿。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册