
不只是少一次 kernel launch而是让碎片化专家计算、L2 复用与 dispatch/combine 通信进入同一条流水线适合懂 Transformer/MoE 基础的工程读者阅读约 9 分钟。Grouped GEMM 译为分组矩阵乘。MoE 让大模型拥有更多参数但也把原本规整的 FFN 计算变成了许多大小不一、负载不均的小矩阵乘。在 Dense FFN 里一个 batch 通常对应少数几个大 GEMM。GPU 喜欢这种形态矩阵足够大tile 足够多Tensor Core 容易喂饱。到了 MoErouter 会把 token 分给不同 expert。每个 expert 拿到的 token 数不同于是每个 expert 要做的 GEMM 的 M 维也不同。有的 expert 很忙有的 expert 只有几个 token。NVIDIA 对 MoE 通信的描述中dispatch 负责把 attention 输出 token 路由到对应 expertcombine 负责把 expert 输出再路由回 attention 输出。换句话说MoE 的性能问题不是单纯“矩阵乘够不够快”而是碎片化专家计算如何和 dispatch/combine 通信共存。看图 1 时可以重点看中间的专家区域每个 expert 的 GEMM 形状不完全一样。如果用最直接的写法就是对 expert 做一个 for-loopExpert 0 启一个 GEMM kernelExpert 1 再启一个 GEMM kernel……这样做的缺点很明显kernel launch 多、单个小 GEMM 的并行度不够、专家之间负载不均GPU 很难一直保持高利用率。Grouped GEMM 的直觉就是不要把每个 expert 当成一个孤立 GEMM而是把一组 GEMM 交给同一个 grouped kernel 统一调度。CUTLASS 的 grouped GEMM 示例把这类 workload 描述为一批 GEMM 操作每个 GEMM 可以有不同 problem size矩阵指针、leading dimension 和 problem size 以数组形式传入这也正是它和普通 batched GEMM 的关键区别。Triton 的 Group GEMM 教程也把它描述为启动固定数量的 CTA 来计算一组 GEMM并在 device 侧完成调度。Grouped GEMM 到底解决什么问题Grouped GEMM 不是一个神秘的新数学算子。数学上它仍然是在算C_i A_i × B_i, i 0, 1, ..., G-1区别在于这里的每个i可以代表一个 expert、一个 shard 或一个独立的 GEMM problem。它们的M/N/K、矩阵地址、stride 都可以不同。Grouped GEMM 关心的是如何把这些独立问题放进一个 kernel 的调度空间里让 CTA 持续拿 tile 干活。这件事在 MoE 里尤其重要。因为 MoE 的 token 分布天然不均匀热门 expert 的M大冷门 expert 的M小。直接逐 expert 调 GEMM热门 expert 会拖尾冷门 expert 又不够大最后很容易变成“有些 SM 很忙有些 SM 等活”。Grouped GEMM 把专家小 GEMM 变成一张 tile 队列。调度器关心的不再是“先做完 Expert 0再做 Expert 1”而是“下一个 CTA 应该拿哪个 expert 的哪个 tile”。CUTLASS 文档中也强调grouped kernel scheduler 的核心职责就是把一组 problem 中的 tile 分配给 threadblockthreadblock 会持续查询下一个 tile执行 MMA 和 epilogue再推进调度状态。这带来三个直接收益减少 launch提升小 GEMM 的并行度缓解 expert 间负载不均。只做成 grouped 还不够真正难的是流水线和内存很多人介绍 Grouped GEMM 时会停在“多个 GEMM 合成一个 kernel”这一层。但在 MoE 训练或推理里这还不够。MoE 层同时有三类压力expert 计算要读 activation 和 weightdispatch/combine 要搬 token下一层或其他并行通信也可能在抢 HBM、NVLink 或 RDMA 资源。因此一个真正可用的 Grouped GEMM 设计通常要同时看四个旋钮tiling、double-buffer pipeline、L2 cache conflict 缓解、HBM 带宽控制。1. Tiling把专家不均匀变成 tile 队列Tiling 的第一层作用是把不同 expert 的 GEMM 切成统一粒度的 tile。比如 Expert 1 有 160 个 tokenExpert 2 只有 32 个 token。直接按 expert 调度会看到一个很大的任务和一个很小的任务切成 tile 后它们都变成调度器可以分发的工作单元。这样热门 expert 不会独占整个 kernel冷门 expert 也不会因为太小而浪费一次 launch。Tiling 的第二层作用是让 kernel 可以调节寄存器、shared memory、occupancy 和 Tensor Core 使用效率。tile 太小调度开销和访存开销占比会上升tile 太大又容易让长尾 expert 拖住 SM。MoE 场景中tile size 不是越大越好而是要配合 token 分布、expert 数、输出维度和通信节奏一起选。2. Double-buffer pipeline别让 Tensor Core 等数据GEMM kernel 的主循环通常在做两件事从 HBM/L2 把下一块数据搬进 shared memory同时让 Tensor Core 对当前块做 MMA。Double-buffer pipeline 的基本做法是准备两套 buffer当前 buffer 用来 compute另一个 buffer 用来 load 下一块。下一轮交换角色。理想情况下数据加载被计算覆盖Tensor Core 不会频繁饿死。在 Grouped GEMM 中这个问题更复杂。因为不同 expert 的矩阵地址不连续tile 形状和剩余 K 维可能不同。pipeline 不仅要解决单个 GEMM 内部的 load/compute overlap还要处理跨 expert 的 tile 切换什么时候预取下一个 expert 的 weight什么时候切换 activation slice什么时候停止过度预取避免把 HBM 带宽打满这些都属于 grouped kernel 里的工程细节。3. L2 cache conflict 缓解不要让 tile 调度互相拆台L2 cache 在这里像一个共享中转站。Grouped GEMM 想复用 expert weight、activation tile 和 epilogue 相关数据dispatch/combine 也会产生大量 token 搬运。调度顺序如果太随意多个 CTA 可能在相近时间访问互相冲突的数据区域导致 L2 hit rate 下降最终表现为 HBM 压力增加。因此很多优化会围绕 locality 做文章让相近 tile 更容易复用 L2让同一 expert 的 weight 不要刚进 cache 就被别的访问模式挤掉让 tile 排布不要集中冲击同一类地址。PyTorch 的 MoE locality-aware GEMM 文章就展示过通过改变调度顺序来改善数据局部性可以显著提升 MoE GEMM 的硬件利用率。这里的关键不是记住某一种固定顺序而是理解原则Grouped GEMM 的调度器不仅是在分配计算也是在塑造缓存访问模式。4. HBM 带宽控制给通信留水位MoE 里的 GEMM 和通信不是两条互不相干的路。如果 Grouped GEMM 过度预取把 HBM 带宽打满dispatch/combine 可能被挤压如果为了通信把 GEMM 做得太保守Tensor Core 又会吃不饱。真正难的是在计算吞吐和通信带宽之间找水位。工程上可以调的旋钮包括CTA 数量、pipeline stage 数、prefetch distance、tile size、persistent block 数量、是否按 chunk 分段执行以及通信 stream 和 compute stream 的重叠策略。这也是为什么 MoE kernel 不能只看单 kernel TFLOPS。一个单独跑起来很漂亮的 GEMM如果上线后把 dispatch/combine 挤到尾部端到端 latency 仍然可能变差。和 dispatch/combine 一起看Grouped GEMM 是 MoE 管线的一段MoE 层可以粗略看成三步dispatch tokens - expert GEMMs - combine outputs但高性能实现不会真的把这三步完全串行化。更常见的思路是 chunk 化当前 chunk 做 expert GEMM下一 chunk 做 dispatch上一 chunk 做 combine。这样计算和通信可以在时间线上重叠。这时 Grouped GEMM 的角色会变化它不只是“尽可能快地完成 expert GEMM”还要“以不会破坏整体流水线的方式完成 expert GEMM”。这解释了为什么 HBM 控制、L2 conflict 缓解和 tile 调度会变得重要。MoE 的瓶颈往往不是某一个 kernel而是多个 kernel、通信库、缓存层级和 HBM 带宽共同形成的系统瓶颈。Fused GemmAdd把 FP32 加法塞进 epilogue最后说 Fused GemmAdd。一个 GEMM kernel 通常可以分成 mainloop 和 epilogue。mainloop 负责 MMA把乘加结果累积到 accumulatorepilogue 负责把 accumulator 做缩放、类型转换、写回等后处理。CUTLASS 文档也把 GEMM operator 描述为 MMA 后接 epilogue operation。如果 GEMM 后面还有一次 FP32 加法例如Y GEMM(X, W) Add朴素做法是先让 GEMM 写回C再启动一个 Add kernel把C和Add读出来相加再写回Y。这会多一次 kernel launch也会多一次 HBM 读写。Fused GemmAdd 的思路很直接既然 GEMM accumulator 本来就在 epilogue 里准备写回那就在写回前完成 FP32 加法Y cast(Acc_fp32 Add_fp32)或者更一般地写成Y cast(alpha * Acc_fp32 beta * Add_fp32)Fused GemmAdd 的收益主要来自两点。第一减少一次中间结果写回和再读取。对小 GEMM 或 memory-bound 后处理来说这往往比省一点算术指令更重要。第二减少 kernel launch 和调度缝隙。MoE 本来就有很多 expert、很多通信和很多小 kernel如果能把简单后处理融合进 epilogue就能让执行图更短、更连续。需要注意的是“FP32 加法融合”并不等于最终输出必须保持 FP32。常见做法是 accumulator 和 add 在 FP32 中完成再按输出要求 cast 到 BF16、FP16 或其他格式。这里要明确数值路径加法在哪个精度做什么时候 cast是否需要 saturation 或 rounding。什么时候 Grouped GEMM 不一定赚Grouped GEMM 很适合 MoE但不是万能药。如果 expert 数很少、每个 expert 的 GEMM 已经足够大普通 GEMM 可能已经能跑满 GPU。此时 grouped 调度带来的收益会变小。如果 token 分布极端不均某个 expert 远大于其他 expertGrouped GEMM 仍然会遇到长尾。此时需要排序、拆分大 expert、调整 tile 或做 persistent scheduling。如果端到端瓶颈主要在 all-to-all、dispatch/combine 或跨节点通信单独优化 GEMM 只能改善一部分。你会看到 kernel profile 变漂亮但 step time 改善有限。如果 epilogue 里融合太多逻辑也可能带来寄存器压力、occupancy 下降和更复杂的边界处理。Fused GemmAdd 适合简单、规则、访存代价高的后处理不是所有后处理都应该塞进 epilogue。实战 profiling 看什么评估 Grouped GEMM不建议只看 TFLOPS。至少要同时看以下几类指标维度重点观察计算Tensor Core 利用率、SM occupancy、tile 尾部空转调度expert 间负载均衡、CTA 分配、长尾 tile缓存L2 hit rate、weight/activation 复用情况带宽HBM throughput、dispatch/combine 是否被挤压图执行kernel launch 数、GEMM 与通信重叠比例数值FP32 accumulate、epilogue add、输出 cast 路径一个好的 Grouped GEMM 优化最终应该体现为端到端收益MoE 层耗时下降通信尾部减少GPU 时间线更紧凑而不是只有孤立 GEMM microbenchmark 变好。小结Grouped GEMM 的认知地图可以用一句话总结Grouped GEMM 是 MoE 里把碎片化专家计算重新组织成可调度整体的关键算子。更细一点它包含五层设计Grouped GEMM 解决的是“很多大小不一的小 GEMM 如何合并调度”Tiling 解决的是“专家负载不均如何变成 tile 粒度的工作队列”Double-buffer pipeline 解决的是“Tensor Core 如何不等 HBM 数据”L2 conflict 缓解和 HBM 带宽控制解决的是“计算如何不把通信挤死”Fused GemmAdd 解决的是“简单 FP32 后处理如何少一次 launch、少一次读写”。所以MoE 里的 Grouped GEMM 不只是一个 GEMM 优化而是一个小型系统工程它连接 router 之后的专家计算也连接 dispatch/combine 通信它既要追求 Tensor Core 吞吐也要尊重缓存和 HBM 的边界。真正的优化目标不是“某个 GEMM kernel 跑得最猛”而是MoE 这一层整体跑得最稳、最短、最少互相干扰。