找回密码
 立即注册
首页 业界区 业界 探秘Transformer系列之(31)--- Medusa

探秘Transformer系列之(31)--- Medusa

这帜 2025-6-2 22:43:14
探秘Transformer系列之(31)--- Medusa


目录

  • 探秘Transformer系列之(31)--- Medusa

    • 0x00 概述
    • 0x01 原理

      • 1.1 动机
      • 1.2 借鉴
      • 1.3 思路

        • 1.3.1 单模型 & 多头
        • 1.3.2 Tree 验证
        • 1.3.3 小结


    • 0x02 设计核心点

      • 2.1 流程
      • 2.2 模型结构
      • 2.3 多头

        • 2.3.1 head结构
        • 2.3.2 位置

      • 2.4 缺点

    • 0x03 Tree Verification

      • 3.1 解码路径
      • 3.2 最佳构造方式
      • 3.3 实现

        • 3.3.1 关键变量
        • 3.3.2 示例代码
        • 3.3.3 总体可视化
        • 3.3.4 使用

          • 调用
          • 初始化
          • 生成候选路径
          • 验证候选路径
          • 计算最优路径


      • 3.4 Typical Acceptance

        • 3.4.1 常见采用方法
        • 3.4.2 思路
        • 3.4.3 Typical Acceptance


    • 0x04 训练

      • 4.1 MEDUSA-1
      • 4.2 MEDUSA-2
      • 4.3 代码

    • 0x05 Decoding

      • 5.1 示例
      • 5.2 计算和空间复杂度

    • 0xFF 参考


0x00 概述

Medusa 是自投机领域较早的一篇工作,对后续工作启发很大,其主要思想是multi-decoding head + tree attention + typical acceptance(threshold)。Medusa 没有使用独立的草稿模型,而是在原始模型的基础上增加多个解码头(MEDUSA heads),并行预测多个后续 token。
正常的LLM只有一个用于预测t时刻token的head。Medusa 在 LLM 的最后一个 Transformer层之后保留原始的 LM Head,然后额外增加多个(假设是k个) 可训练的Medusa Head(解码头),分别负责预测t+1,t+2,...,和t+k时刻的不同位置的多个 Token。Medusa 让每个头生成多个候选 token,而非像投机解码那样只生成一个候选。然后将所有的候选结果组装成多个候选序列,多个候选序列又构成一棵树。再通过树注意力机制并行验证这些候选序列。
注:全部文章列表在这里,估计最终在35篇左右,后续每发一篇文章,会修改此文章列表。
cnblogs 探秘Transformer系列之文章列表
0x01 原理

1.1 动机

1.jpeg

投机采样的核心思路如上图下方所示,首先以低成本的方式(一般来说是用小模型)快速生成多个候选 Token,然后通过一次并行验证阶段快速验证多个 Token,进而减少大模型的 Decoding Step,实现加速的目的。然而,采用一个独立的“推测”模型也有缺点,具体如下:

  • 很难找到一个小而强的模型来生成对于原始的模型来说比较简单的token。

    • draft模型和大模型很难对齐,存在distribution shift。
    • 并不是所有的LLM都能找到现成的小模型。重新训练一个小模型需要较多的额外投入。

  • 在一个系统中维护2个不同的模型,即增加了推理过程的计算复杂度,也导致架构上的复杂性,在分布式系统上的部署难度增大。
  • 使用投机采样的时候,会带来额外的解码开销,尤其是当使用一个比较高的采样温度值时。
1.2 借鉴

Medua主要借鉴了两个工作:BPD和SpecInfer。

  • 大模型自身带有一个LM head,用于把隐藏层输出映射到词表的概率分布,以实现单个token的解码。为了生成多个token,论文“Blockwise Parallel Decoding for Deep Autoregressive Models”在骨干模型上使用多个解码头来加速推理,通过训练辅助模型,使得模型能够预测未来位置的输出,然后利用这些预测结果来跳过部分贪心解码步骤,从而加速解码过程。
  • 论文“SpecInfer: Accelerating Generative Large Language Model Serving with Speculative Inference and Token Tree Verification”的思路是:既然小模型可以猜测大模型的输出并且效率非常高,那么一样可以使用多个小模型来猜测多个 Token 序列,这样提供的候选更多,猜对的机会也更大;为了提升这多个 Token 序列的验证效率,作者提出 Token Tree Attention 的机制,首先将多个小模型生成的多个 Token 序列组合成 Token 树,然后将其展开输入模型,即可实现一次 decoding step 完成整个 Token 树的验证。
1.3 思路

基于这两个思路来源,Medusa决定让target LLM自己进行预测,即在target LLM最后一层decoder layer之上引入了多个额外的预测头,使得模型可以在每个解码步并行生成多个token,作为“推测”结果。我们进行具体分析。
1.3.1 单模型 & 多头

为了抛弃独立的 Draft Model,只保留一个模型,同时保留 Draft-then-Verify 范式,Medusa 在主干模型的最终隐藏层之后添加了若干个 Medusa Heads,每个解码头是一个带残差连接的单层前馈网络。这些Medusa Heads是对BPD中多 Head 的升级,即由原来的一个 Head 生成一个 Token 变成一个 head 生成多个候选 Token。因为这些 Heads 具有预测对应位置 token 的能力,并且可以并行地执行,因此可以实现在一次前向中得到多个 draft tokens。具体如下图所示。
可能有读者会有疑问,后面几个head要跨词预测,其准确率应该很难保证吧?确实是这样的,但是,如果我每个预测时间步都取top3出来,那么最终预测成功的概率就高不少了。而且,Medusa 作者观察到,虽然在预测 next next Token 的时候 top1 的准确率可能只有 60%,但是如果选择 top5,则准确率有可能超过 80%。而且,因为 MEDUSA 解码头与原始模型共享隐藏层状态,所以分布差异较小。
2.jpeg

1.3.2 Tree 验证

因为贪心解码的正确率不够高,加速效果不够显著,因此Medusa让每个Head解码top-k个候选,不同head的候选集合组成一个树状结构。为了更高效地验证这些 draft tokens,Medusa根据这些 Head 生成 Token 的笛卡尔积来构建出多个 Token 序列。然后使用Tree Attention方法,在注意力计算中,只允许同一延续中的 token 互相看到(attention mask),再加上位置编码的配合,就可以在不增加 batch size 的情况下并行处理多个候选。
Medusa 中的树和注意力掩码矩阵如下图所示。在每一跳中,我们看到图中Medusa保留了多个可能的token,也就是概率最高的几个token。这样构成了所谓的树结构,直观来说,就是每1跳的每1个token都可能和下1跳的所有token组合成句子,也可以就在这1跳终止。例如,在图中,一共2个head生成了2跳的token,那么这棵树包含了6种可能的句子:Head 1 在下一个位置生成 2 个可能的 Token(It 和 I),Head 2 在下下一个位置生成 3 个可能的 Token(is,’ 和 the),这样下一个位置和下下一个位置就有了 2 x 3 = 6 种可能的候选序列,如下图左侧所示。
而其对应的 Attention Mask 矩阵如右侧所示。与原始投机解码略有不同的地方是,树中有多条解码路径,不同解码路径之间不能相互访问。比如,(1) "It is"和 (2) "I is"是两条路径,那么在计算(1).is的概率分布时,只能看到(1).it,而不能看到(2)中的"I"。因此,Medusa新建了在并行计算多条路径概率分布时需要的attention mask,称为"Tree attention"。本质上就是同一条路径内遵从因果mask的规则,不同路径之间不能相互访问。
Medusa作者称,SpecInfer中每个speculator生成称的序列长度不同,所以Mask是动态变化的。而Medusa的Tree Attention Mask在Infrence过程中是静态不变的,这使得对树注意力Mask的预处理进一步提高了效率。
3.jpeg

1.3.3 小结

下表给出了BPD,SpecInfer,Medusa之间的差异。
领域Blockwise Parallel DecodingSpecInferMedusa多模型没有真的构造出k-1个辅助模型,只对原始模型略作改造,让其具备预测后k个token的能力采用一批small speculative models(SSMs),并行预测多个候选SSM,可以是原始LLM的蒸馏、量化、剪枝版本多头加入k个project layer,这k个project layer的输出就是k个不同位置token的logits在 LLM 的最后一个 Transformer Layer 之后保留原始的 LM Head,然后额外增加多个 Medusa Head,获得多个候选的 Token 序列Tree将SSMs预测的多个候选merge为一个新的token tree,采用原始LLM做并行验证。SpecInfer中每个speculator生成称的序列长度不同,所以Mask是动态变化的。Medusa的Tree Attention Mask在Infrence过程中是静态不变的,这使得对树注意力Mask的预处理进一步提高了效率。训练重新训练原始模型训练小模型并不需要重新训练整个大模型,而是冻结大模型而只训练解码头0x02 设计核心点

2.1 流程

MEDUSA的大致思路和投机解码类似,其中每个解码步骤主要由三个子步骤组成:

  • 生成候选者。MEDUSA通过接在原模型的多个Medusa解码头来获取多个位置的候选token
  • 处理候选者。MEDUSA把各个位置的候选token进行处理,选出一些候选序列。然后通过tree attention来进行验证。由于 MEDUSA 头位于原始模型之上,因此,此处计算的 logits可以用于下一个解码步骤。
  • 接受候选者。通过typical acceptance(典型接受)来选择最终输出的结果。
Medusa更大的优势在于,除了第一次Prefill外,后续可以达到边verify边生成的效果,即 Medusa 的推理流程可以理解:Prefill + Verify + Verify + ...。
2.2 模型结构

下面代码给出了美杜莎的模型结构。Medusa 是在 LLM 的最后一个 Transformer Layer 之后保留原始的 LM Head,然后额外加多个 Medusa Head,也就是多个不同分支输出。这样可以预测出多个候选的 Token 序列。
Medusa head的输入是大模型的隐藏层输出。这是和使用外挂小模型投机解码的另一个重要不同。外挂小模型的输入是查表得到的token embedding,比这里的大模型最后一层隐藏层要弱的多,因此比较依赖小模型的性能。正是因为借助大模型的隐藏层输出,这里的Medusa head的结构都十分简单。
  1. class MedusaLlamaModel(KVLlamaForCausalLM):
  2.     """The Medusa Language Model Head.
  3.     This module creates a series of prediction heads (based on the 'medusa' parameter)
  4.     on top of a given base model. Each head is composed of a sequence of residual blocks
  5.     followed by a linear layer.
  6.     """
  7.     def __init__(
  8.         self,
  9.         config,
  10.     ):
  11.         # Load the base model
  12.         super().__init__(config)
  13.         # For compatibility with the old APIs
  14.         medusa_num_heads = config.medusa_num_heads
  15.         medusa_num_layers = config.medusa_num_layers
  16.         base_model_name_or_path = config._name_or_path
  17.         self.hidden_size = config.hidden_size
  18.         self.vocab_size = config.vocab_size
  19.         self.medusa = medusa_num_heads
  20.         self.medusa_num_layers = medusa_num_layers
  21.         self.base_model_name_or_path = base_model_name_or_path
  22.         self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
  23.         # Create a list of Medusa heads
  24.         self.medusa_head = nn.ModuleList(
  25.             [
  26.                 nn.Sequential(
  27.                     *([ResBlock(self.hidden_size)] * medusa_num_layers),
  28.                     nn.Linear(self.hidden_size, self.vocab_size, bias=False),
  29.                 )
  30.                 for _ in range(medusa_num_heads)
  31.             ]
  32.         )
  33.         
  34.     def forward(
  35.         self,
  36.         input_ids=None,
  37.         attention_mask=None,
  38.         past_key_values=None,
  39.         output_orig=False,
  40.         position_ids=None,
  41.         medusa_forward=False,
  42.         **kwargs,
  43.     ):
  44.         """Forward pass of the MedusaModel.
  45.         Args:
  46.             input_ids (torch.Tensor, optional): Input token IDs.
  47.             attention_mask (torch.Tensor, optional): Attention mask.
  48.             labels (torch.Tensor, optional): Ground truth labels for loss computation.
  49.             past_key_values (tuple, optional): Tuple containing past key and value states for attention.
  50.             output_orig (bool, optional): Whether to also output predictions from the original LM head.
  51.             position_ids (torch.Tensor, optional): Position IDs.
  52.         Returns:
  53.             torch.Tensor: A tensor containing predictions from all Medusa heads.
  54.             (Optional) Original predictions from the base model's LM head.
  55.         """
  56.         if not medusa_forward:
  57.             return super().forward(
  58.                 input_ids=input_ids,
  59.                 attention_mask=attention_mask,
  60.                 past_key_values=past_key_values,
  61.                 position_ids=position_ids,
  62.                 **kwargs,
  63.             )
  64.         with torch.inference_mode():
  65.             # Pass input through the base model
  66.             outputs = self.base_model.model(
  67.                 input_ids=input_ids,
  68.                 attention_mask=attention_mask,
  69.                 past_key_values=past_key_values,
  70.                 position_ids=position_ids,
  71.                 **kwargs,
  72.             )
  73.             if output_orig:
  74.                 # 原始模型输出
  75.                 orig = self.base_model.lm_head(outputs[0])
  76.         # Clone the output hidden states
  77.         hidden_states = outputs[0].clone()
  78.         medusa_logits = []
  79.         # TODO: Consider parallelizing this loop for efficiency?
  80.         for i in range(self.medusa):
  81.             # 美杜莎头输出
  82.             medusa_logits.append(self.medusa_head[i](hidden_states))
  83.         if output_orig:
  84.             return torch.stack(medusa_logits, dim=0), outputs, orig
  85.         return torch.stack(medusa_logits, dim=0)        
复制代码
2.3 多头

2.3.1 head结构

Medusa 额外新增 medusa_num_heads 个 Medusa Head,每个 Medusa Head 是一个加上了残差连接的单层前馈网络,其中的 Linear 和模型的默认 lm_head 维度一样,这样可以预测后续的 Token。
  1. self.medusa_head = nn.ModuleList(
  2.     [
  3.         nn.Sequential(
  4.             *([ResBlock(self.hidden_size)] * medusa_num_layers),
  5.             nn.Linear(self.hidden_size, self.vocab_size, bias=False),
  6.         )
  7.         for _ in range(medusa_num_heads)
  8.     ]
  9. )
复制代码
下面代码为打印出来的实际内容。
  1. ModuleList(
  2.   (0-3): 4 x Sequential(
  3.     (0): ResBlock(
  4.       (linear): Linear(in_features=4096, out_features=4096, bias=True)
  5.       (act): SiLU()
  6.     )
  7.     (1): Linear(in_features=4096, out_features=32000, bias=False)
  8.   )
  9. )
复制代码
把第k个解码头在词表上的输出分布记作 \(p_t^{(t)}\),其计算方式如下。d是hidden state的输出维度,V是词表大小,原始模型的预测表示为  \(p_t^{(0)}\)  。
4.jpeg

下面是把代码和模型结构结合起来的示意图。
5.jpeg

2.3.2 位置

Medusa每个头预测的偏移量是不同的,第k个头用来预测位置t+k+1的输出token(k的取值是1~K)。原模型的解码头依然预测位置t+1的输出,相当于k=0。具体而言,把原始模型在位置t的最后隐藏状态 \(ℎ_t\)接入到K个解码头上,对于输入token序列 \(t_0,t_1,..,t_i\),原始的head根据输入预测$ t_{i+1}$,Medusa新增的第一个head根据输入预测  \(t_{i+2}\)的token,也就是跳过token  \(t_{i+1}\) 预测下一个未来的token。并且每个头可以指定topk个结果。这些头的预测结果构成了多个候选词汇序列,然后利用树形注意力机制同时处理这些候选序列。在每个解码步,选择最长被接受的候选序列作为最终的预测结果。这样,每步可以预测多个词汇,从而减少了总的解码步数,提高了推理速度。
如下图所示,Medusa在原始模型基础上,增加了3个额外的Head,可以并行预测出后4个token的候选。
6.jpeg

2.4 缺点

Medusa的缺点如下:

  • Medusa 新增的 lm_head 和最后一个 Transformer Block 中间只有一个 MLP,表达能力可能有限。
  • Medusa 增加了模型参数量,会增加显存占用;
  • Medusa 每个 head 都是独立执行的,也就是 “next next token” 预测并不会依赖上一个 “next token” 的结果,导致生成效果不佳,接受率比较低,在大 batch size 时甚至可能负优化。
  • 缺乏序列依赖也可能导致低效的树剪枝算法。
  • 草稿质量仍然不高,加速效果有限,并且在非贪婪解码 (non-greedy decoding) 下不能保证输出分布与目标LLM一致。
因此,后续有研究工作对此进行了改进。比如Clover重点是提供序列依赖和加入比单个 MLP 具有更强的表征能力的模块。Hydra 增加了 draft head 预测之间的关联性。Hydra++使用 base model 的输出预测概率作为知识蒸馏的教师模型输出来训练 draft head。并且类似EAGLE,Hydra++增加一个独立的 decoder layer,每个 Hydra head 除了上一个 token 本身,还添加了上一个 token 在这个 decoder layer 的 representation 作为输入。
0x03 Tree Verification

每个Medusa Head 会生成 top-k 个预测标记,然后通过计算这些预测的笛卡尔积来形成候选序列。我们可以对于每个候选序列都走一遍模型来验证,但是这样做太耗时。因此,Medusa 作者设计了一种tree attention的机制,在候选树内进行掩码操作,掩码限制某个token对前面token的注意力。同时,也要为相应地为position embedding设置正确的位置索引。因为有 tree attention 的存在,所以 Medusa 可以并行地构建、维护和验证多个候选序列。
3.1 解码路径

在Medusa中,基础版本解码采用greedy方式取Top-1 Token。Medusa增加额外的解码头之后,使用 Top-K Sampling,每一个 Head 都会输出 k 个 tokens。不同Medusa头之间预测结果不一致。\(p(t_{t+1}|t_0,...,t_i)\)和\(p(t_{t+2}|t_0,...,t_i)\)形式上是条件独立的,但是实际上\(p_{t+2}\)依赖\(p_{t+1}\),不能直接取\(p(t_{t+1}|t_0,...,t_i)\)和\(p(t_{t+2}|t_0,...,t_i)\)最大的token作为verify阶段的输入,这样组成的句子可能会在逻辑上不一致。因此,Medusa还引入采样topk组合作为候选序列的方式去缓解这个问题。最终把LM_head 的输出作为根节点构建出树状结构,树的深度自顶向下遍历称为解码路径(论文中叫做candidates path)。每个候选序列可以表示所构建的tree上的一条路径上所有的node(而不只是leaf node,因为tree attention验证的时候会把路径上所有token都进行验证)。
由于有K个head,每个head选取\(\text{top-}s_k\)个预测输出,则所有路径可能组合为树上所有节点的总和,即\(\sum_{k=1}^K \prod_{i=1}^k s_i\)。在构建树形结构时,最简单的方法是通过笛卡尔积来获取多个解码头组成的所有可能的候选序列。下图例子使用了Cartesian product对两个解码头的结果进行处理,获得所有候选序列。具体来说就是将每个头的top-k个词作为节点,每个头作为树的一层。图上一共存在6条解码路径,相当于 Head 1 在下一个位置生成 2 个可能的 Token(It 和 I),Head 2 在下下一个位置生成 3 个可能的 Token(is,’ 和 the),这样下一个位置和下下一个位置就有了 2 x 3 = 6 种可能的候选序列。为了区分不同的 prefix,Medusa 设置了一些冗余,例如 Head 2 的三个预测 token 均出现了两次,这是为了分别对应 It 和 I 这两个不同的 prefix。每个 token 在 tree mask 的作用下只可以看见自己的 prefix。
7.jpeg

3.2 最佳构造方式

上图采用top-3,两个头一共有6条候选路径。如果解码头数量数量比较多,每个头给出的候选token也比较多。解码路径会随着Top-k 和头数增多急剧增加,会产生大量的候选路径,具有庞大的搜索空间。虽然增加候选序列的数量,最终接受token的命中率就会提升,但是验证更多候选序列也会带来额外的计算消耗。那么新的问题是:

  • 如何能减少候选解码路径?
  • 如何能在候选解码路径中,得到最优解码路径?
直观上来说,那些由不同头部的topk预测组成的候选结果可能具有不同的准确率。应该优先选择那些更准确的预测,从而构建一个更有效的树,而不需要使用所有可能的排列组合。Medusa 通过在校准数据集上统计每个解码头的 top-k 预测的准确率,然后贪婪地选择那些能够最大化期望接受长度的节点,将它们加入到树中。这样可以在相同的节点总数下,获得更高的加速比。其实,此方法本质上就是用剪枝来加速,剪去每个head中\(\text{top-}s_k\)的若干个。
具体来说,我们可以使用一个calibration dataset(比如Alpaca-eval dataset)来获取不同解码头给出的各个token的准确率:把第 k 个解码头给出的第 i 个token的准确率记为 \(a_k^{(i)}\)。假设各个token的准确率之间是独立的,那么一个由$[i_1,i_2,\cdots,i_k] $构成的候选序列的准确率可以写作 \(\prod_{j=1}^ka_j^{(i_j)}\)。我们用 I 表示候选序列的集合,那么集合里的候选序列的expectation of acceptance length就表示为:

\[\sum_{[i_1,i_2,\cdots,i_k]\in I}\prod_{j=1}^ka_j^{(i_j)}\]

在构建tree的时候,Medusa 用贪心算法优先加入当前有最大准确率的候选序列,直到tree的节点数量达到接受长度的期望值上限,这样能最大化expectation of acceptance length,也就能最大化acceleration rate。这是一种手工设计的稀疏树结构,越靠前的节点,有更多的子节点路径。即,把所有位置组合穷举,然后选取前
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册