当现代大语言模型(LLM)将上下文长度推向 128K、256K 甚至百万 token 级别时,一个不起眼的训练操作正在成为严重瓶颈:注意力蒸馏(Attention Distillation)。这个被广泛用于知识蒸馏、模型压缩、持续学习和稀疏注意力训练的核心操作,其现有实现需要将两个 O(N²) 的注意力矩阵完整写入 HBM,再逐元素计算 KL 散度——在 64K 上下文、32 注意力头、BF16 精度下,仅存储两个注意力分布就需要 512 GB,是单张 NVIDIA H200(141 GB HBM)的 3.6 倍。
由上海交通大学、华为和复旦大学联合团队提出的 StreamKL,以首个注意力 KL 散度融合 GPU 原语彻底解决了这一问题。通过巧妙的在线公式推导和 tile-by-tile 流式计算,StreamKL 将注意力蒸馏的额外 HBM 占用从 O(N²) 降至 O(1),前向速度提升最高 43 倍,反向提升 14 倍,使得原本需要多 GPU 的长上下文蒸馏任务可以在单 GPU 上完成。
注意力蒸馏为什么重要
注意力蒸馏的核心思想很简单:训练一个注意力分布去匹配另一个,损失函数是它们之间的 KL 散度。这个看似朴素的操作在现代 Transformer 工作流中无处不在:
- 知识蒸馏:MiniLM 等工作通过蒸馏教师模型与学生的注意力分布来迁移知识;
- 模型压缩:剪枝或量化后的模型通过注意力蒸馏保持原始模型的注意力模式;
- 持续学习:通过将当前模型的注意力分布与历史快照对齐,缓解灾难性遗忘;
- 稀疏注意力训练:DeepSeek V3.2 和 GLM-5 等模型在训练稀疏注意力时,使用注意力蒸馏将轻量级 indexer 的分布与原始密集注意力分布对齐。
这些场景的共同点是:训练目标都是最小化两个 N_Q × N_K 的注意力矩阵 P₁ 和 P₂ 之间的 KL 散度。当 N_Q = N_K = 64K 时,这就是一个 64K × 64K 的矩阵,且有两个。
为什么不能直接套用 FlashAttention 的思路
熟悉 GPU 优化的读者可能会问:FlashAttention 不是已经通过融合 kernel 解决了标准注意力的 O(N²) 物化问题吗?直接套用不就行了?
答案是:不能直接套用。标准注意力计算只需要处理一个 softmax 分布,然后乘以 V 矩阵。而 KL 散度耦合了两个独立的 softmax 分布,通过一个加权的 logit 差值项连接。如果直接使用在线 softmax,要么仍然需要物化 O(N²) 的中间结果,要么需要两次独立的遍历——计算量和 HBM 读取量翻倍。
核心困难在于:当两个分布的 running maximum 各自变化时,如何正确地对累积的加权 logit 差值进行 rescale?这需要全新的数学推导。
StreamKL 的技术创新
在线公式:耦合双分布 KL 的增量计算
StreamKL 的核心贡献是对注意力 KL 散度进行了巧妙的数学重写。对于行级 KL 散度 L = Σᵢ P₁ⁱ log(P₁ⁱ/P₂ⁱ),将 softmax 概率用在线统计量(running maximum m 和未归一化和 l)表示后,经过推导可以得到:
L = acc / l₁ + LSE₂ − LSE₁
其中 acc = Σᵢ e^(S₁ⁱ − m₁) · (S₁ⁱ − S₂ⁱ) 是一个可以在流式处理中增量更新的累加器。关键的是,这个累加器和标准化常数 l₁, l₂ 都可以随着 key tile 的流入,使用类似在线 softmax 的 rescale 机制逐 tile 更新——且累加器的 rescale 只是一个 B_Q 维向量,而非 FlashAttention 中需要 rescale 的 B_Q × d 输出矩阵,算术开销降低了 d 倍(通常 ≥128)。
前向:单遍流式 Kernel
StreamKL 的前向 kernel 将 query 分块,每个 thread block 处理一个 query tile,在内部循环中依次流式加载 K₁ 和 K₂ 的 tile。每处理一个 tile:
- 在 SRAM 中计算 logit tile S₁ = Q₁K₁ᵀ 和 S₂ = Q₂K₂ᵀ
- 更新 running maximum m₁, m₂
- 计算 correction factor 进行 rescale
- 更新累加器 acc
前向结束后,仅将 KL 值和两个 LSE 向量(O(N_Q) 个标量)写回 HBM。
反向:基于重计算的 Tile-by-Tile 梯度
反向传播不存储任何 O(N²) 中间结果。通过保存的 LSE₁ 和 LSE₂,backward kernel 可以随时按需重建 P₁ⁱ = exp(S₁ⁱ − LSE₁) 和 P₂ⁱ = exp(S₂ⁱ − LSE₂)。StreamKL 支持两种梯度设置:优化 P₂(标准蒸馏)和优化 P₁,并对后者设计了数值稳定的 log-ratio 计算,避免概率趋近于零时产生 NaN。
GPU Kernel 工程优化
在算法创新之外,StreamKL 还实现了一系列精巧的 GPU kernel 优化:
- Split-K 变体:当 batch size 小或 N_Q 少导致 SM 利用率不足时,沿 key 维度分片,将部分统计量写回 HBM 后由轻量级 reduce kernel 合并,充分饱和 GPU;
- 融合反向 Kernel:通过原子操作将 dQ 和 dK 的计算合并到一个 kernel 中,每个 QK tile 对只访问一次,HBM 流量减半;
- Hopper 架构适配:利用 Tensor Memory Accelerator(TMA)进行异步 bulk copy,并用 exp₂ 替代自然指数以利用单周期 SFU 指令。
实际效果
在 NVIDIA H200 和 A100 上的实验表明:
| 指标 | 前向 | 反向 |
|---|---|---|
| 对比 PyTorch(causal) | 43× 加速 | 14× 加速 |
| 对比 PyTorch(non-causal) | 18× 加速 | 6.5× 加速 |
| 对比 torch.compile | 7.0× 加速 | — |
| 对比 FLA | 7.1× 加速 | — |
| HBM 额外占用 | O(1) | O(1) |
最关键的指标是内存:StreamKL 是唯一能在单 GPU 上支撑 64K+ 上下文蒸馏的方案。传统方法在 chunk size 1024 时就会触发 H200 的 OOM,而 StreamKL 在 128K 上下文下仍稳定运行,甚至支持最高 512K token 的蒸馏任务。
算法-系统协同设计的价值
StreamKL 是 GPU kernel 优化中"算法-系统协同设计"的典型案例。它展示了一条清晰的路径:当现有系统的瓶颈来自底层计算的数学结构(而非单纯的工程实现),通过重新推导公式、改变计算顺序,可以释放数量级的效率提升。
这一思路与 FlashAttention 一脉相承——后者通过在线 softmax 将标准注意力的 HBM 读写从 O(N²) 降至 O(N)。StreamKL 将同样的哲学推广到了更复杂的双分布 KL 散度场景,填补了注意力蒸馏在系统层面的空白。
对于正在训练稀疏注意力 LLM、进行大规模知识蒸馏、或探索长上下文模型压缩的团队而言,StreamKL 提供了一个将"不可能"变为"单 GPU 可行"的关键工具。论文来自上海交通大学、华为和复旦大学,已于 2026 年 6 月 18 日发布在 arXiv 上。

