)
在前面的文章里Flash Attention 这个名字反复出现第 2 篇讲 attention 时提到它是「现代推理框架的标配」第 5 篇讲长上下文时把它列为「四大攻坚维度」之一第 11 篇讲推理优化时它是 prefill 阶段的核心加速器这一篇我们正式把它讲透。为什么 Flash Attention 值得单独一篇因为它代表了深度学习系统优化的一个里程碑思路——它没有改变任何数学计算结果完全等价只通过重新设计数据在显存中的搬运方式把 attention 的速度提升了 2-4 倍把显存占用从 O(n²) 降到 O(n)。如果你做过相关工作下面这些问题应该不陌生为什么 vLLM、SGLang、TensorRT-LLM 都默认用 Flash Attention为什么把attn_implementationflash_attention_2加上模型就能跑得快很多Flash Attention 的分块、在线 softmax到底是什么H100 的 Flash Attention v3 比 v2 快多少端到端训练用了 Flash Attention 后能省多少显存读完本文你将能理解 GPU 显存层级HBM vs SRAM—— 这是 Flash Attention 的物理基础理解 Flash Attention 的两个核心技巧Tiling Online Softmax知道 v1 / v2 / v3 之间的演进针对你的硬件选对版本用 PyTorch / HuggingFace / vLLM 三种方式启用 Flash Attention判断什么场景 Flash Attention 不适用我们开始。一、为什么 Attention 需要专门优化1.1 一个被忽视的事实GPU 不是只有算力很多人对 GPU 的认知停留在「TFLOPS 多少」——比如 H100 SXM 是 989 TFLOPSFP16。但 GPU 还有一个同等重要的指标显存带宽。GPUFP16 算力显存带宽算力/带宽比V100125 TFLOPS0.9 TB/s139A100 80G312 TFLOPS2.0 TB/s156H100 SXM989 TFLOPS3.35 TB/s295H200989 TFLOPS4.8 TB/s206B2002250 TFLOPS8 TB/s281注意「算力/带宽比」——越高表示单位带宽对应的算力越多。关键认知GPU 算力增长比显存带宽增长快得多。从 V100 到 H100算力翻了 8 倍带宽只翻了 3.7 倍。这意味着「IO 瓶颈」越来越严重。1.2 GPU 显存的三层结构我们再深入一层——GPU 内部其实有多级存储HBM (High Bandwidth Memory) 80 GB, 3.3 TB/s ↑ ↑ 显存所有数据默认在这里 非常大相对慢 L2 Cache 50 MB, ~12 TB/s ↑ 中间层 SRAM (Shared Memory Registers) ~228 KB / SM ~19 TB/s ↑ 片上极快但极小 H100 有 132 个 SM总共也才 30 MB简化版[ 80 GB ] HBM ← 慢 ↕↕↕↕↕↕ 数据搬运 [ 30 MB ] SRAM ← 极快 ↑ 计算实际发生的地方核心矛盾数据默认在 HBM80 GB 富余但计算必须在 SRAM 进行每次计算都要把数据从 HBM 搬到 SRAMHBM 带宽3.3 TB/s远低于 SRAM19 TB/s这就是为什么 IO 成了瓶颈——GPU 算力再强数据搬不进来也没用。1.3 传统 Attention 的 IO 噩梦回顾 attention 计算S Q · K^T # [n, n] 矩阵 P softmax(S) # [n, n] 矩阵 O P · V # [n, d] 矩阵朴素实现把每一步的中间结果写回 HBM然后下一步再读回来1. 读 Q, K 到 SRAM 2. 计算 S QK^T 3. 把 S 写回 HBM ← O(n²) 写 4. 读 S 回 SRAM 5. 计算 P softmax(S) 6. 把 P 写回 HBM ← O(n²) 写 7. 读 P, V 到 SRAM 8. 计算 O PV 9. 写 O 到 HBM问题第 3、6 步要把n × n大小的矩阵在 HBM 和 SRAM 之间来回搬。对于 n8K 序列S 矩阵显存8K × 8K × 4 bytes 256 MB这 256 MB 反复在 HBM ↔ SRAM 间来回搬实测数据A100 上 attention 计算实际算力消耗约 5% 的 GPU 算力实际 IO 消耗约 95% 的 GPU 时间也就是说95% 的时间在搬数据5% 的时间在算——这是工程优化的巨大空间。1.4 Flash Attention 的「Aha Moment」Flash Attention 论文Tri Dao, 2022的一句话总结了它的核心思想能不能让 attention 计算不要物化中间矩阵 S 和 P如果可以那么IO 量从 O(n²) 降到 O(n)显存占用从 O(n²) 降到 O(n)速度提升 2-4×算力终于能跑满但难点在于softmax 需要看到整行才能归一化——你不知道总和之前怎么知道每个元素的归一化值Flash Attention 的天才之处在于它用一种叫在线 softmax的算法让 softmax 可以流式计算。二、Flash Attention v1 原理深入2.1 核心技巧 1Tiling分块Flash Attention 不一次计算整个 attention而是按块计算。把 Q、K、V 切成 blockQ : [n, d] → 切成 Tr 块每块 [Br, d] K : [n, d] → 切成 Tc 块每块 [Bc, d] V : [n, d] → 切成 Tc 块每块 [Bc, d]Br、Bc 设计成能装进 SRAM典型值 128。然后双层循环for j in range(Tc): # 外层循环 K, V 块 把 Kj, Vj 加载到 SRAM for i in range(Tr): # 内层循环 Q 块 把 Qi 加载到 SRAM 在 SRAM 中计算 Qi · Kj^T → Sij (小矩阵) 在 SRAM 中应用 softmax → Pij 在 SRAM 中计算 Pij · Vj → 输出累积 把累积结果写回 HBM关键整个n × n大矩阵 S 从未物化在 HBM只有小的Br × Bc块在 SRAM 里HBM ↔ SRAM 的数据搬运量从 O(n²) 降到 O(n²/M)M 是 SRAM 大小2.2 核心技巧 2在线 Softmax但 softmax 是个全局操作——它需要先看到整行才能归一化softmax(x) exp(x_i) / Σ exp(x_j) ↑ 需要总和Flash Attention 用在线 softmax解决# 增量计算 softmax # 假设我们已经处理了前 i 个 block # m_i 前 i 个 block 的最大值 # s_i 前 i 个 block 的 exp 总和 新来一个 block计算它的 softmax m_new max(m_i, max(new_block)) s_new exp(m_i - m_new) * s_i exp(m_new - m_new) * sum(exp(new_block - m_new)) 输出 用 m_new 和 s_new 重新归一化所有已处理的部分这个算法的核心数学技巧exp(a) exp(b) exp(max) * [exp(a - max) exp(b - max)] ↑ 防止 overflow 可流式合并直观上每个 block 自己算 softmax用本地 max 防 overflow处理完后保存 (max, sum) 两个状态来新 block 时用两个 max 之间的换算因子调整之前的累积这个算法数学上完全等价于一次性 softmax——没有任何精度损失。2.3 完整伪代码def flash_attention(Q, K, V): n, d Q.shape M SRAM_SIZE # SRAM 大小约 100 KB Br, Bc derive_block_size(M, d) # 通常 128 Tr, Tc n // Br, n // Bc # 初始化输出和状态 O zeros((n, d), in_hbmTrue) l zeros(n, in_hbmTrue) # 累积的 sum m full(n, -inf, in_hbmTrue) # 累积的 max for j inrange(Tc): Kj load_to_sram(K[j*Bc:(j1)*Bc]) Vj load_to_sram(V[j*Bc:(j1)*Bc]) for i inrange(Tr): Qi load_to_sram(Q[i*Br:(i1)*Br]) Oi load_to_sram(O[i*Br:(i1)*Br]) li load_to_sram(l[i*Br:(i1)*Br]) mi load_to_sram(m[i*Br:(i1)*Br]) # 在 SRAM 内计算 Sij Qi Kj.T / sqrt(d) # [Br, Bc] mij row_max(Sij) # [Br] Pij exp(Sij - mij[:, None]) # [Br, Bc] lij row_sum(Pij) # [Br] # 在线 softmax 合并 m_new max(mi, mij) l_new exp(mi - m_new) * li exp(mij - m_new) * lij # 更新输出 Oi_new ( (li * exp(mi - m_new))[:, None] * Oi exp(mij - m_new)[:, None] * (Pij Vj) ) / l_new[:, None] # 写回 HBM write_to_hbm(O[i*Br:(i1)*Br], Oi_new) write_to_hbm(l[i*Br:(i1)*Br], l_new) write_to_hbm(m[i*Br:(i1)*Br], m_new) return O整体效果数学等价于标准 attention中间矩阵 S, P 从未离开过 SRAMHBM IO 量降为原来的 1/MM SRAM 大小约 100 KB2.4 Flash Attention v1 的实际收益序列长度朴素 AttentionFlash Attention速度提升5121×1.2×1.2×10241×1.8×1.8×40961×2.7×2.7×163841×3.5×3.5×结论序列越长Flash Attention 越赚。显存占用序列长度朴素 (n² 矩阵)Flash Attention8K256 MB 1 MB32K4 GB 4 MB128K64 GB 16 MB这就是为什么没有 Flash Attention 根本搞不动长上下文——光 attention 矩阵就把显存吃光了。三、Flash Attention v2 / v3 的演进3.1 v2 (2023.07)进一步加速Flash Attention v2 的改进点改进 1减少非矩阵乘法的开销v1 中有不少 rescale、max compare 等非 matmul 操作这些操作虽然简单但累积起来不少。v2 重新设计算法把它们减少到最少。改进 2更好的并行化v1 内层循环只在 Q 上并行。v2 把外层循环也并行化更充分利用 GPU 的多个 SM。改进 3分配更好的 warp把 SRAM 分配给更细粒度的 warp进一步提升计算密度。实测比 v1 快~2×在 A100 上达到 50-70% 的理论算力在长序列下尤其明显3.2 v3 (2024.07)H100 时代的飞跃Flash Attention v3 专为 H100 设计引入了 H100 的特殊功能特性 1异步加载async TMAH100 引入了TMATensor Memory Accelerator——可以异步搬运数据让计算和数据搬运 overlap。v3 充分利用这个计算 block 1 ── 同时加载 block 2 计算 block 2 ── 同时加载 block 3 计算 block 3 ── 同时加载 block 4 ...特性 2FP8 支持v3 第一次支持 FP8 attention• 精度约 0.1% 掉点• 速度比 FP16 再快 2×特性 3Warpgroup 异步矩阵乘法H100 的WGMMAWarpgroup MMA让矩阵乘法本身就是异步的。v3 充分利用这个让算力打满。实测v3 在 H100 上达到75% 的理论算力vs v2 的 35%FP8 模式下接近1.5 PFLOPS3.3 三个版本性能对比测试设置H100 SXM序列长度 8Kd128BF16版本TFLOPS利用率标准 PyTorch111.1%Flash v119519.7%Flash v234835.2%Flash v374074.8%Flash v3 FP8141771.6% (vs FP8 ceil)结论标准 PyTorch → Flash v367× 加速v2 → v3约2× 加速H100 专属3.4 哪个版本配哪个硬件GPU推荐 Flash 版本V100 / T4v1v2/v3 不一定支持A100 / L40v2v3 部分支持但优化不到位H100 / H200v3B200v3v4 即将出专门为 Blackwell 优化四、工程实战怎么用上 Flash Attention4.1 用 HuggingFace Transformers 自动启用最简单的方式from transformers import AutoModelForCausalLM, AutoTokenizer model AutoModelForCausalLM.from_pretrained( Qwen/Qwen3-32B-Instruct, attn_implementationflash_attention_2, # ← 关键 torch_dtypeauto, device_mapauto, )支持的选项attn_implementation eager # 朴素PyTorch 实现慢 attn_implementation sdpa # PyTorch 2.0 内置使用 backend attn_implementation flash_attention_2 # Flash v2 attn_implementation flash_attention_3 # Flash v3 (Transformers 4.46 支持)Tipsdpa是 PyTorch 内置的scaled_dot_product_attention它在底层会自动选择 Flash 或 Memory-Efficient 实现——很多情况下这就够用flash_attention_2/flash_attention_3需要pip install flash-attn4.2 用 PyTorch SDPA最通用PyTorch 2.0 内置了scaled_dot_product_attention会自动用 Flash Attention 后端import torch.nn.functional as F def my_attention(q, k, v, maskNone): # 自动用 Flash Attention 如果可用 output F.scaled_dot_product_attention( q, k, v, attn_maskmask, dropout_p0.0, is_causalTrue, ) return output控制后端from torch.nn.attention import SDPBackend, sdpa_kernel with sdpa_kernel(SDPBackend.FLASH_ATTENTION): output F.scaled_dot_product_attention(q, k, v, is_causalTrue)可选 backend•FLASH_ATTENTION── Flash Attention 实现•EFFICIENT_ATTENTION── Memory Efficient Attention•MATH── 标准实现fallback•CUDNN_ATTENTION── cuDNN 实现新4.3 在 vLLM 中vLLM默认就用 Flash Attention你什么都不用做vllm serve Qwen/Qwen3-32B-Instruct # 自动用 Flash Attention v2/v3看硬件强制版本# vLLM 0.6 支持 VLLM_ATTENTION_BACKENDFLASH_ATTN vllm serve ... VLLM_ATTENTION_BACKENDFLASH_ATTN_3 vllm serve ...4.4 训练时的 Flash Attention训练阶段 Flash Attention 收益更明显——因为序列更长、需要反向传播。from transformers import AutoModelForCausalLM, TrainingArguments model AutoModelForCausalLM.from_pretrained( ..., attn_implementationflash_attention_2, torch_dtypetorch.bfloat16, ) training_args TrainingArguments( ..., bf16True, # 必须用低精度才能用 Flash Attention gradient_checkpointingTrue, # 配合用节省更多显存 )实测训练 70B 模型 8K context不开 Flash Attention每 step 4.2 秒显存 78 GB开 Flash Attention v2每 step 1.8 秒显存 42 GB2× 加速 47% 显存节省。这就是为什么训练大模型必须用 Flash Attention。4.5 安装# Flash Attention v2 pip install flash-attn --no-build-isolation # Flash Attention v3H100 only目前仍在 hopper 分支 pip install githttps://github.com/Dao-AILab/flash-attention.githopper常见安装坑CUDA 版本要匹配建议 12.x编译时间长首次安装 30-60 分钟需要 ≥ 8 GB 内存编译没有预编译 wheel 时编译失败 → 安装 ninja 试试五、扩展话题Flash 家族还在演进5.1 Flash Decoding推理专用Flash Attention v2 主要为训练优化长 seq、batch 大。推理有不同的瓶颈Decode 阶段每次只处理 1 个 tokenKV Cache 上的 attention 是 1 × N 矩阵不是 N × N真正的瓶颈是并行度不足Flash DecodingDao 2023.10专门解决这个把 KV 序列也切到多个 SM上并行每个 SM 处理 KV 的一部分最后用 log-sum-exp 合并效果Decode 速度提升2-8×看 batch 和 seq长上下文场景尤其明显128K decode 提升 5×当下地位vLLM、SGLang 等推理框架都已集成。5.2 Ring Attention跨卡 Flash第 5 篇我们讲过 Ring Attention——它本质上就是 Flash Attention 的分布式版本把 KV 切到多张卡每张卡持有局部 KVKV 在卡间环形传递每张卡轮流和其他卡的 KV 做 Flash Attention这是训练 / 推理 1M 上下文的基础。5.3 Triton 实现Flash Attention 原生用 CUDA 写但Triton 版本越来越流行Triton 是 OpenAI 开源的 GPU kernel DSL比 CUDA 简单性能接近 CUDAv2 大概 90%v3 仍在追赶可读性极强——你可以读 Triton 版的 Flash Attention 来理解算法vLLM 部分 backend 就是 Triton 实现。5.4 什么时候 Flash Attention 不适用虽然 Flash Attention 是标配但有一些场景不适用或收益有限场景原因序列极短 256IO 占比不大传统 attention 反而更快自定义 attention如 ALiBi 老版Flash 默认不支持任意 mask要专门修改FP32 训练Flash v1/v2 仅支持 FP16/BF16v3 加 FP8老 GPUPascal / VoltaFlash 需要 Ampere 架构极特殊 attention 模式局部 全局混合需要专门定制但 95% 的场景Flash Attention 都是无脑选项。六、Flash Attention 给工程师的启示6.1 算法 硬件 真正的优化Flash Attention 的成功不是算法创新softmax 还是那个 softmax也不是新硬件GPU 没变而是两者结合理解算法的数学结构理解硬件的物理特性重新设计两者的接口这是大模型系统优化的核心方法论不要只看算法也不要只看硬件而是两者协同。6.2 IO 优化的普适性Flash Attention 的分块 流式合并思路在很多地方都能用量化W4 FP16 也用类似思想分块MoE专家计算和数据搬运的 overlap分布式训练通信和计算的 overlap训练 checkpointing分块保存激活如果你做系统优化多想想能不能不物化中间结果——这是个屡试不爽的优化方向。6.3 不要害怕底层Flash Attention 的实现要写 CUDA / Triton kernel这让很多工程师望而却步。但理解它的原理并不要求你能从零写——理解 Tiling、在线 softmax、IO/compute 平衡这些概念已经足够你做正确的部署决策。七、结语Flash Attention 是大模型时代的基础设施读完本文你应该明白GPU 算力增长比带宽快IO 是大模型的主要瓶颈Flash Attention 用 Tiling Online Softmax 把 attention IO 量从 O(n²) 降到 O(n)v1 / v2 / v3 演进v1 开创、v2 优化、v3 适配 H100 FP8使用方式HuggingFace 加attn_implementation、PyTorch 用SDPA、vLLM 默认启用训练比推理收益更大训练 70B 8K 上下文2× 加速 47% 显存节省Flash Decoding / Ring AttentionFlash 家族在持续演进参考文献13.Flash Attention 原理与实践让 Attention 重新成为算力游戏