作者:企鹅火烈鸟🦩
来源:https://research.colfax-intl.com/cutlass-tutorial-persistent-kernels-and-stream-k/
前言
这篇是Hopper矩阵乘法系列的最后一篇文章了,欢迎来到我们关于 GEMM(通用矩阵乘法)教程系列的第三部分。在第一和第二部分中,我们从单个线程块的视角详细讨论了 GEMM,介绍了 WGMMA matmul 原语、流水线技术以及 warp 专业化。在本部分中,我们将从整个网格的角度考察 GEMM。在这一层面,主要有两类优化方式:(1)利用线程块刷选(swizzling)和集群来最大化 L2 缓存命中率;(2)更好地将任务分配到各线程块,以充分利用 GPU 的计算资源并实现良好的负载均衡。本文将重点讲解第二类优化(但在附录中也会讨论第一类优化)。
具体来说,我们将讨论一种名为 Stream-K 的划分策略,它用以解决波量化(wave quantization)问题——当工作单元的数量不能被流处理多处理器(SM)的数量整除时就会出现该问题。当采用标准的基于块的输出划分方式无法充分利用 GPU 时(例如 M 和 N 较小但 K 很大),Stream-K 也同样非常有用。
本博文结构如下:我们首先介绍波量化问题及持久内核(persistent kernel)的概念。随后,我们将介绍多种在线程块间划分 GEMM 工作负载的策略,包括 Stream-K 及其前身 Split-K,重点关注它们如何应对波量化问题。接着,我们将说明内核开发者如何编写自己的 tile 调度器;作为例子,我们在本系列第二部分的 GEMM 内核中增加了 Stream-K 的实现。
Wave quantization
一块 NVIDIA GPU 由多个流处理多处理器(SM)组成:每个 SM 拥有自己的共享内存、寄存器文件、张量核心等硬件资源,并且各自独立运行。理想情况下,负载能在 SM 之间最大程度地并行分配,使所有 SM 在整个内核执行期间都能保持忙碌。然而,如果某些 SM 完成分配任务的速度快于其他 SM,这些 SM 就会处于空闲状态,等待剩余 SM 完成,这便造成了负载不均衡的现象。
假定某项计算可以被划分为等大小的工作单元,而且每个工作单元都能被单个 SM 在相同的时间内完成。例如,GEMM 通常会被划分为多个工作单元,每个单元负责计算一个 bM x bN 的输出 tile。这些工作单元随后会被分配给线程块(CTA),每个 CTA 会在可用的 SM 上完成其分配的工作单元。我们将工作单元分配给 SM 的过程称为调度(scheduling)。
当工作单元的数量超过可用 SM 的数量时,这些工作单元会分批次被处理,每一批次称为一个波(wave):即每个可用 SM 各自完成一个工作单元,便构成一个波的完成。
波量化(wave quantization)问题就出现在工作单元数量不能被可用 SM 数量整除的时候。例如,假设有 10 个工作单元和 4 个 SM,那么工作单元的执行时间轴如下所示:

在这种情况下,前两波是完整的波,每个 SM 都得到了充分利用。但最后一波是一个部分波,只占用了一半的 SM。
当工作单元的数量相对于 SM 数量较少时,波量化现象会严重降低性能。例如,在一块拥有 114 个 SM 的 H100 PCIe GPU 上,如果计算任务有 115 个工作单元,则需要 2 个波——这与有 228 个工作单元的计算任务所需的波数完全相同!换句话说,仅仅多出第 115 个工作单元,就会让设备利用率大约减半。另一方面,虽然拥有 114,001 个工作单元的计算任务也会受到同样的量化影响,但其带来的性能损失与整个内核的总开销相比可以忽略不计。你可以在 NVIDIA 深度学习性能指南中找到更多相关信息。
为了观察波量化的实际影响,我们可以采用本系列第二部分中实现的 GEMM 内核,并在不同波数下测量其性能。假设我们要计算一个 MxK 的矩阵 A 与 KxN 的矩阵 B。令 bM 和 bN 为工作 tile 的尺寸,并假设它们能整除 M 和 N。则总波数为 ceil((M/bM * N/bN)/num_SMs)
。为了研究量化效应,我们关注每个 SM 分到的 tile 数,即 (M/bM * N/bN)/num_SMs
;其小数部分表示最后一波的填满程度。因此,我们将固定 M=1024、K=4096,并以 bN(对于我们来说是 192)为步长,逐步增加 N 的值。

左侧的图展示了性能(以 TFLOPs/s 为单位),右侧的图则显示了耗时,所有数据均在 H100 PCIe GPU 上进行的基准测试。图中的虚线表示波的分界线,即每个 SM 分配到的 tile 数量跨越整数值的位置。左图清晰地展现了波量化效应——每当跨越波界时,性能会出现明显的下降。相应地,右图显示了耗时主要由总波数这个离散参数决定(具体来说,当 x ∈ (0,1] 时为 1 波,x ∈ (1,2] 时为 2 波,以此类推)。
需要注意的是,第二次量化带来的影响要小于第一次——随着波数的增加,波量化效应的影响会逐渐减弱。然而,想要增加波数并非易事,尤其是在 NVIDIA GPU 的 SM 数量随着新架构不断增加的情况下。因此,我们必须提出一些策略,在不对问题规模做假设的前提下,尽可能减轻波量化的影响。
Persistent kernels
为了解决波量化问题,我们需要设计更优的划分和调度方案。在本博客之前展示的内核实现中,网格(grid)的大小通常取决于问题的维度,每个 CTA 负责处理一个工作单元。例如,在 GEMM 中,工作单元指的是 MxN 输出矩阵中的 bM×bN tile,其中 bM 和 bN 在编译时就已固定。每个工作单元会由一个 CTA 在 M/bM × N/bN 的网格中独立计算。因此,我们的 kernel 启动参数通常如下所示:
dim3 dimGrid(ceil_div(M, bM), ceil_div(M, bN));
这种做法的问题在于,虽然我们可以在一定程度上控制线程块如何分配到 SM,但很难实现更复杂的调度策略。因此,我们将采用另一种设计思路:持久化内核(persistent kernels) 。在持久化内核中,网格的大小是一个固定值。通常,这个值等于可用 SM 的数量,这样每个 CTA 就能独占一个 SM。我们可以通过如下 CUDA 代码获取用于 dimGrid
的 SM 数量:
int num_SMs;
cudaGetDeviceAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, device_id);
dim3 dimGrid(num_SMs);
每个 CTA 会在其所分配的 SM 上持续存在,并不断处理多个工作单元,直到所有任务都完成为止。这种设计上的改变让程序员在调度方面拥有了更大的自主权,可以灵活地控制每个 CTA 如何遍历各个工作单元。借助这种灵活性,我们能够以最小化波量化和负载不均衡的方式分配工作任务。
在实际应用中,工作单元分配给 CTA 的过程通常由 tile 调度器(tile scheduler)负责。tile 调度器本质上就是一个功能增强的迭代器,用于告知每个 CTA 其下一个工作单元的位置以及何时终止处理。虽然每个输出 tile 所需的总工作量并未发生改变,但通过切换不同的 tile 调度器,我们能够尝试更复杂的策略来最小化负载不均衡,比如采用 Stream-K 方案。
处理持久化内核中的wave quantization
为了逐步理解 Stream-K,我们也有必要先看看一些更简单但效率较低的波量化应对方法。Stream-K 相关论文对此有非常深入的讨论,强烈建议大家阅读。为了便于读者理解,这里我们对论文中的讨论内容做一个简要总结。
在本节中,为了让数据更易于理解,我们将以一块虚构的 GPU——Hipparchus H10(仅有 4 个 SM)为例。
数据并行
我们先从最基本的版本说起:也就是简单地在 M 和 N 方向上平均划分 tile,并采用轮询(round-robin)方式分配给各个 CTA。需要注意的是,这种做法本质上和非持久化、基于工作 tile 网格启动的内核情况几乎是一样的,唯一的区别在于分配顺序得到了保证。尽管如此,研究这种策略仍然有意义,因为它能帮助我们理解在什么情况下波量化会成为问题。由于各个工作单元之间没有依赖关系,这种调度方式也被称为数据并行(data-parallel)工作调度。

图 1 展示了一个划分示例。在这里,GEMM 工作负载被分成了 9 个工作 tile。由于每个工作单元都是等价的,这些 tile 会以“波”的形式被处理。具体来说,这 9 个工作 tile 会在 H10 的 4 个 SM 上分 3 波执行:前两波为完整的波,最后一波是部分波,只占用了 4 个 SM 中的 1 个。如果每个工作 tile 在其 SM 上都能达到 100% 利用率,那么整个计算过程的平均利用率就是 2.25/3 = 75%。
最直接的应对方法,就是回到这样一个事实:如果工作单元数量增加,波量化问题的影响就会减小——而我们可以通过缩小每个工作单元的规模来增加工作单元的数量。

在图 2 中,我们将 bN 在 N 方向上缩小了一半。现在共有 18 个工作 tile,可以分成 5 个波来执行:4 个完整的波,以及 1 个部分波(其中 4 个 SM 里只有 2 个被占用)。再次假设每个工作 tile 都能以 100% 利用率被计算,那么整个计算过程的平均利用率就是 4.5/5 = 90%。此外,图 2 中的每个工作 tile 所需的 FLOPs 是图 1 的一半——粗略估算,每个波执行的时间应该是图 1 的一半。因此,虽然图 2 有 5 个波而图 1 只有 3 个,图 2 的总用时只有 (5*0.5)/3 = 图 1 的 83%!那会有什么问题吗?
不幸的是,这里我们做了过多的简化假设,已经无法正确模拟 Hipparchus H10 的实际行为了。核心问题在于,随着 tile 尺寸的减小,每个工作 tile 的计算效率可能变得更低。因此,假设 tile 尺寸减半就能让计算时间也减半,或者让单个 CTA 的利用率保持不变,这往往是不成立的。
其中一个主要缺点就是算术强度的降低。由于内存访问耗时较高,我们希望每次内存访问能配合大量的算术运算,以掩盖内存访问延迟。对于 GEMM,一个 CTA 计算一个 bM × bN × bK 的 matmul tile,会执行 2·bM·bN·bK 次算术操作,并进行 (bM·bK + bN·bK + bM·bN) 次全局内存(GMEM)访问。可以看到,bN 减半时,第一个数减半,而第二个数却没有减半。例如,128 x 128 x 128 的 tile 大小会带来每次 GMEM 传输 85.3 次运算,而 128 x 64 x 128 的 tile 大小则只有 64 次运算/每次 GMEM 传输。
还有一个问题是,如果 CTA 的大小不变,tile 尺寸减半意味着每个 CTA 里的 warp 需要处理的指令数也减半。这会减少 warp 调度器的延迟隐藏空间,而延迟隐藏对于流水线 GEMM 的高性能至关重要。
最后,tile 大小还可能受 MMA atom 选择的约束。例如,H10 可能要求使用 128 x 128 x 16 的 WGMMA atom 以获得最大吞吐量。这也对 tile 的最小尺寸提出了限制。
在这些因素之间取得平衡并不是一件显而易见的事,针对特定问题找到合适的 tile 大小往往需要反复试验——比如使用 CUTLASS Profiler 这样的工具。
Split-K
到目前为止,我们只在 M 和 N 方向上进行了划分,但其实还有另一个可以划分的维度:K 方向。当 K 很大时,在 K 方向上划分(Split-K)会非常有效;不过和之前一样,如果 bK 太小,同样会带来算术强度和延迟隐藏方面的损失。
Split-K 调度方式会将一个 tile 沿 K 方向均匀地分成若干份。例如,在图 3 中,我们就在 K 方向上将 tile 分成了 2 个工作单元。

这种策略带来了一个新的复杂性:每个 CTA 只为其 bM × bN 输出 tile 累加了一部分结果。为了完成计算,负责同一个输出 tile 的多个 CTA 需要将它们的结果合并。通常的做法是在辅助 GMEM 工作区中进行“turnstile reduction”。每个协作计算同一 tile 的 CTA,都会等前一个 K 索引的 CTA 到达同步屏障(barrier)后,再把自己的部分结果累加到工作区,然后自己也到达屏障。最后一个 CTA 则不是再向工作区累加,而是从工作区读取数据到自己的累加器,并计算收尾(epilogue)。注意,额外的 GMEM 访问和 barrier 同步会引入额外开销,在图 3 中以“arrive”和“reduce”模块的形式表现出来。
Split-K 也引入了一个新的超参数——分割数(splits 数量),这会带来一系列权衡:
-
增加分割数可以减小波量化带来的影响,从而可能提升整体 SM 利用率。 -
增加分割数会让 K 方向的 tile 尺寸减小,这可能会提升 GMEM 访问与计算的比值(内存带宽压力变大)。 -
增加分割数还会减少每个 CTA 执行的指令数,因此降低隐藏延迟的能力。 -
我们还引入了同步和归约的额外开销,这是 Split-MN 不存在的。分割数越多,同步的成本也越高。
Stream-K
到目前为止,我们讨论的这些策略虽然改善了波量化问题,但并没有真正消除它。回到最初的例子——将 9 个工作 tile 分布在 4 个 SM 上——理想情况下,每个 SM 应该能够运行 2.25 个波。这也正是 Stream-K 策略的动机所在。
Stream-K 策略为每个 SM 分配一个持久化的单一 CTA。每个 CTA 会被分配到一个“分数”数量的工作 tile,其中被拆分的 tile 会沿 K 方向进行划分。和 Split-K 策略一样,对于每个被拆分的工作 tile,协作的 CTA 可以在 GMEM 工作区中通过 turnstile reduction 合并它们的结果。

例如,在图 4 中,SM0 上的持久化 CTA 计算了完整的工作 tile 0、完整的工作 tile 1,以及工作 tile 2 的 1/4。SM1 上的持久化 CTA 计算了 tile 2 的其余部分、完整的 tile 3,以及 tile 4 的一半,依此类推。部分 tile 的调度方式是:确保每个 tile 的第一个部分会比最后一个部分提前很多被计算,从而最小化同步开销(但要注意,如果 tile 在 K 方向上非常长,这点未必总能做到)。
我们来把 Stream-K 和之前讨论的策略进行对比:
-
通过消除“波”,我们消除了量化问题。每个 CTA 计算 2.25 个工作 tile。除了同步和归约所需的额外时间外,总体计算时间大约是 2.25 个单位,而原始 kernel 需要 3 个单位。 -
许多原本 128 x 128 x 128 的工作 tile 还是由单个 CTA 完全处理,因此部分保留了大 tile 的优势:高算术强度、较长的指令序列,以及能够使用大 WGMMA 指令。如果第一个 kernel 每个 CTA 都能达到 100% 利用率,这里同样可以做到。 -
在很多情况下,我们可以让输出 tile 的前几个部分比最后一个部分提前很多被计算,这样负责 epilogue 的 CTA 实际上不需要长时间等待同步屏障。 -
但此内核确实需要额外的 GMEM 传输,以便不同 CTA 之间共享部分 tile 的数据。
Hybrid Stream-K
我们还可以对内核做最后一项优化,这涉及到缓存性能。对于一个采用 tile 划分的 GEMM 内核来说,每个操作数的 tile 通常会被多个输出工作 tile 所共享。例如,在 split-MN 的情况下,B0 tile 会被用来计算输出的 tile 0、1 和 2。

这里,输出 tile 0、1 和 2 是同时被计算的。当某个 CTA 从全局内存中获取 tile B0 时,这个 tile 也会被放入 L2 缓存。其他 CTA 如果也请求 tile B0,就可以直接命中缓存,从而更快地加载数据。但由于缓存容量有限,旧的数据可能被驱逐,所以这些请求需要在时间上尽量接近,才能最大化缓存利用率。
更具体地说,操作数 tile 也会在 K 方向上进行分块,每个 CTA 会在其操作数 tile 的 K 块上做内层循环。当第 0 波开始时,SM0、SM1 和 SM2 会同时请求 tile B0 的第 0 个 K 块,其中有两个请求会命中缓存。在下一次循环迭代中,SM0、SM1 和 SM2 又会同时请求 tile B0 的第 1 个 K 块,依此类推。
然而,stream-K 内核引入了“错位”(skew):由于每个 SM 开始时处理的部分 tile 大小不同,它们往往会在同一时刻访问不同的 K 偏移。回到图 4,SM0 和 SM2 虽然都在第 0 波开始时用到 B0 的数据,但 SM0 需要第 0 个 K 块,而 SM2 需要更靠中的数据。实际上,这种调度下各个 SM 的 K 偏移不会对齐,使得缓存命中的概率大大降低。总之,消除“波”并让不同 SM 的调度不同步,带来了缓存性能下降这一隐性代价。
我们可以通过重新调度,将计算过程设计成持久化内核与普通数据并行内核的混合体来解决这个问题。由于数据并行调度不会出现错位,因此我们应该尽可能多地采用这种方式,仅将 Stream-K 用于处理波量化残留的极少量 tile。为了让 Stream-K 阶段 SM 之间的负载均衡,需要分配 1 个完整波以及剩余的部分波到这个阶段。
这种调度如图 6 所示。最初的 Stream-K 阶段会处理 1 到 2 个完整波的计算。每个 SM 最多分配到 2 个部分工作 tile。通过这种设计,这些 tile 的总大小与 CTA 无关,因此所有 CTA 预计会在同一时间完成这阶段的计算。一旦这阶段结束,只剩下完整的工作 tile,并且剩下的数量能被 SM 数整除。这样,这些工作 tile 就可以用非持久化、数据并行的策略来计算——这种策略没有波量化问题,同时缓存性能更好。如图 6 所示:

在这里,我们可以预期工作 tile 6、7 和 8 会在几乎相同的时间被计算,从而在访问操作数 tile B2 时能够命中缓存。类似地,工作 tile 5 和 8 由于共享同一个 A tile,也能利用缓存。在这个例子中,数据并行阶段只包含 1 个波,但如果 GEMM 更大、工作 tile 更多,数据并行阶段会更长,缓存的利用率也会更高。
Tile 调度器抽象
由于工作划分和调度的问题在很大程度上与每个 CTA 的内存及计算操作是独立的,像 CUTLASS 这样的 GEMM 实现通常会用一种称为 tile 调度器(tile scheduler)的抽象来封装这些逻辑。(这种方式不仅适用于 GEMM——比如 FlashAttention-3 也支持基于 tile scheduler 类的持久化内核。)在下一节我们会具体分析 CUTLASS 的实现,这里我们先概述一下 tile 调度器一般承担的职责。
首先,内核的 grid 形状取决于 tile 调度方式。因此,tile 调度器负责决定内核的 grid 大小。对于非持久化内核,这个 grid 大小与逻辑网格相同,取决于问题规模;对于持久化内核,grid 大小通常是固定的,并很可能等于 SM 的数量。我们会在启动内核前向 tile 调度器查询 grid 大小,并用它来配置 kernel 启动参数。
在内核内部,每个线程会构造一个 tile 调度器的实例。主循环和 epilogue(收尾计算)现在会被包裹在一个由调度器提供的工作 tile 循环中,大致如下所示:
for (auto worktile = scheduler.get_initial_tile();
scheduler.is_valid(worktile);
worktile = scheduler.get_next_tile(worktile)) {
auto [m_block, n_block, k_block_start, k_block_stop] = worktile.get_block_coord();
for (k_block = k_block_start; k_block < k_block_stop; ++k_block) {
// mainloop
}
// epilogue
}
实现这些迭代器原语的一种简单方式,是让调度器维护一个线性的工作 tile 索引。对于持久化内核,每个 CTA 最初会获得索引为 blockIdx.x
的工作 tile(这其实就是底层 SM 的线性编号);CTA 通过每次递增 gridDim.x
(即 SM 的数量)来前进到下一个 tile;只要 tile 的索引没有超过总 tile 数,就说明当前 tile 是有效的。将线性索引映射到实际的 (M, N) tile 坐标的工作则交由 worktile 对象来处理。
这种方式已经足够实现持久化数据并行调度了,但更复杂的调度方式则需要更多功能。例如,对于 Stream-K,工作分配在 K 方向上的大小依赖于 tile,因此 worktile 实际上应该为内核提供四个坐标,如代码示例所示。
对于 Stream-K 和 Split-K,两者都存在一些 CTA 会输出需要聚合的部分结果,这带来了如下影响:
-
需要额外的 GMEM 工作区,既要存放部分结果,还要存放允许多个 CTA 协同工作的 barrier(同步屏障)对象数组。所需空间大小依赖于问题规模,因此必须在 kernel 启动前动态分配。在内核执行期间,调度器应为 CTA 提供合适的工作区指针。 -
每当开始处理新工作 tile 时,每个 CTA 需要知道当前 tile 是完整输出 tile(此时结果应写回输出张量),还是部分结果 tile(此时结果应写入工作区)。 -
只有一个 CTA 负责输出 tile 的 epilogue(收尾计算)。这个 CTA 不是把结果累加到工作区,而是从工作区累加到自己的寄存器,然后进行 epilogue。调度器需要告知每个 CTA 它是否需要对当前 tile 执行 epilogue。
正如 CUTLASS 的实现所展示的,上述简单轮廓还可以有许多改进,比如调度器可以决定 tile 的启动顺序,使用启发式方法在 Stream-K、Split-K 和数据并行模式之间切换,以及在 Hopper 架构上合理使用 cluster。我们接下来会对这些进行分析。
我们的 GitHub 代码示例中提供了三种调度器的例子:一个简单的非持久化调度器(为每个 CTA 分配 1 个 tile,网格大小由问题形状决定);一个数据并行持久化调度器;还有一个 Stream-K 混合调度器,实现了部分但不是全部 CUTLASS 的优化。实践中我们发现,要获得较好的性能,许多 CUTLASS 的优化都是必要的:特别是由于归约带来的额外 GMEM 访问和更小的 tile 尺寸是确实的性能损失,Stream-K 的工作分界也需要仔细调整以最小化该损失。
下图展示了 Stream-K tile 调度器的一些性能指标。与数据并行调度器相比,我们的 Stream-K 实现能在每个波的初期表现良好,显著减弱波量化效应,但当进入剩余部分波时,性能会下降。”Heuristic” 曲线采用了 CUTLASS 的启发式策略,即当最后一波至少填满一半时,从 Stream-K 切换回数据并行,这显然是一个很好的选择。

结论
在本文中,我们讨论了波量化(wave quantization)及其对 GEMM 性能的影响。我们注意到,在第二部分实现的 GEMM 中,波量化带来了显著的性能波动。随后,我们探讨了多种应对波量化的策略,重点介绍了 Stream-K 方法。最后,我们展示了一个 Stream-K tile 调度器的实现版本,以消除我们 GEMM 实现中的波量化影响。至此,我们基于 CUTLASS/CuTe 抽象实现高性能 Hopper 架构 GEMM 的三部曲就全部结束了。
(文:GiantPandaCV)