找回密码
 立即注册
首页 业界区 业界 MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础 ...

MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础原理

百杲憔 前天 02:55
MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础原理


目录

  • MPK(Mirage Persistent Kernel)源码笔记(1)--- 基础原理

    • 0x00 概要

      • 0.1 传统LLM推理框架的瓶颈
      • 0.2 MPK的流程重构
      • 0.3 MPK的关键优势

    • 0x01 问题

      • 1.1 现有框架问题
      • 1.2 编程抽象层级

        • 1.2.1 GPU架构
        • 1.2.2 编程视角


    • 0x02 总体思路

      • 2.1 编译过程
      • 2.2 执行过程

    • 0x03 通过代码来打通流程

      • 3.1 核心模块说明

        • 3.1.1 三层结构化图模型
        • 3.1.2 PersistentKernel
        • 3.1.3 层级关系
        • 3.1.4 数据流关系

      • 3.2 main()代码
      • 3.3 关键步骤

        • 3.3.1 计算图构建过程

          • 初始化持久化内核
          • 定义张量
          • 构建计算层

        • 3.3.2 任务图生成
        • 3.3.3 runtime执行


    • 0xFF 参考


0x00 概要

CMU 贾志豪老师团队提出的MPK(Mirage Persistent Kernel)是依托 Mirage 编译器生态的创新运行时系统,其核心能力在于将多GPU环境下大语言模型(LLM)推理任务自动转换为适配GPU架构的高性能巨型内核(megakernel)。MPK的关键优势在于将传统由CPU负责的内核调度和任务依赖管理工作转移到GPU端,通过“长期驻留的巨型内核(Persistent Kernel)”自主完成,同时统筹GPU内部计算与跨GPU通信任务。这种设计不仅大幅削减了CPU-GPU交互带来的内核启动开销,还通过计算与通信的细粒度重叠,将推理延迟优化至接近硬件物理极限,显著提升推理效率。
0.1 传统LLM推理框架的瓶颈

传统LLM推理框架的流程存在固有瓶颈:CPU需逐个发起CUDA内核调用(如矩阵乘法、激活函数计算),待GPU执行完当前内核并反馈后,再触发下一个内核。这种“CPU发起-GPU执行-CPU等待”的循环,会产生频繁的CPU-GPU通信与内核启动开销,尤其在自回归生成场景中,单次token生成需多轮内核调用,开销会持续累积,严重拖累整体推理性能。
0.2 MPK的流程重构

MPK彻底重构了这一流程:它仅需CPU在推理初始化阶段,向GPU提交一个“永不主动退出的persistent_kernel”,之后所有任务分派(如层间计算顺序)、依赖管理(如等待前一层结果再执行下一层)均由GPU内部自主完成。此时CPU的角色从“实时调度的包工头”转变为“启动初始化的门卫”,仅负责触发首次内核启动,后续不再参与任何具体调度。
0.3 MPK的关键优势

MPK通过将多GPU的LLM推理任务转换为高性能的巨型内核,从根本上改变了GPU的运行模式。它不仅减少了内核启动开销,还通过细粒度的软件流水线和计算通信重叠,显著提高了推理效率。MPK提供了一种全新的思路,将性能优化的重心从“如何调用优化库”转移到了“如何为整个模型生成一个最优的、原生的执行体”,在多GPU环境下实现了更高的吞吐量和更低的延迟。
0x01 问题

重新设计类似 Mirage 的 MegaKernel的优势,是将所有计算和通信融合进一个单一的巨型内核(也称为持续内核)是降低大语言模型推理延迟的最有效方法之一。这种方法通过启动一个GPU内核来执行整个模型,从逐层计算到GPU间通信,整个过程无需中断。尽管有这些优势,将LLM编译成巨型内核仍然极具挑战性。
1.1 现有框架问题

现有框架难以支持单一的巨型内核。

  • 现有的高级ML框架,如PyTorch、Triton和TVM,并不原生支持端到端巨型内核生成。
  • 现代LLM系统由各种不同的专用内核库构建而成,这种碎片化使得将整个推理流水线整合进一个单一的、统一的内核变得非常困难。
  • 高性能GPU内核的手工编写需要大量的专家知识,如何自动生成高性能内核代码是一个痛点问题。传统做法依赖于专家编写好的内核或者手工融合规则,但这些方法维护成本高,容易漏掉跨内核/层级组合优化的机会。
1.2 编程抽象层级

从编程抽象层级上来看,也缺乏最优系统。
1.2.1 GPU架构

下图展示了当今 GPU 的层次结构。GPU 上的计算被组织为内核,每个内核都是一个函数,以单程序多数据(SPMD)的方式在多个 GPU 核心上同时执行。一个内核包括一个线程块网格,每个线程块在一个 GPU 流式多处理器上执行,并包括多个线程来对单个数据元素进行计算。每个线程都与一个每线程寄存器文件相关联,并且线程块内的所有线程都可以访问共享内存以启用集体操作。最后,内核的所有输入和输出都存储在 GPU 设备内存中。
下图是GPU hierarchy。
1.png

下图为GPU 计算架构和编程抽象示意图
2.png

1.2.2 编程视角

Triton 是一款高级 GPU 编程框架,其编程视角主要聚焦于块(Block)级别。该框架的设计允许开发者以块为单位进行编程,而块内部的优化工作则由 Triton 编译器自动完成。这种设计模式使开发者能够将精力集中在高层逻辑的构建上,无需深入研究线程(Thread)级别的细节实现。Triton 的核心优势在于其简洁的编程模型和自动化优化能力,这使得它在处理复杂并行任务时具有更高的效率。
Cutlass 则属于底层 GPU 编程库,其编程视角覆盖了块(Block)、线程束(Warp)与线程(Thread)的完整层级。Cutlass 提供了丰富的 CUDA 模板和底层控制接口,开发者可以利用这些工具精细调控每个线程的行为,从而实现高度优化的计算内核。这种细粒度的控制能力让 Cutlass 在对性能有极致要求的场景中表现出色,但同时也增加了编程的复杂性。
正是这种编程视角的层级差异,构成了当前高性能 GPU 编程领域的核心挑战:缺乏一套能够 “跨内核(Kernel)、线程块(Block)、线程(Thread)三个层级” 联合搜索最优计算方案,并自动验证方案正确性的系统。现有框架要么局限于单一层级的优化(例如 Triton 仅针对块内部逻辑进行优化,而 Cutlass 则需要开发者手动协调全层级的适配),要么无法在多层级协同后确保计算结果的准确性。这一问题在大型语言模型(LLM)推理等复杂张量计算场景中,会显著增加开发成本与优化难度。
0x02 总体思路

MegaKernel 可被视为一种 grid 级(网格级)的内核抽象。与 CUDA 的 thread 级(线程级)抽象、Triton 的 block 级(块级)抽象不同,它提供了层次更高的抽象能力,允许开发者在 grid 级开展编程工作。这种抽象设计能让开发者更灵活地管理 GPU 上的计算资源,进而实现更高效的内核生成与执行。
MPK 的工作原理主要包含以下两部分:

  • MPK 编译器:负责将大语言模型(LLM)的计算图转换为经过优化的任务图。
  • MPK 运行时系统:在单个巨型内核内部执行任务图,以此达成高吞吐量与低延迟的目标。
2.1 编译过程


  • 模型翻译:将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。
  • 任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。
2.2 执行过程


  • 任务调度:将生成的任务图交付调度器执行。调度器负责管理 GPU 中的流式多处理器(SM),并通过 warp specialization(线程束特化)技术将 SM 划分为 worker(工作单元)与 scheduler(调度单元)。这种设计与数据处理领域的 actor 模型(角色模型)相似:scheduler 负责协调任务的执行顺序,worker 则负责具体执行分配到的任务。
  • 性能优化:在小模型与低 Batch(批次)场景下,MPK 通过多种方式显著降低延迟,具体包括:消除内核启动开销、打破内核边界限制、实现细粒度的 SM 调度,以及对任务特定模式进行融合。
0x03 通过代码来打通流程

我们以demo_chat.py为例来进行全局打通。在此文件中会将Python模型结构映射为Mirage的计算图表示,然后编译为高效的持久化CUDA内核执行。
3.1 核心模块说明

3.1.1 三层结构化图模型

Mirage 实现了多层次计算图表示(μGraphs),通过 kernel-graph、block-graph 和 thread-graph 这三层结构化图模型,精确映射了 GPU 程序从内核到线程的执行逻辑与存储层级。这种三层结构与 CUDA 程序的执行层级及 GPU 的存储体系紧密对应,每层都清晰定义了“算子类型 - 张量存储 - 核心功能”的关联。
三层图功能如下:

  • Kernel Graph 是最高计算图,定义整个执行流程。通过自定义操作管理多个block graph
  • Block Graph 是嵌套在自定义操作中,定义线程块执行序列
  • Thread Graph是最低层,定义线程级别执行细节
3.1.2 PersistentKernel

PersistentKernel 作为计算图的容器和执行器,提供了从计算图构建、优化到执行的过程。
persistent_kernel.py是 PersistentKernel的Python接口,本质是Python到CUDA持久化内核系统的桥梁,允许用户用python定义复杂的计算图,然后在GPU上高效执行。
3.1.3 层级关系

计算图与 PersistentKernel 的关系如下:

  • 包含关系:PersistentKernel 内部包含并管理一个 Kernel graph
  • 构建关系:通过 PersistentKernel  的各种layer方法构建计算图。
  • 转换关系:PersistentKernel 将计算图转换为可执行的任务图
  • 执行关系:PersistentKernel  是计算图的执行引擎。
3.1.4 数据流关系

数据流关系可以近似如下图所示:
  1. 应用层:PersistentKernel.py(创建并管理kernel graph)
  2.     │                                                                  
  3.     │                                                                  
  4.     ▼   
  5. 输入张量
  6.     │                                                                  
  7.     │                                                                  
  8.     ▼  
  9. 计算图节点(各种layer方法添加)
  10.     │                                                                  
  11.     │                                                                  
  12.     ▼  
  13. 任务层:kernel graph(包括所有操作和计算流,即定义张量数据流)
  14.     │                                                                  
  15.     │                                                                  
  16.     ▼  
  17. 并行层:block graph(嵌套在自定义操作中,定义线程块执行序列,即定义内存访问模式)
  18.     │                                                                  
  19.     │                                                                  
  20.     ▼  
  21. 执行层:task graph(kernel graph生成的可执行任务图,taskDesc是可执行任务,EventDesc管理事件同步和依赖)
  22.     │                                                                  
  23.     │                                                                  
  24.     ▼  
  25. 运行时环境:PersistentKernel 执行引擎
  26.     │                                                                  
  27.     │                                                                  
  28.     ▼  
  29. 硬件层:Thread graph,在实际GPU线程中执行具体操作
复制代码
3.2 main()代码

demo_chat.py的main()如下。
  1. def main():
  2.     world_size, rank = setup_distributed_environment()
  3.     model, tokenizer = load_model_and_tokenizer(rank)
  4.     tokens = torch.full((1, MAX_SEQ_LEN), 0, dtype=torch.long, device="cuda")
  5.     step_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")
  6.     mpk = None
  7.     if args.use_mirage:
  8.         # 构建计算图
  9.         mpk = build_mirage_graph(model, world_size, rank, args, tokens, step_tensor)
  10.     positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
  11.     position_embeddings = model.model.rotary_emb(positions)
  12.     messages = [{"role": "system", "content": SYSTEM_PROMPT}]
  13.     while True:
  14.         prompt_container = [None]
  15.         if rank == 0:
  16.             try:
  17.                 prompt = input("> User: ")
  18.                 prompt_container[0] = prompt
  19.             except EOFError:
  20.                 prompt_container[0] = "exit"
  21.         if world_size > 1:
  22.             dist.broadcast_object_list(prompt_container, src=0)
  23.         prompt = prompt_container[0]
  24.         messages.append({"role": "user", "content": prompt})
  25.         text = tokenizer.apply_chat_template(
  26.             messages, tokenize=False, add_generation_prompt=True
  27.         )
  28.         model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  29.         new_prompt_len = model_inputs.input_ids.shape[-1]
  30.         tokens[0, :new_prompt_len] = model_inputs.input_ids[0]
  31.         if new_prompt_len < tokens.shape[1]:
  32.             tokens[0, new_prompt_len:] = 0
  33.         prompt_len = new_prompt_len
  34.         if args.use_mirage:
  35.             end_pos, run_time, generated_len = run_mirage_generation(
  36.                 model, mpk, tokens, prompt_len, step_tensor, position_embeddings
  37.             )
  38.         else:
  39.             end_pos, run_time, generated_len = run_pytorch_generation(
  40.                 model, tokens, prompt_len, step_tensor, position_embeddings
  41.             )
  42.         if rank == 0:
  43.             assistant_response_ids = tokens[0, prompt_len:end_pos]
  44.             assistant_response = tokenizer.decode(assistant_response_ids, skip_special_tokens=True)
  45.     if world_size > 1:
  46.         dist.destroy_process_group()
  47.     print("Exiting demo.")
复制代码
总体过程如下:

  • 模型定义阶段

    • 使用PyTorch/HuggingFace定义模型结构。
    • 加载预训练权重
    • 初始化输入张量和相关参数。

  • 任务图构建阶段

    • 通过KNOperator定义计算操作
    • 构建完整的计算图结构
    • 设置任务配置参数

  • 任务图优化阶段

    • 分析任务间的依赖关系
    • 生成事件描述以管理依赖
    • 对任务进行合理分组以优化执行

  • 任务图转换阶段

    • 生成TaskDesc描述每个计算任务
    • 生成EventDesc描述任务间同步事件
    • 生成CUDA可执行代码
    • 输出JSON配置文件用于运行时加载。

  • 运行时初始化阶段

    • 配置GPU资源(worker,调度器等)
    • 分配GPU内存给任务队列和事件队列
    • 初始化工作队列和调度队列。
    • 设置事件计数器和相关同步机制

  • 持久化内核运行阶段

    • worker执行具体计算任务
    • 调度器负责任务调度和事件管理
    • 通过事件机制协调任务间的依赖关系
    • 支持多GPU环境下的分布式执行。

3.3 关键步骤

3.3.1 计算图构建过程

此处对应模型翻译过程,即将 PyTorch 框架下的模型翻译为 MPK 的指令集,这一步骤本质上相当于用 MPK 的指令重新构建模型的过程。尽管 PyTorch 具备强大的自动微分与优化能力,但要将模型完整转换为 MPK 的指令集,仍需进行大量手动调整与优化操作。
模型转换为计算图的工作是在build_mirage_graph函数中,其主要步骤如下:
初始化持久化内核

首先构建PersistentKernel实例。
  1.     mpk = mi.PersistentKernel(
  2.         world_size=world_size,
  3.         mpi_rank=rank,
  4.         num_workers=96,
  5.         num_local_schedulers=48,
  6.         num_remote_schedulers=0,
  7.         max_seq_length=4096,
  8.         eos_token_id=model.config.eos_token_id,
  9.         meta_tensors=[step_tensor, tokens_tensor],
  10.         profiler_tensor=profiler_tensor,
  11.     )
复制代码
定义张量

将模型权重和中间张量添加到计算图中。
  1. # 输入张量
  2. x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
  3. # 位置编码
  4. positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
  5. position_embeddings = model.model.rotary_emb(positions)
  6. x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
  7. cos_pos_embed = mpk.attach_input(
  8.     torch_tensor=position_embeddings[0][0, :MAX_CONTEXT_LEN, :],
  9.     name="cos_position_embedding",
  10. )
  11. sin_pos_embed = mpk.attach_input(
  12.     torch_tensor=position_embeddings[1][0, :MAX_CONTEXT_LEN, :],
  13.     name="sin_position_embedding",
  14. )
  15. # 计算图的中间结果张量
  16. embed_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="embed_out")
  17. attn_in = mpk.new_tensor(dims=(batch_size, fused_outdim_1 // world_size), dtype=mi.bfloat16, name="attn_in")
  18. attn_out = mpk.new_tensor(dims=(batch_size, num_local_q_heads * head_dim), dtype=mi.bfloat16, name="attn_out")
  19. is_nvshmem = "nvshmem_tensor" if world_size > 1 else "cuda_tensor"
  20. attn_proj_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_proj_out", io_category=is_nvshmem)
  21. allreduce_buf = mpk.new_tensor(dims=(world_size, batch_size, hidden_size), dtype=mi.bfloat16, name="all_reduce_buf", io_category=is_nvshmem)
  22. attn_allreduce_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_allreduce_out", io_category=is_nvshmem)
  23. mlp_mid = mpk.new_tensor(dims=(batch_size, fused_outdim_2 // world_size), dtype=mi.bfloat16, name="mlp_mid")
  24. mlp_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_out", io_category=is_nvshmem)
  25. mlp_final = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_final", io_category=is_nvshmem)
  26. argmax_in = mpk.new_tensor(dims=(batch_size, vocab_size), dtype=mi.bfloat16, name="argmax_in")
  27. argmax_part_value = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.bfloat16, name="argmax_part_value")
  28. argmax_part_index = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.int64, name="argmax_part_index")
  29. argmax_out = mpk.new_tensor(dims=(batch_size, 1), dtype=mi.int64, name="argmax_out")
复制代码
构建计算层

通过调用各种layer方法将模型层添加到计算图。此处会把HuggingFace模型权重映射到Mirage张量。也可以融合张量以提高计算效率。
  1. # --- Define the Model Graph ---
  2. w_embed = mpk.attach_input(torch_tensor=model.model.embed_tokens.weight, name="embed_tokens")
  3. mpk.embed_layer(input=x, weight=w_embed, output=embed_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
  4. x = embed_out
  5. for i, layer in enumerate(model.model.layers):
  6.      # Attention block
  7.      w_norm_attn = mpk.attach_input(torch_tensor=layer.input_layernorm.weight, name=f"layer_{i}_input_layernorm")
  8.      w_q = mpk.attach_input(torch_tensor=layer.self_attn.q_proj.weight, name=f"layer_{i}_q_proj")
  9.      w_k = mpk.attach_input(torch_tensor=layer.self_attn.k_proj.weight, name=f"layer_{i}_k_proj")
  10.      w_v = mpk.attach_input(torch_tensor=layer.self_attn.v_proj.weight, name=f"layer_{i}_v_proj")
  11.      w_qkv = mpk.fuse_tensors(inputs=[w_q, w_k, w_v], fused_dim=0, num_groups=num_local_kv_heads, name=f"layer_{i}_qkv_proj")
  12.      mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_attn, weight_linear=w_qkv, output=attn_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
  13.      w_q_norm = mpk.attach_input(torch_tensor=layer.self_attn.q_norm.weight, name=f"layer_{i}_q_norm")
  14.      w_k_norm = mpk.attach_input(torch_tensor=layer.self_attn.k_norm.weight, name=f"layer_{i}_k_norm")
  15.      k_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[0][i], name=f"layer_{i}_k_cache")
  16.      v_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[1][i], name=f"layer_{i}_v_cache")
  17.      mpk.attention_layer(input=attn_in, q_norm=w_q_norm, k_norm=w_k_norm, k_cache=k_cache, v_cache=v_cache, cos_pos_embed=cos_pos_embed, sin_pos_embed=sin_pos_embed, output=attn_out, grid_dim=(batch_size, num_local_kv_heads, 1), block_dim=(128, 1, 1))
  18.      w_o_proj = mpk.attach_input(torch_tensor=layer.self_attn.o_proj.weight, name=f"layer_{i}_o_proj")
  19.      mpk.linear_with_residual_layer(input=attn_out, weight=w_o_proj, residual=x, output=attn_proj_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
  20.      x = attn_proj_out
  21.      if world_size > 1:
  22.          mpk.allreduce_layer(input=attn_proj_out, buffer=allreduce_buf, output=attn_allreduce_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
  23.          x = attn_allreduce_out
  24.      # MLP block
  25.      residual_mlp = x
  26.      w_norm_mlp = mpk.attach_input(torch_tensor=layer.post_attention_layernorm.weight, name=f"layer_{i}_post_attn_layernorm")
  27.      w_gate_proj = mpk.attach_input(torch_tensor=layer.mlp.gate_proj.weight, name=f"layer_{i}_gate_proj")
  28.      w_up_proj = mpk.attach_input(torch_tensor=layer.mlp.up_proj.weight, name=f"layer_{i}_up_proj")
  29.      w_gatedup = mpk.fuse_tensors(inputs=[w_gate_proj, w_up_proj], fused_dim=0, num_groups=1, name=f"layer_{i}_gatedup_proj")
  30.      mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_mlp, weight_linear=w_gatedup, output=mlp_mid, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
  31.      w_down_proj = mpk.attach_input(torch_tensor=layer.mlp.down_proj.weight, name=f"layer_{i}_down_proj")
  32.      mpk.silu_mul_linear_with_residual_layer(input=mlp_mid, weight=w_down_proj, residual=residual_mlp, output=mlp_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
  33.      x = mlp_out
  34.      if world_size > 1:
  35.          mpk.allreduce_layer(input=mlp_out, buffer=allreduce_buf, output=mlp_final, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
  36.          x = mlp_final
  37. # Final layer
  38. w_final_norm = mpk.attach_input(torch_tensor=model.model.norm.weight, name="model_norm_weight")
  39. w_lm_head = mpk.attach_input(torch_tensor=lm_head_weight, name="lm_head")
  40. mpk.rmsnorm_linear_layer(input=x, weight_norm=w_final_norm, weight_linear=w_lm_head, output=argmax_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
  41. # Argmax
  42. mpk.argmax_partial_layer(input=argmax_in, output=(argmax_part_value, argmax_part_index), grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
  43. mpk.argmax_reduce_layer(input=(argmax_part_value, argmax_part_index), output=argmax_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
复制代码
3.3.2 任务图生成

此处对应任务图生成:编译器会将翻译后的模型进一步转换为细粒度任务图。该任务图属于有向无环图(DAG),图中每个节点代表一项具体任务,节点间的边则代表任务之间的依赖关系。这一步骤要求编译器能够准确识别并优化任务间的依赖关系,为后续高效调度奠定基础。
调用compile()方法生成最终的执行图。compile()函数内会执行:

  • 生成任务图。
  • 创建CUDA代码。
  • 调用nvcc编译器。
  • 创建Python绑定模块。
  1.     mpk.compile()
  2.     print("Mirage graph compiled.")
  3.     return mpk
复制代码
3.3.3 runtime执行

run_mirage_generation()函数是执行引擎运行任务图过程。
  1. def run_mirage_generation(model, mpk, tokens, prompt_len, step_tensor, position_embeddings):
  2.     # 初始化CUDA事件用于计时(starter记录开始,ender记录结束)
  3.     starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
  4.     # 创建CUDA流,用于管理异步计算任务的执行顺序
  5.     stream = torch.cuda.Stream()
  6.     # 预填充阶段(处理输入的prompt文本,生成初始上下文)
  7.     # 将步骤张量的值设为prompt长度减1,标记预填充阶段的结束位置
  8.     step_tensor.fill_(prompt_len - 1)
  9.     # 从输入 tokens 中截取前 prompt_len 个token作为初始输入(即prompt部分)
  10.     input_ids = tokens[:, 0:prompt_len]
  11.     # 提取与prompt长度匹配的位置编码(余弦部分)
  12.     cos_embeddings = position_embeddings[0][:, 0:prompt_len]
  13.     # 提取与prompt长度匹配的位置编码(正弦部分)
  14.     sin_embeddings = position_embeddings[1][:, 0:prompt_len]
  15.     # 调用模型前向传播,处理prompt并生成初始logits
  16.     logits = model.forward(
  17.         input_ids=input_ids,                   # 输入的prompt token序列
  18.         position_embeddings=(cos_embeddings, sin_embeddings),  # 对应的位置编码
  19.         step=step_tensor,                      # 当前处理步骤标记
  20.         stream=stream                          # 使用指定的CUDA流进行计算
  21.     )
  22.     # 从logits中选取概率最大的token作为下一个生成的token(取最后一个位置的输出)
  23.     next_token = logits.argmax(dim=-1)[0, -1]
  24.     # 将生成的第一个token写入tokens张量的prompt_len位置,作为生成阶段的起始
  25.     tokens[0, prompt_len] = next_token
  26.     # 等待CUDA流中的所有操作完成,确保预填充阶段计算结果就绪
  27.     torch.cuda.synchronize()
  28.     # 为下一轮生成重新初始化持久化内核(MPK)
  29.     # 收集元数据张量的指针地址,供内核访问
  30.     meta_tensors_ptr = [tensor.data_ptr() for tensor in mpk.meta_tensors]
  31.     # 获取性能分析缓冲区的指针(若不存在则设为0)
  32.     profiler_buffer_ptr = (
  33.         mpk.profiler_tensor.data_ptr() if mpk.profiler_tensor is not None else 0
  34.     )
  35.     # 调用MPK的初始化函数,配置内核运行参数
  36.     mpk.init_func(
  37.         meta_tensors_ptr,             # 元数据张量指针列表
  38.         profiler_buffer_ptr,          # 性能分析缓冲区指针
  39.         mpk.mpi_rank,                 # 当前MPI进程的排名(分布式场景)
  40.         mpk.num_workers,              # 工作单元(worker)的数量
  41.         mpk.num_local_schedulers,     # 本地调度器的数量
  42.         mpk.num_remote_schedulers     # 远程调度器的数量(分布式场景)
  43.     )
  44.     # 生成阶段(基于预填充的上下文,持续生成后续token)
  45.     # 将步骤张量的值设为prompt_len,标记生成阶段的起始位置
  46.     step_tensor.fill_(prompt_len)
  47.     # 记录生成阶段开始时间
  48.     starter.record()
  49.     # 执行持久化内核,启动生成过程
  50.     mpk()
  51.     # 记录生成阶段结束时间
  52.     ender.record()
  53.     # 等待CUDA操作完成,确保计时准确
  54.     torch.cuda.synchronize()
  55.     # 计算生成阶段的运行时间(毫秒)
  56.     run_time = starter.elapsed_time(ender)
  57.     # 获取生成结束时的位置(从步骤张量中提取具体数值)
  58.     end_pos = step_tensor[0].item()
  59.     # 计算实际生成的token长度(总长度减去prompt长度)
  60.     generated_len = end_pos - prompt_len
  61.     # 返回生成结束位置、运行时间和生成长度
  62.     return end_pos, run_time, generated_len
复制代码
0xFF 参考

如何评价CMU将LLM转化为巨型内核的Mirage Persistent Kernel(MPK)工作?
Mirage: A Multi-Level Superoptimizer for Tensor Programs 简记  尘伊光
OSDI2025论文笔记:Mirage: A Multi-Level Superoptimizer for Tensor Programs  画饼充饥
Mirage: A Compiler for High-Performance Tensor Programs on GPUs
https://mirage-project.readthedocs.io/en/latest/mugraph.html
https://mirage-project.readthedocs.io/en/latest/transpiler.html
https://zhihaojia.medium.com/compiling-llms-into-a-megakernel-a-path-to-low-latency-inference-cf7840913c17
舍弃CUDA编程!CMU等用代码将LLM编译成巨型内核,推理延迟降6.7倍  机器之心Pro

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册