💡 这是 How to Scale Your Model: A Systems View of LLMs on TPUs 的中文翻译版本,版权归原作者所有。当前这个版本基于commit f3c0a12翻译,后续会随开源仓库的更新不定期更新内容。

训练大语言模型常常感觉像是一门炼金术,但理解和优化模型性能并不一定如此神秘。本书旨在揭开扩展语言模型的科学奥秘:TPU(和GPU)如何工作以及它们如何相互通信,大语言模型如何在实际硬件上运行,以及如何在训练和推理过程中并行化模型,使其在大规模环境下高效运行。如果你曾经想知道"训练这个大语言模型应该花费多少"或"我自己部署这个模型需要多少内存"或"什么是AllGather",我们希望本书对你有所帮助。

深度学习的很多方面仍然归结为一种黑魔法,但优化模型性能并不一定如此——即使在大规模场景下也是如此!相对简单的原则可以适用于所有场景——从处理单个加速器到数万个——理解这些原则可以让你做很多有用的事情:

  • 大致评估模型各部分与理论最优水平的接近程度。

  • 在不同规模下对不同并行方案做出明智选择(如何在多个设备上拆分计算)。

  • 估算训练和运行大型Transformer模型所需的成本和时间。

  • 设计能利用特定硬件优势的算法。

  • 基于对当前算法性能限制的明确理解来设计硬件。

预期背景知识:我们假设你对LLM和Transformer架构有基本了解,但不一定了解它们如何在大规模环境下运行。你应该了解LLM训练的基础知识,最好对JAX有一些基本熟悉。一些有用的背景阅读资料包括这篇博客关于Transformer架构的介绍和这些优秀的幻灯片关于JAX中的LLM扩展。

目标与反馈:到最后,你应该能够舒适地估算在给定硬件平台上Transformer模型的最佳并行方案,以及训练和推理大致需要多长时间。如果你做不到,请给我们发邮件或留言!我们很想知道如何使内容更清晰。

为什么你应该关注这些?

三四年前,我认为大多数ML研究人员不需要了解本书中的任何内容。但如今,即使是"小型"模型也运行得非常接近硬件极限,以至于开展创新研究需要你考虑大规模环境下的效率问题。如果在基准测试上获得20%的提升,但代价是降低20%的硬件效率上限,那这种提升就毫无意义。有前途的模型架构经常失败,要么是因为它们无法在大规模环境下高效运行,要么是因为没有人投入精力使它们能够这样做。

"模型扩展"的目标是能够增加用于训练或推理的芯片数量,同时实现吞吐量的成比例线性增长。这被称为"强扩展"。虽然增加额外的芯片("并行化")通常会减少计算时间,但它也带来了芯片之间额外通信的成本。当通信时间超过计算时间时,我们就会受到"通信瓶颈"的限制,无法实现强扩展。如果我们对硬件了解得足够好,能够预测这些瓶颈将在何处出现,我们就可以设计或重新配置模型来避免它们。

我们在本书中的目标是解释TPU(和GPU)硬件的工作原理,以及Transformer架构如何演变以在当前硬件上表现良好。我们希望这对设计新架构的研究人员和致力于使当前一代LLM快速运行的工程师都有所帮助。

整体概述

本书的整体结构如下:

第1章解释了屋顶线分析以及哪些因素会限制我们的扩展能力(通信、计算和内存)。第2章第3章详细讨论了TPU和现代GPU的工作原理,既包括作为单个芯片时的工作方式,也包括—这一点至关重要—作为带有有限带宽和延迟的芯片间链接的互连系统时的工作方式。我们将回答以下问题:

  • 特定大小的矩阵乘法应该花费多长时间?在什么情况下它会受到计算能力、内存或通信带宽的限制?

  • TPU如何连接在一起形成训练集群?系统的各个部分有多少带宽?

  • 在多个TPU上收集、分散或重新分配数组需要多长时间?

  • 如何有效地乘以在不同设备上分布不同的矩阵?

来自第2章的图表,展示了TPU如何执行元素级乘法。根据我们数组的大小和各种链接的带宽,我们可能会遇到计算受限(使用全部硬件计算能力)或通信受限(受内存加载瓶颈影响)的情况。

五年前,机器学习拥有丰富多彩的架构景观——卷积网络、LSTM、多层感知器、Transformer——但现在我们主要只使用Transformer。我们坚信理解Transformer架构的每个部分都很有价值:每个矩阵的确切大小、规范化发生的位置、每个部分中的参数和浮点运算数量。

第4章 仔细讲解了这些"Transformer数学",展示了如何计算训练和推理过程中的参数和浮点运算数量。这告诉我们模型将使用多少内存,在计算或通信上花费多少时间,以及注意力机制何时相对于前馈模块变得重要。

标准Transformer层,每个矩阵乘法(matmul)显示为圆圈内的一个点。所有参数(不包括规范化层)以紫色显示。第4章更详细地讲解了这个图表。

第5章:训练第7章:推理是本文的核心,我们在这里讨论一个基本问题:给定某一规模的模型和若干芯片,如何并行化模型以保持在"强扩展"状态?这是一个简单的问题,却有着出人意料的复杂答案。从高层次来看,有4种主要的并行技术用于在多个芯片上拆分模型(数据张量流水线专家并行),以及一些其他减少内存需求的技术(重物化优化器/模型分片(又称ZeRO)主机卸载梯度累积)。我们将在这里讨论其中的许多技术。

我们希望在学完这些章节后,你能够为新架构或新环境自行选择适合的并行策略。第6章第8章是实用教程,将这些概念应用于LLaMA-3这一流行的开源模型。

最后,第9章第10章探讨了如何在JAX中实现这些想法,以及当出现问题时如何分析和调试代码。

我们努力为你提供可以自己解决的问题。请不要感到必须阅读所有章节或按顺序阅读。也请留下反馈。目前,这仍是一份草稿,将继续修订。谢谢!

我们要感谢James Bradbury和Blake Hechtman,他们提出了本文中的许多想法。

跳转到各章节

这个系列可能比必要的更长,但我们希望这不会阻碍你。前三章是基础知识,如果你已经熟悉可以跳过,尽管它们介绍了后面使用的符号。最后三个部分可能最具实用价值,因为它们解释了如何处理真实模型。

第一部分:基础知识

第二部分:Transformer

第三部分:实用教程

第1章:屋顶线分析简介

时间都去哪儿了?

让我们从一个极其简单的问题开始:为什么一个算法需要50毫秒而不是50秒或5毫秒?模型内部实际发生了什么占用了大量时间,我们应该预期它需要多长时间?

计算:深度学习模型本质上是一堆矩阵乘法,每个乘法由浮点乘法和加法"操作"(FLOPs)组成。我们的加速器速度决定了这些计算需要多长时间:

$$ \begin{equation} T_\text{math} = \frac{\text{计算FLOPs数}}{\text{加速器FLOPs/秒}} \end{equation} $$

例如,NVIDIA H100可以执行约9.89e14 bfloat16 FLOPs/秒,而TPU v6e可以执行9.1e14 FLOPs/秒。这意味着在H100上执行1e12 FLOPs大约需要1e12 / 9.89e14 = 1.01ms,在TPU v6e上需要1e12 / 9.1e14 = 1.1ms

芯片内部通信:在加速器内部,张量需要在片上内存(HBM)和计算核心之间传输。这种链接的带宽被称为"HBM带宽"。在H100上,这约为3.35TB/秒,在TPU v6e上约为1.6TB/秒

芯片间通信:当我们跨多个加速器分布模型时,张量经常需要在它们之间传输。在我们的硬件上通常有几个选项(ICI、DCN和PCIe),每个选项都有不同的带宽。

无论是芯片内部还是芯片之间的通信,我们都以字节/秒来测量,并使用以下公式估计总通信时间:

$$ \begin{equation} T_\text{comms} = \frac{\text{通信字节数}}{\text{网络/内存带宽字节/秒}} \end{equation} $$

通常情况下(但并非总是),单个芯片内的计算可以与芯片内部和芯片之间的通信重叠。这意味着我们可以通过使用计算和通信时间的最大值来确定训练和推理时间的下限。我们也可以用它们的总和作为上限。在实践中,我们优化的目标是最大值,因为代数计算更简单,而且通过重叠通信和计算,我们通常可以接近这个界限。如果我们以最大值为目标进行优化,那么下限和上限最多相差2倍,因为Tmath + Tcomms ≤ 2 * max (Tmath, Tcomms)。然后,我们通过建模"重叠区域"和开销来进一步提高准确性,这可以通过分析特定模型和目标系统的性能来获取。

$$ T_{lower} = max (T_{math}, T_{comms}) \ T_{upper} = T_{math} + T_{comms} $$

如果我们假设可以完美地重叠通信和计算,当Tmath > Tcomms时,我们能充分利用硬件。我们称之为"计算受限"。当Tcomms > Tmath时,我们往往是"通信受限",至少有一部分加速器的FLOPs/秒被浪费在等待数据传输上。判断一个操作是计算受限还是通信受限的一种方法是查看其"算术强度"或"运算强度"。

定义:算法的算术强度由其执行的总FLOPs与需要通信的字节数之比给出——无论是在芯片内部还是芯片之间。

$$ \begin{equation} \text{算术强度} = \frac{\text{计算FLOPs}}{\text{通信字节数}} \end{equation} $$

算术强度衡量给定操作的"每字节FLOPs数"。从一阶近似来看,当我们的算术强度高时,Tmath相对于Tcomms较大,我们通常能利用大部分可用的FLOPs。当情况相反时,我们在通信上花费更多时间,浪费FLOPs。这种转换发生的点是我们硬件的"峰值算术强度",即峰值加速器FLOPs/秒与加速器带宽的比率。

$$ \begin{align*} T_\text{math} > T_\text{comms} \Leftrightarrow \frac{\text{计算FLOPs}} {\text{加速器FLOPs/秒}} > \frac{\text{通信字节数}}{\text{带宽字节/秒}} & \[0.5em] \Leftrightarrow \frac{\text{计算FLOPs}}{\text{通信字节数}} > \frac{\text{加速器FLOPs/秒}}{\text{带宽字节/秒}} & \[0.5em] \Leftrightarrow \text{强度}(\text{计算}) > \text{强度}(\text{加速器}) & \ \end{align*} $$

强度(加速器)这个量是加速器达到峰值FLOPs/秒时的算术强度。对于TPU v5e MXU,这大约是240 FLOPs/字节,因为TPU可以执行1.97e14 FLOPs/秒并从HBM加载8.2e11字节/秒。这意味着如果一个算法的算术强度低于240 FLOPs/字节,它将受到字节加载的限制,因此我们无法充分利用硬件。让我们看一个这样的例子:

示例(点积):要在bfloat16精度下计算两个向量的点积,x • y: bf16[N], bf16[N] → bf16[1],我们需要从内存加载x和y,每个向量有2 * N = 2N字节,执行N次乘法和N − 1次加法,并将2字节写回HBM

$$ \begin{equation} \text{强度}(\text{点积}) = \frac{\text{总FLOPs}}{\text{总字节数}} = \frac{N + N - 1}{2N + 2N + 2} = \frac{2N - 1}{4N + 2} \rightarrow \frac{1}{2} \end{equation} $$

当N → ∞时。所以点积的算术强度为\frac{1}{2},或者换句话说,点积每加载一个字节执行0.5次浮点运算。这意味着我们的算术强度低于硬件的算术强度,我们将受通信限制。

可视化屋顶线模型

我们可以使用屋顶线图来可视化内存和计算之间的权衡,它在y轴上绘制算法在我们硬件上可达到的峰值FLOPs/秒(吞吐量),在x轴上绘制该算法的算术强度。以下是一个对数-对数图示例:

图示:一个屋顶线图示例,展示了两种具有不同算术强度的算法(算法1和算法2)及其在不同带宽(BW1和BW2)下的相应理论峰值吞吐量。在红色区域中,算法在两种带宽下都受带宽限制,浪费了硬件峰值FLOPs/秒的一部分。黄色区域仅在较低带宽(BW1)下受带宽限制。绿色区域在所有带宽下都受计算限制。在这里,我们使用的是加速器的峰值FLOPs/秒,增加带宽或提高强度都不会带来额外收益。

上图中,随着强度增加(从左向右移动),我们最初看到算法性能(以FLOPs/秒为单位)线性增加,直到达到硬件的临界算术强度,TPU v5e的情况下为240。任何强度较低的算法都将受带宽(BW)限制,并受限于峰值内存带宽(红色显示)。任何右侧的算法都将充分利用我们的FLOPs(绿色显示)。这里,算法1受通信限制,仅使用硬件总FLOPs/秒的一小部分。算法2受计算限制。我们通常可以通过增加算法的算术强度或增加可用内存带宽(从BW1移动到BW2)来提高算法性能。

矩阵乘法

让我们看看我们即将最喜欢的算法:矩阵乘法(又称matmul)。我们将X * YZ表示为X形状为bf16[B, D],Y形状为bf16[D, F],而Z形状为bf16[B, F]。要进行矩阵乘法,我们需要加载2DF + 2BD字节,执行2BDF FLOPs,并写回2BF字节。因此:

$$ \begin{equation} \text{强度}(\text{matmul}) = \frac{2BDF}{2BD + 2DF + 2BF} = \frac{BDF}{BD + DF + BF} \end{equation} $$

如果我们假设我们的本地"批量大小"B相对于DF较小,我们可以得到一个很好的简化。那么我们得到

$$ \begin{equation} \frac{BDF}{BD + DF + BF} \approxeq \frac{BDF}{DF} = B \end{equation} $$

$$ \begin{equation} \text{强度}(\text{matmul}) > \text{强度}(\text{TPU}) \implies B > \frac{1.97e14}{8.20e11} = 240 \end{equation} $$

这对于Transformer矩阵乘法来说是一个合理的假设,因为对于我们的大多数模型,我们的本地批量大小(以token为单位)B < 1024,但DF > 8000。因此,当我们的本地批量大小大于240个token时,我们就变成了计算受限,这是一个非常简单的规则!

💡 要点:对于大多数TPU上的bfloat16矩阵乘法,要使其受计算限制,我们需要使我们的本地批量大小(以token为单位)大于240。

这有一些值得注意的注意事项,我们将在下面的问题中探讨,特别是关于量化(例如,如果我们量化我们的激活值但仍然执行全精度FLOPs),但这是一个值得记住的好规则。对于GPU,这个数字略高(接近300),但通常也得出相同的结论。我们将在下一章讨论GPU和TPU的更低级别细节。

网络通信屋顶线

到目前为止我们讨论的所有屋顶线都是内存带宽屋顶线,都在单个芯片内。这不应该被视为规则。事实上,我们在本书中关注的大多数屋顶线涉及芯片之间的通信:通常是涉及在多个TPU上分片的矩阵的矩阵乘法。

举一个有些刻意的例子,假设我们想要乘以两个大矩阵 X ∼ bfloat16[B, D] 和 Y ∼ bfloat16[D, F],它们均匀地分布在2个TPU/GPU上(沿着 D 维度)。要进行这种乘法(我们将在第3节中看到),我们可以在每个TPU上乘以每个矩阵的一半(例如 X[:, :D // 2] @ Y[:D // 2, :]),然后将得到的"部分和"复制到另一个TPU并将它们加在一起。假设我们可以在每个方向上复制 4.5e10 字节,并在每个芯片上执行 1.97e14 FLOPs/秒。T_{math}T_{comms}是多少?

T_{math}显然是之前的一半,因为每个TPU只做一半的工作,

$$ T_\text{math} = \frac{2BDF}{2 \cdot \text{Accelerator FLOPs/s}} = \frac{BDF}{1.97e14} $$

那么Tcomms呢?这现在指的是芯片之间的通信时间!这只是发送的总字节数除以网络带宽,即

$$ T_\text{comms} = \frac{2BF}{\text{Network Bandwidth}} = \frac{2BF}{4.5e10} $$

因此,当Intensity(matmul (2-chips)) > Intensity(TPU w.r.t. inter-chip network)时,我们变成计算受限(现在是相对于芯片间网络而言),或等价地,当 \frac{BDF}{2BF} = \frac{D}{2} > \frac{1.97e14}{4.5e10} = 4377 或 D > 8755时。请注意,与之前不同,临界阈值现在取决于 D 而不是 B!试着思考为什么会这样。这只是一个例子,但我们强调这种类型的屋顶线对于了解何时可以在多个TPU上并行操作至关重要。

第1章练习题

问题1 [int8矩阵乘法]:假设我们想要在int8精度(每个参数1字节)而不是bfloat16中进行X[B, D]⋅_D Y[D, F] \to Z[B, F]的计算。

  1. 需要从内存中加载多少字节?需要写回内存多少字节?

  2. 总共执行了多少操作?

  3. 算术强度是多少?

  4. 对于T_{math}T_{comms}的屋顶线估计是什么?整个操作运行时间的合理上下限是什么?

假设我们的HBM带宽是8.1e11字节/秒,我们的int8峰值操作速度是3.94e14

点击此处查看答案
  1. 因为我们在int8中存储参数,每个参数1字节,所以我们从HBM加载了BD + DF字节,并写回BF字节。

  2. 这与bfloat16中的情况相同,但理论上int8操作/秒应该更快。所以这仍然是2BDF个FLOPs。

  3. 算术强度是2BDF/(BD + DF + BF)。如果我们像上面一样假设B ≪ DB ≪ F,我们得到算术强度为2B,这意味着我们的规则变成B > HBM int8算术强度/2。使用给定的数字,这个int8强度是3.94e14 / 8.1e11 = 486,所以规则是B > 486/2 = 243。请注意,这基本上没有变化!

  4. T_{math} = 2BDF/3.94e14T_{comms} = (BD + DF + BF)/8.1e11,所以一个合理的下限是max(T_{math}, T_{comms}),上限是T_{math} + T_{comms}

问题2 [int8 + bf16矩阵乘法]:在实践中,我们经常对权重和激活进行不同的量化,所以我们可能会以非常低的精度存储权重,但保持激活(和计算)在更高的精度。假设我们想要将权重量化为int8,但保持激活(和计算)在bfloat16中。在什么批量大小下我们会变成计算受限?假设bfloat16 FLOPs/秒为1.97e14

提示:这具体意味着bfloat16[B, D] * int8[D, F] -> bfloat16[B, F],其中B是"批量大小"。

点击此处查看答案

再次假设B很小,我们有2BDF个bfloat16 FLOPs,但只有DF个权重(而不是bfloat16中的2DF)。这意味着当2B > 240或B > 120时,我们变成计算受限。这个值低得多,意味着如果我们可以进行int8权重量化(这相对容易做到)但仍然进行bfloat16 FLOPs,我们在效率上会获得有意义的提升(尽管int8操作会更好)。

问题3:对于上述问题,为几个D和F值绘制峰值FLOPs与B的屋顶线图。

问题4:如果我们想执行int8[B, D] *_D int8[B, D, F] \to int8[B, F],其中我们想象每个批次元素有一个不同的矩阵。这个操作的算术强度是多少?

点击此处查看答案

让我们先看看总FLOPs和通信量。

  1. 总FLOPs:FLOPs基本相同,因为我们做的BD × DF矩阵乘法数量相同(这在第4节中有更多讨论)。所以这只是2BDF。

  2. 总通信量:我们这里有更多的通信量:BD + BDF + BF

  3. 因此,我们的算术强度现在实际上是2BDF/(BD + BDF + BF)。由于BDF在分母中占主导地位,这大约是2。所以它不再依赖于批量大小,而基本上是常数。这很糟糕,因为它意味着无论如何我们基本上总是通信受限的。

问题5 [GPU的内存屋顶线]:使用NVIDIA为H100提供的规格表,计算矩阵乘法将变为计算受限的批量大小。请注意,Tensor Core FLOPs数字是真实值的两倍,因为它们只有在结构化稀疏性的情况下才能实现。

点击此处查看答案

从规格表中,我们看到报告的bfloat16 FLOPs值是1.979e15 FLOPs/秒,带有星号注明"带稀疏性"。没有稀疏性的真实值是其一半,接近1e15 FLOPs/秒。内存带宽是3.35TB/秒,或3.35e12字节/秒。因此Bcrit是1e15 / 3.35e12 = 298,与TPU非常相似。

第2章:如何思考TPU

什么是TPU?

TPU基本上是一个专门用于矩阵乘法的计算核心(称为TensorCore),连接着一堆高速内存(称为高带宽内存或HBM)。这是一个图表:

图示:TPU芯片的基本组件。TensorCore是左侧的灰色框,包含矩阵乘法单元(MXU)、向量单元(VPU)和向量内存(VMEM)。

你可以将TensorCore基本上看作是一个非常优秀的矩阵乘法机器,但它还有一些其他值得注意的功能。TensorCore有三个关键单元:

  • MXU(矩阵乘法单元)是TensorCore的核心。对于大多数TPU世代,它使用脉动阵列每8个周期执行一次bfloat16[8,128] @ bf16[128,128] -> f32[8,128]矩阵乘法(详见附录B)。

    • 在TPU v5e上,以1.5GHz运行时,每个MXU约有5e13 bf16 FLOPs/s。大多数TensorCore有2或4个MXU,因此例如TPU v5e的总bf16 FLOPs/s是2e14
    • TPU还支持更低精度的矩阵乘法,具有更高的吞吐量(例如,每个TPU v5e芯片可以执行4e14 int8 OPs/s)。
  • VPU(向量处理单元)执行一般的数学运算,如ReLU激活或向量之间的逐点加法或乘法。规约(求和)也在这里执行。附录C提供了更多详情。

  • VMEM(向量内存)是位于TensorCore内的片上暂存器,靠近计算单元。它比HBM小得多(例如,TPU v5e上为128 MiB),但对MXU的带宽要高得多。VMEM的运作有点像CPU上的L1/L2缓存,但更大且由程序员控制。HBM中的数据需要先复制到VMEM中,然后TensorCore才能对其进行任何计算。

TPU在矩阵乘法方面非常非常快。这主要是它们所做的事情,而且做得很好。TPU v5p,迄今为止最强大的TPU之一,每核心每秒可以执行2.5e14 bf16 FLOPs,或每芯片每秒5e14 bf16 FLOPs。一个拥有8960个芯片的单一pod每秒可以执行4 exaflops。这非常强大。这是世界上最强大的超级计算机之一。而且Google有很多这样的设备。

上图还包括一些其他组件,如SMEM和标量单元,它们用于控制流处理,在附录C中有简要讨论,但理解这些并不是关键。另一方面,HBM很重要且相对简单:

  • HBM(高带宽内存)是一大块快速内存,用于存储TensorCore使用的张量。HBM的容量通常在数十GB级别(例如,TPU v5e有16GiB的HBM)。
    • 当计算需要时,张量会从HBM通过VMEM(见下文)流入MXU,结果再从VMEM写回HBM。
    • HBM和TensorCore之间(通过VMEM)的带宽被称为"HBM带宽"(通常约为1-2TB/秒),它限制了内存受限工作负载中计算的速度。

通常,所有TPU操作都是流水线化和重叠的。要执行矩阵乘法X \cdot A \to Y,TPU首先需要将矩阵AX的块从HBM复制到VMEM,然后加载到MXU中,MXU将8x128(对于X)和128x128(对于A)的块相乘,然后将结果一块一块地复制回HBM。为了高效地完成这一过程,矩阵乘法采用流水线方式,使得与VMEM之间的复制与MXU的工作重叠。这允许MXU持续工作而不必等待内存传输,使矩阵乘法受计算限制而非内存限制。

这里是一个如何从HBM执行元素级乘积的例子:

图示:一个动画显示在TPU上执行的逐点乘积,其中字节从HBM加载。注意字节是如何以块的形式从内存中流出,而部分结果在整个数组完全具体化之前就被流水线式地传回。

矩阵乘法看起来几乎相同,只是它会加载到MXU而不是VPU/向量单元,并且加载和存储的顺序会不同,因为相同的权重块用于激活的多个块。你可以看到数据块流入VMEM,然后进入VREG(向量寄存器),然后进入向量单元,然后回到VMEM和HBM。正如我们将要看到的,如果从HBM到VMEM的加载比向量单元(或MXU)中的FLOPs慢,我们就会变成"带宽受限",因为我们无法给VPU或MXU提供足够的工作。

💡 关键要点:TPU非常简单。它们将权重从HBM加载到VMEM,然后从VMEM加载到脉动阵列中,该阵列每秒可以执行约200万亿次乘加运算。HBM ↔︎ VMEM和VMEM ↔︎ 脉动阵列的带宽设定了TPU能够高效执行哪些计算的基本限制。

VMEM和算术强度:VMEM比HBM小得多,但它对MXU的带宽要高得多。正如我们在第1节中看到的,这意味着如果一个算法能将其所有输入/输出放入VMEM,它就不太可能遇到通信瓶颈。当计算具有较低的算术强度时,这特别有帮助:VMEM带宽比HBM带宽高约22倍,这意味着从VMEM读取/写入的MXU操作只需要10-20的算术强度就能达到峰值FLOPs利用率。这意味着如果我们能将权重放入VMEM而不是HBM,我们的矩阵乘法在更小的批处理大小下也能受FLOPs限制。这也意味着本质上具有较低算术强度的算法仍然可以高效运行。VMEM就是太小了,这通常是一个挑战。

一个TPU芯片通常(但并非总是)由两个共享内存的TPU核心组成,可以被视为一个大型加速器,FLOPs翻倍(称为"megacore"配置)。自TPU v4起就是如此。较旧的TPU芯片它们拥有独立的内存,被视为两个独立的加速器(TPU v3及更早版本)。针对推理优化的芯片如TPU v5e每个芯片只有一个TPU核心。

芯片在一个"托盘"上以4个一组的方式排列通过PCIe网络连接到CPU主机。这是大多数读者所熟悉的格式,4个芯片(8个核心,尽管通常被视为4个逻辑megacore)通过Colab或单个TPU-VM暴露。对于像TPU v5e这样的推理芯片,我们每个主机有2个托盘而不是1个,但每个芯片只有1个核心,给我们8个芯片 = 8个核心。

PCIe带宽是有限的:与HBM ↔︎ VMEM链接一样,CPU ↔︎ HBM PCIe连接有特定的带宽,限制了从主机内存到HBM或反向加载的速度。例如,TPU v4的PCIe带宽在每个方向上为16GB/秒,所以比HBM慢近100倍。我们可以将数据加载/卸载到主机(CPU)RAM中,但速度不会很快。

TPU网络连接

芯片通过ICI网络在Pod中相互连接。在较旧的世代(TPU v2和TPU v3)、推理芯片(例如TPU v5e)和Trilium(TPU v6e)中,ICI("芯片间互连")连接4个最近的邻居(通过边缘链接形成2D环面)。TPU v4和TPU v5p连接到最近的6个邻居(形成3D环面)。请注意,这些连接通过它们的主机,它们是芯片之间的直接链接。

环面结构将任意两个节点之间的最大距离从N减少到N/2,使通信更快。TPU还具有"扭曲环面"配置,将环面包裹在类似莫比乌斯带的拓扑中,进一步减少节点之间的平均距离。

TPU pod(通过ICI连接)可以变得非常大:最大pod大小(称为超级pod)对于TPU v4是16x16x16,对于TPU v5p是16x20x28。这些大型pod由4x4x4芯片的可重构立方体组成,通过光学环绕链接连接,我们可以重新配置这些链接来连接非常大的拓扑结构。

也可以请求较小的拓扑结构(例如2x2x12x2x2),但没有环绕连接。这是一个重要的注意事项,因为它通常会使大多数通信的时间翻倍。任何完整立方体的倍数(例如4x4x44x4x8)都将通过光学开关提供环绕连接。

TPU v5e和Trillium pod由单个16x16的2D环面组成,沿着任何大小为16的轴有环绕连接(这意味着一个8x16在长轴上有环绕连接)。TPU v5e和v6e(Trillium)不能扩展超过16x16环面,但pod仍然可以通过标准数据中心网络(DCN)相互通信,DCN连接TPU主机彼此。同样,可以请求较小的拓扑结构,但维度<16的没有环绕连接。

这种最近邻居连接是TPU和GPU之间的关键区别。GPU以全连接配置(称为节点)连接多达256个H100,而不是使用本地连接。一方面,这意味着GPU可以在单个低延迟跳跃中在节点内发送任意数据。另一方面,TPU在连接在一起时便宜得多且更简单,并且可以扩展到更大的拓扑结构,因为每个设备的链接数量是恒定的。

ICI相对于DCN非常快,但仍然比HBM带宽慢。例如,TPU v5p具有:

2.5e12字节/秒(2.5 TB/秒)的每芯片HBM带宽。

9e10字节/秒(90 GB/秒)的每轴ICI带宽,每个芯片有3个轴。

2.5e10字节/秒(25 GB/秒)的每主机DCN(出口)带宽。由于我们通常每个主机有8个TPU,这实际上更接近于每芯片3.1e9字节/秒。

这意味着当我们将模型拆分到多个芯片上时,我们需要小心避免因较慢的跨设备通信而造成MXU的瓶颈。

多片训练:一组通过ICI连接的TPU被称为。不同的片可以通过DCN相互连接,例如连接不同pod上的片。由于DCN是比ICI慢得多的连接,我们应该尽量限制计算等待DCN数据的时间。DCN是主机到主机的连接,所以要通过DCN在TPU之间传输缓冲区,我们首先需要通过PCIe传输到主机,然后通过网络出口,再通过目标主机网络入口,最后通过PCIe进入HBM。

关键要点

TPU是简单的设备,在大多数情况下可以被视为连接到内存(超快速)、通过ICI连接到其他芯片(相当快速)以及通过DCN连接到数据中心其余部分(较快速)的矩阵乘法单元。

  • 通信受到我们各种网络带宽的限制,按速度排序:

    • HBM带宽:在TensorCore和其关联的HBM之间。
    • ICI带宽:在TPU芯片和其最近的4或6个邻居之间。
    • PCIe带宽:在CPU主机和其关联的芯片托盘之间。
    • DCN带宽:在多个CPU主机之间,通常是未通过ICI连接的主机。
  • 在一个片内,TPU仅通过ICI连接到其最近的邻居。这意味着在片内远距离芯片之间通过ICI的通信需要先经过中间芯片进行跳转。

  • 权重矩阵需要在两个维度上填充到至少128的大小(在TPU v6上是256)以填满MXU(实际上,较小的轴会被填充到128)。

  • 低精度矩阵乘法往往更快。对于支持的世代,TPU可以执行int8或int4 FLOPs的速度大约是bfloat16 FLOPs的2倍/4倍。VPU操作仍在fp32中执行。

  • 为了避免TPU计算单元的瓶颈,我们需要确保每个通道上的通信量与其速度成正比

  • 以下是我们芯片的一些具体数据:

型号Pod大小主机大小每芯片HBM容量每芯片HBM带宽(字节/秒)每芯片FLOPs/秒(bf16)每芯片FLOPs/秒(int8)
TPU v332x324x232GB9.0e111.4e141.4e14
TPU v4p16x16x162x2x132GB1.2e122.75e142.75e14
TPU v5p16x20x282x2x196GB2.8e124.59e149.18e14
TPU v5e16x164x216GB8.1e111.97e143.94e14
TPU v6e16x164x232GB1.6e129.20e141.84e15

主机大小指的是连接到单个主机的TPU拓扑结构(例如,TPU v5e有一个CPU主机连接到4x2拓扑结构中的8个TPU)。以下是互连数据:

型号ICI带宽/链接(单向,字节/秒)ICI带宽/链接(双向,字节/秒)
TPU v31e112e11
TPU v4p4.5e109e10
TPU v5p9e101.8e11
TPU v5e4.5e109e10
TPU v6e9e101.8e11

我们同时包括单向带宽和双向带宽,因为单向带宽更符合硬件实际情况,但双向带宽在涉及完整环的方程中更常见。

PCIe带宽通常在每芯片1.5e10字节/秒左右,而DCN带宽通常在每主机2.5e10字节/秒左右。我们包括单向和双向带宽以求完整。通常,当我们可以访问完整的环绕环时,双向带宽是更有用的数字,而单向带宽更符合硬件实际情况。

第2章练习题

这些数字有点枯燥,但它们可以让你对模型性能做出基本的屋顶线估计。让我们通过几个问题来解释为什么这很有用。在第3部分中你会看到更多例子。

问题1 [限制LLM延迟]:假设你想从一个跨32个TPU v4p分割的200B参数模型中进行采样,采用bf16格式。将所有参数从HBM加载到脉动阵列需要多长时间?提示:使用上面的数字。

点击此处查看答案

答案:我们正在加载sizeof(bf16) * 200e9 = 400e9字节到32个芯片上,意味着每个芯片12.5e9字节,每个芯片的HBM带宽为1.23e12。因此加载大约需要10ms。

这很酷,因为这是模型采样延迟的合理下限。每个采样步骤都需要从HBM加载所有参数,所以它不可能少于10毫秒。实际上,在小批量大小的情况下,这个数值是可以接近达到的。

问题2 [TPU细节]:考虑一个完整的TPU v5e pod。总共有多少CPU主机?有多少TPU TensorCore?整个pod的总FLOPs/s是多少?总HBM是多少?对TPU v5p pod做同样的练习。

点击此处查看答案

答案:对于TPU v5e,每个pod是16x16,每个主机是4x2切片,所以我们有1616 / 8 = 32个主机。对于TPU v5e,每个TPU只有一个核心,所以我们有256个TensorCore。总FLOPs/s是1616*2e14 = 5.1e16(bfloat16格式)。每个芯片有16GB的HBM,所以总共是256 * 16 = 4TB内存。

对于完整的TPU v5p pod,我们有16x20x28个芯片,每个主机是2x2x1,所以我们有162028 / 2*2 = 2,240个主机。对于TPU v5p,每个TPU有两个TensorCore,所以我们有8960 * 2 = 17,920个核心。总FLOPs/s是8960 * 4.5e14 = 4e18(bfloat16格式)。每个芯片有96GB的HBM,所以总共是8960 * 96 = 860TB内存。

问题3 [PCIe运算强度]:假设我们被迫在主机DRAM中存储一个大型权重矩阵A,类型为bfloat16[D, F],以及一批激活x,类型为bfloat16[B, D],并希望对它们进行矩阵乘法。这在单个主机上运行,我们使用连接到它的单个TPU v6e芯片。你可以假设B ≪ D,且F = 4D(我们将在未来章节中看到为什么这些是合理的假设)。我们需要的最小批量大小B是多少才能在PCIe上保持FLOPs受限?假设PCIe带宽为1.5e10字节/秒。

点击此处查看答案

答案:我们必须执行2BDF次浮点运算,每个芯片每秒可以执行9.2e14次浮点运算。这需要2BDF/9.2e14秒来执行。我们必须从DRAM加载2DF + 2BD字节,并将2BF字节写回。我们受PCIe传输速度的瓶颈限制,因此需要2 \cdot (BD + DF + BF)/1.5e10秒将数据传输到TPU并从TPU传回。由于我们希望计算时间长于权重加载时间,假设我们可以将所有权重加载与计算重叠,我们希望2BDF/9.2e14 > 2 \cdot (BD + DF + BF)/1.5e10。使用我们的假设B ≪ DF = 4D,我们可以简化为

$$ \frac{8BD^2}{9.2e14} > \frac{8D^2}{1.5e10} $$

$$ B > \frac{9.2e14}{1.5e10} \simeq 61,000 $$

问题4 [通用矩阵乘法延迟]:假设我们要将一个权重矩阵int8[16384, 4096]乘以一个大小为int8[B, 4096]的激活矩阵,其中B是某个未知的批量大小。假设我们首先在1个TPUv5e上。

  1. 这个乘法需要多长时间,作为B的函数?提示:计算从HBM加载数组需要多长时间以及乘法实际需要多长时间可能会有帮助。哪个是瓶颈?

  2. 如果我们想从VMEM运行这个操作呢?作为B的函数,它需要多长时间?

点击此处查看答案

答案:(1) 我们需要执行的浮点运算数量是2 \cdot 4096 \cdot 16384 \cdot B = 1.3e8 \cdot B。所以T_{math} = (1.3e8 \cdot B)/3.94e14秒。我们需要从HBM加载16384 \cdot 4096 + 4096 \cdot B字节到VMEM,并将16384 \cdot B字节从VMEM写回HBM。这意味着T_{comms} = (6.7e7 + 2e4 \cdot B)/8.1e11秒。假设通信和计算尽可能重叠,整个乘法大约需要

$$ \max{T_{\text{math}}, T_{\text{comms}}} = \max\left{\frac{6.7e7 + 2e4\cdot B}{8.1e11}, \frac{1.3e8 \cdot B}{3.94e14}\right} $$

当\frac{6.7e7 + 2e4\cdot B}{8.1e11} < \frac{1.3e8 \cdot B}{3.94e14}时,我们将受FLOPs限制,或等效地,B > 271。这略大于我们下面推导的240数字,因为我们考虑了DF的完整影响。

(2) 如果我们从VMEM加载,让我们考虑VMEM到MXU的带宽是HBM ↔︎ VMEM带宽的22倍。这将我们的数据加载分母从8.1e11变为1.78e13,我们得到B > 11。_注意,在实践中,我们不能将所有VMEM带宽专用于加载W,所以实际上会更接近20。

问题5 [ICI带宽]:假设我们有一个TPU v5e 4x4切片。假设我们想从TPU{0,0}TPU{3, 3}发送一个类型为bfloat16[8, 128, 8192]的数组。假设TPU v5e的每跳延迟为1μs

  1. 第一个字节何时到达目的地?

  2. 整个传输需要多长时间?

点击此处查看答案

答案:在TPUv5e中,我们有2D连接。因为我们只有一个4x4切片(没有大小为16的轴),所以没有环绕连接。因此,目标芯片有两个端口可以接收数据,同样,源芯片也有两个端口可以发送数据。我们需要传输的数据量是2 * 8 * 128 * 8192 = 1.7e7字节。我们可以同时从两个端口传输(即向右发送一半数组,向下发送一半),所以我们每秒传输2 * 4.5e10 = 9e10字节,这意味着传输整个数组大约需要1.7e7 / 9e10 = 188us(假设我们受带宽限制)。在4x4切片中,芯片(0, 0)和(3, 3)之间有六个跳转,因为对于少于16个芯片的轴没有环绕链接。由于每个跳转的延迟约为1μs,第一个字节将在约6us后到达,整个传输将花费188us

问题6 [综合起来,难]:假设你有一个大矩阵Aint8[128 * 1024, 128 * 1024],均匀分片在TPU v5e 4x4切片上,但卸载到每个芯片的主机DRAM上。假设你想将整个数组复制到TPU{0, 0}并将其乘以向量bf16[8, 128 * 1024]。这需要多长时间?提示:使用上面的数字。

点击此处查看答案

答案:让我们首先概述我们必须执行的操作。我们的数组大约是16GB。从上表可知,TPU v5e主机有4x2拓扑,所以4x4有2个主机。因此,由于我们的数组均匀分片,每个主机实际上包含数组的1/2,即8GB。我们需要将这些块全部复制到TPU{0,0},这给我们两个选择:

  1. 我们可以通过DCN复制,然后通过PCIe将整个未分片数组加载到HBM中。

  2. 我们可以将分片数组加载到相应的TPU上,然后通过ICI执行收集操作,然后在TPU{0,0}上执行矩阵乘法。很明显,选项(2)更好。与ICI相比,DCN速度较慢,我们更喜欢通过多个PCIe链接(而不仅仅是主机0上的8个)加载大型数组。这是系统部分的示意图。如上所述,请注意TPU通过ICI连接到它们的邻居(甚至跨主机),所有TPU都连接到它们的主机CPU(通过PCIe),主机通过DCN连接。

现在让我们计算每个部分需要多长时间:

  1. PCIe加载:我们通过8个PCIe链接加载16GB / 2 = 8GB的块,每个链接带宽为1.5e10字节/秒。因此这将花费约66ms。

  2. ICI复制:每个TPU现在有16GB / 16 = 1GB的数组。我们的ICI带宽是每链接10e10字节/秒双向,你会注意到从上图中只有TPU v5e上4个ICI链接中的2个在这种拓扑中使用。由于TPU{0,0}需要沿2个轴以4.5e10字节/秒/链接的速度接收总共15GB,我们可以将时间下限为15e9 / (4.5e10 * 2) = 167ms。实际上,这可能无法实现,因为负载非常不均匀,但它可能在2倍范围内。正如你将在第2节中看到的,执行完整的AllGather也大约需要16e9 / (4.5e10 * 2),所以这接近最优。

  3. HBM → MXU加载:为了执行最终的矩阵乘法,我们需要通过HBM带宽将这些16e9字节加上bf16[8, 128 * 1024]数组(另外2MB,所以可忽略不计)加载到MXU中,这将花费16e9 / 8.1e11 = 19ms

  4. FLOPs:我们执行总共2 \cdot 8 \cdot 128 \cdot 1024 \cdot 128 \cdot 1024 = 2.7e11 FLOPs,由于我们可以执行1.97e14 bf16 FLOPs/s,我们得到1.3ms。总时间的上限是所有这些时间的总和,但由于TPU通常可以重叠这些操作,我们可以将其视为由最慢部分瓶颈的流水线问题。假设这是正确的,那么答案约为150-200ms。

第2章附录A:简单谈谈GPU

与TPU相比,GPU具有更简单的通信模型和更复杂的编程模型。

  • GPU在概念上与TPU相似:它们也作为连接到CPU的加速器运行。

  • 不同之处在于计算是通过更多数量的"流式多处理器"(等同于TensorCore)进行的,这些处理器连接到DRAM(等同于HBM)。每个流式多处理器(SM)都有一个小型L1缓存,用于加速数据访问和寄存器溢出。L1缓存使用的内存部分也可以声明为共享内存,允许线程块中的任何线程访问,用于用户定义的缓存、并行规约和同步等。最后,还有一个由所有SM共享的附加L2缓存。

  • 主要区别在于NVIDIA GPU通常通过开关(NVLink → NVSwitch)形成8-256个GPU的"小组",允许该"小组"内任何GPU之间进行点对点通信,但这意味着超过256个之间的通信会明显变慢 - 这意味着训练超过256个通常需要流水线并行来扩展,这更为复杂(相比之下,PaLM是在两组各3072个TPU芯片上训练的)

  • 对于常见的神经网络操作,如AllReduce,全连接并不具有优势(因为无论如何都必须发生相同的通信模式),但它确实允许在更多GPU上存储MoE模型并更有效地传输专家

  • 每个GPU需要一个成本与GPU本身相似的开关,使得像ICI这样的片上互连更便宜

  • NVIDIA深度学习性能

  • NVSwitch

  • 张量并行/流水线并行的转换点非常不同!

第2章附录B:脉动阵列如何工作?

TPU MXU的核心是一个128x128脉动阵列(TPU v6e上为256x256)。当完全饱和时,脉动阵列每8个时钟周期可以执行一次bfloat16[8,128] @ bf16[128x128] -> f32[8,128]乘法运算。

  • 从本质上讲,脉动阵列是一个2D的128x128=16,384)ALU网格,每个ALU能够执行乘法和加法操作。

  • 权重(W128x128输入)从上方传下来(称为RHS),而输入(X8x128输入)从左侧传入(称为LHS)。

这是一个简化的动画,展示了一组权重(蓝色)与一组激活值(绿色)相乘。你会注意到权重(RHS)首先被对角线地部分加载,然后激活值也是对角线地被送入。在下面的每一帧中,我们将所有重叠的绿色和蓝色单元相乘,将结果与从上方传入的任何残差相加,然后依次将结果向下传递一个单元。

这是此动画的更一般版本,显示计算结果的输出流:

这是一个图表,显示了如何在多个RHS和LHS数组上进行流水线处理:

当权重(RHS)和激活值(LHS)被加载时,存在一个初始流水线气泡。在初始气泡之后,可以加载新的输入和权重,而不会产生额外的气泡。

这是一个bf16[2, 3] x bf16[3, 3]矩阵乘法的不太理想的动画,你可以将其想象为一个2x3权重矩阵与批次大小为1、尺寸为3的输入激活值的矩阵乘法。与之前的幻灯片相比,这个是旋转了的,输入流向右侧而不是向下,但你大致可以看到结构。

我们可以有效地通过流水线方式乘以大矩阵,而不会产生太大的流水线气泡。话虽如此,重要的是我们的矩阵形状要大于MXU的边长,通常是128x128。一些TPU(自TPU v3起)有多个MXU,TPU v3有2个,TPU v4/5有4个,所以我们需要确保平铺维度大于128 * MXU数量。这里有一个很好的动画演示。

Trillium(TPU v6e)有一个256x256的脉动阵列,这意味着它每个周期可以执行4倍的FLOPs。这也意味着张量的维度需要增加一倍才能充分利用MXU。

这篇博客文章有另一个固定权重矩阵的脉动阵列乘法的优秀动画。

第2章附录C:TPU内部结构

标量核心

TPU标量核心处理所有指令,并执行所有从HBM到向量内存(VMEM)的传输。标量核心还负责为芯片的VPU、MXU和XLU组件获取指令。这导致的一个副作用是,TPU的每个核心每个周期只能创建一个DMA请求。

为了理解这一点,一个单独的4标量核心控制着一个由2048个ALU组成的VPU、4个MXU、2个XLU和多个DMA引擎。每单位计算的控制高度倾斜的特性是硬件效率的来源,但也限制了以任何有趣方式进行数据依赖向量化的能力。

VPU

TPU向量核心由一个二维向量机器(VPU)组成,执行向量操作如vadd(向量加法)或vmax(元素级最大值),以及一组向量寄存器VREGs,为VPU和MXU保存数据。VPU实际上是一个形状为(8, 128)的2D向量算术单元,其中128维被称为lane,8维被称为sublane。在v4上,每个(lane, sublane)对包含2个标准浮点和整数ALU。从软件角度看,这创造了一个8x128向量单元的外观,在v4中总共有2048个浮点加法器。TPU v4有32个大小为(8, 128)的VREGs,VPU从中加载并写入。

VPU在其每个ALU中执行大多数算术指令只需一个周期(如vadd或向量加法),延迟为2个周期,所以例如在v5中,你可以在每个周期中将4对f32值从VREGs中加在一起。一个典型的VPU指令可能看起来像{v2 = vadd.8x128.f32 v0, v1},其中v0和v1是输入VREG,v2是输出VREG。

所有lane和sublane每个周期以纯SIMD方式执行相同的程序,但每个ALU可以执行不同的操作。因此,我们可以例如在单个周期中处理1个vadd和1个vsub,每个操作都对两个完整的VREG进行操作,并将输出写入第三个VREG。

第3章:分片矩阵及其乘法方法

分区表示法和集体操作

当我们在一万个TPU上训练LLM时,我们从抽象层面上仍在执行与单个TPU训练相同的计算。区别在于我们的数组无法适应单个TPU的HBM,所以我们必须将其拆分。我们称这为"分片"或"分区"我们的数组。

这是一个跨4个TPU分片的2D数组A示例:

图示:形状为A[I, J]的示例数组被分片到4个设备上。两个维度均匀地跨2个设备进行分片,分片方式为A[I_X, J_Y]。每个TPU持有总内存的1/4。

注意分片数组仍然具有与未分片数组相同的全局逻辑形状,比如(4, 128),但它也有一个设备本地形状,如(2, 64),这给出了每个TPU实际持有的字节大小(在上图中,每个TPU持有总数组的¼)。现在我们将这一概念推广到任意数组。

分片的统一表示法

我们使用命名轴表示法的变体来描述张量如何在设备上以块的形式进行分片:我们假设存在一个2D或3D的设备网格,称为设备网格,其中每个轴都被赋予网格轴名称例如XY和Z。然后,我们可以通过描述数组的每个命名维度如何在物理网格轴上进行分区,来指定矩阵数据如何在设备网格上布局。我们称这种分配为分片

示例(上图):对于上图,我们有:

  • 分片:A[I_X, J_Y],这告诉我们将第一个轴I沿着网格轴X分片,将第二个轴J沿着网格轴Y分片。这种分片告诉我们每个分片持有数组的1/(|X| \cdot |Y|)。

  • 网格:上面的设备网格Mesh(devices=((0, 1), (2, 3)), axis_names=('X', 'Y')),告诉我们我们有4个TPU在一个2x2网格中,轴名为XY

综合起来,我们知道数组的本地形状(单个设备持有的分片大小)是(|I|/2, |J|/2),其中|I|是A的第一维大小,|J|是A的第二维大小。

示例(沿1个轴的2D分片)A[I_{XY}, J]将第一维度(I)沿着X和Y两个硬件轴进行分片。每个设备的字节数与前一种分片相同,但本地形状不同。现在是(|I|/(|X| \cdot |Y|), |J|)。

可视化这些分片:让我们尝试通过查看一个在4个设备上拆分的2D数组来可视化这些分片:

我们将矩阵的完全复制形式简单地表示为A[I, J],没有分片分配。这意味着每个设备都包含整个矩阵的完整副本。

当我们希望表示其中一个维度已经在网格轴上分区时,我们使用网格轴下标来表示。例如,A[I_X, J]表示I逻辑轴已经在X网格维度上分区,但J维度分区,块在Y网格轴上保持部分复制

A[I_X, J_Y]表示I逻辑轴已在X网格轴上分区,J维度已在Y网格轴上分区。

我们在下图中说明其他可能性:

这里A[I_{XY}, J]表示我们将XY网格轴视为更大的扁平化维度,并在所有设备上对I命名轴进行分区。多个网格轴下标的顺序很重要,因为它指定了网格上分区的遍历顺序。

最后,请注意,我们不能有多个命名轴沿相同的网格维度进行分片。例如,A[I_X, J_X]是一种无意义的、禁止的分片。一旦一个网格维度被用来分片数组的一个维度,它就在某种意义上被"用完了"。

小测验:A为一个形状为int8[128, 2048]的数组,分片为A[I_{XY}, J],网格为Mesh({'X': 2, 'Y': 8, 'Z': 2})(总共32个设备)。A在每个设备上使用多少内存?A在所有设备上总共使用多少内存?

点击此处查看答案。

答案:我们的数组A在X和Y上分片并在Z上复制,所以每个设备上的形状为int8[128 / (2 * 8), 2048] = int8[8, 2048],大小为8 * 2048 = 16,384字节。因为它在Z上是复制的,而在Z平面内它在X和Y上完全分片,所以每个Z平面有一个副本,总共有2个这样的平面,所以总大小(跨所有设备)是128 * 2048 * 2 = 512kiB

简单补充:我们如何在代码中描述这个?

JAX使用的命名分片语法与我们上面描述的抽象语法非常匹配。我们将在第10节中更详细地讨论这一点,但这里是一个快速预览。你可以在Google Colab 这里尝试,并分析结果以了解JAX如何处理不同的分片。这段代码做了3件事:

  1. 创建一个jax.Mesh,将我们的8个TPU映射到一个4x2网格,并将名称'X'和'Y'分配给两个轴。

  2. 创建矩阵A和B,其中A沿其两个维度分片,B沿输出维度分片。

  3. 编译并执行一个简单的矩阵乘法,返回一个分片数组。

import jax
import jax.numpy as jnp
import jax.sharding as shd
# 创建我们的网格!我们在TPU v2-8 4x2切片上运行,轴名为'X'和'Y'。
assert len(jax.devices()) == 8
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# 一个帮助定义分片的小工具函数。PartitionSpec是我们的分片(从轴到名称的映射)。
def P(*args):
  return shd.NamedSharding(mesh, shd.PartitionSpec(*args))

# 我们在非收缩维度上对A和B进行分片,并在收缩维度上对A进行分片。
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))

# 我们可以对这些分片数组执行矩阵乘法!out_shardings告诉我们我们希望输出如何分片。
# JAX/XLA为我们处理其余的分片。
compiled = jax.jit(lambda A, B: jnp.einsum('BD,DF->BF', A, B), out_shardings=P('X', 'Y')).lower(A, B).compile()
y = compiled(A, B)

JAX的一个很棒的特点是这些数组的行为就像它们没有被分片一样!B.shape会告诉我们全局或逻辑形状(2048, 8192)。我们必须实际查看B.addressable_shards才能看到它是如何在本地分片的。我们可以对这些数组执行操作,JAX会尝试弄清楚如何广播或重塑它们以执行操作。例如,在上面的例子中,A的本地形状是[2, 1024]B的是[2048, 4096]。JAX/XLA会自动在这些数组之间添加必要的通信以执行最终的乘法。

使用分片数组的计算

如果你有一个分布在多个设备上的数据数组,并希望对其执行数学运算,那么分片数据和计算会带来哪些开销?

显然,这取决于所涉及的计算。

  • 对于元素级操作,在分布式数组上操作没有额外开销

  • 当我们希望对驻留在多个设备上的元素执行操作时,情况会变得复杂。幸运的是,对于大多数机器学习来说,几乎所有计算都以矩阵乘法的形式进行,而且它们相对容易分析。

本节的其余部分将讨论如何乘以分片矩阵。粗略地说,这涉及移动矩阵的块,以便你可以完全乘法或求和每个块。每种分片方式都涉及不同的通信。例如,A[I_X, J] \cdot B[J, K_Y] \to C[I_X, K_Y]可以在没有任何通信的情况下相乘,因为收缩维度(J,即我们实际求和的维度)是未分片的。然而,如果我们想要输出未分片(即A[I_X, J] \cdot B[J, K_Y] \to C[I, K]),我们需要将AC复制到每个设备。这两种选择有不同的通信成本,所以我们需要计算这些成本并选择较低的一个。

你可以从"块矩阵乘法"的角度来思考这个问题。

首先,让我们回顾一下"块矩阵"的概念,或者说矩阵的嵌套矩阵:

$$ \begin{equation} \begin{pmatrix} a_{00} & a_{01} & a_{02} & a_{03} \ a_{10} & a_{11} & a_{12} & a_{13} \ a_{20} & a_{21} & a_{22} & a_{23} \ a_{30} & a_{31} & a_{32} & a_{33} \end{pmatrix}

\left( \begin{matrix} \begin{bmatrix} a_{00} & a_{01} \ a_{10} & a_{11} \end{bmatrix} \ \begin{bmatrix} a_{20} & a_{21} \ a_{30} & a_{31} \end{bmatrix} \end{matrix} \begin{matrix} \begin{bmatrix} a_{02} & a_{03} \ a_{12} & a_{13} \end{bmatrix} \ \begin{bmatrix} a_{22} & a_{23} \ a_{32} & a_{33} \end{bmatrix} \end{matrix} \right)

\begin{pmatrix} \mathbf{A_{00}} & \mathbf{A_{01}} \ \mathbf{A_{10}} & \mathbf{A_{11}} \end{pmatrix} \end{equation} $$

矩阵乘法有一个很好的特性,即当矩阵乘数以块的形式表示时,其乘积可以按照标准规则用块矩阵乘法表示:

$$ \begin{equation} \begin{pmatrix} A_{00} & A_{01} \ A_{10} & A_{11} \end{pmatrix} \cdot \begin{pmatrix} B_{00} & B_{01} \ B_{10} & B_{11} \end{pmatrix}

\begin{pmatrix} A_{00}B_{00} + A_{01}B_{10} & A_{00}B_{01} + A_{01}B_{11} \ A_{10}B_{00} + A_{11}B_{10} & A_{10}B_{01} + A_{11}B_{11} \end{pmatrix} \end{equation} $$

这意味着实现分布式矩阵乘法归结为通过网络移动这些分片块,对块执行本地矩阵乘法,并将其结果相加。问题在于需要添加什么通信,以及它有多昂贵。

方便的是,我们可以将所有可能的分片归纳为大约4种需要考虑的情况,每种情况都有一个规则,说明我们需要添加什么通信

  1. 情况1两个输入都没有沿收缩维度分片。我们可以在没有任何通信的情况下乘以本地分片。

  2. 情况2一个输入有一个分片的收缩维度。我们通常沿收缩维度对分片输入进行"AllGather"操作。

  3. 情况3两个输入都沿收缩维度分片。我们可以乘以本地分片,然后对结果进行"AllReduce"操作。

  4. 情况4两个输入都有沿同一轴分片的非收缩维度。我们必须首先对两个输入中的一个进行AllGather操作,否则无法继续。

你可以将这些视为需要遵循的规则,但了解这些规则为何成立以及它们的成本有多高也很有价值。现在我们将详细讨论每一种情况。

情况1:两个乘数都没有分片的收缩维度

引理:当乘以分区张量时,计算是有效的,输出遵循输入的分片除非收缩维度被分片或两个张量都有沿同一轴分片的非收缩维度。例如,这样运行良好

$$ A[I_X, J] \cdot B[J, K_Y] \to C[I_X, K_Y] $$

完全不需要任何通信,并产生一个跨X和Y硬件维度分片的张量。试着思考为什么会这样。基本上,计算独立于分片,因为每个批次条目都有一些要收缩的轴的本地块,可以进行乘法和规约。以下任何情况都运行良好并遵循此规则:

$$ \begin{align*} \mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I, K] \ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I_X, K]\ \mathbf{A}[I, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I, K_Y]\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I_X, K_Y] \end{align*} $$

因为AB都没有分片的收缩维度J,我们可以简单地执行输入的本地块矩阵乘法,结果已经按照所需的输出分片进行分片。当两个乘数都有沿同一轴分片的非收缩维度时,情况就不再如此(详见无效分片部分)。

情况2:一个乘数有分片的收缩维度

让我们考虑一个简单的情况,在收缩维度J上分片的A与完全复制的B进行分布式矩阵乘法:

$$ A[I, J_X] \cdot B[J, K] \to C[I, K] $$

我们不能简单地对本地AB块执行本地矩阵乘法,因为我们缺少来自A收缩轴的完整数据。通常,我们首先在本地"AllGather"A的分片,然后才与B相乘:

$$ AllGather_X[I, J_X] \to A[I, J] \ A[I, J] ⋅ B[J, K] \to C[I, K] $$

AllGather移除沿一个轴的分片,并将分布在设备上的分片重新组装到该轴上的每个设备上。使用上面的表示法,AllGather从一组轴中移除下标,例如

$$ AllGather_{XY}(A[I_{XY}, J]) \to A[I, J] $$

我们也不必删除给定维度的所有下标,例如A[I_{XY}, J] \to A[I_Y, J]也是一个AllGather,只是仅在单个轴上。

请注意,我们可能还希望使用AllGather来移除非收缩维度分片,例如矩阵乘法:

$$ A[I_X, J] \cdot B[J, K] \to C[I, K] $$

我们同样会沿X进行AllGather以移除输出分片,但在这种情况下,我们可以在矩阵乘法之前或之后执行此操作,而在对收缩维度进行AllGather的情况下,我们必须在执行矩阵乘法之前这样做。

AllGather实际上是如何执行的?要沿单个轴执行AllGather,我们需要在该轴上传递所有分片,直到每个设备都有一个副本。图1显示了一个示例。8个设备中的每一个开始时拥有数组的1/8,最终都拥有所有副本。一种有效的方法是让每个设备沿着分片维度环传递其分片,可以是单向也可以是双向。如果我们单向传递,每条链路需要N-1跳,每跳大小为总大小/N;否则,我们有\lceil \frac{N}{2} \rceil跳,每条链路大小为2 ⋅ 总大小/N

这需要多长时间?让我们计算双向AllGather需要多长时间。设V为数组中的字节数,|X|为收缩维度上的分片数。从上图可知,每跳在每个方向上发送V/|X|字节,所以每跳需要

$$ T_{hop} = \frac{2 \cdot V}{|X| \cdot W_\text{ICI}} $$

其中W_{ICI}双向ICI带宽。我们需要发送总共|X|/2跳才能到达每个TPU,所以总归约需要

$$ T_{total} = \frac{2 \cdot V \cdot |X|}{2 \cdot |X| \cdot W_\text{ICI}} $$

$$ T_{total} = \frac{V}{W_\text{ICI}} $$

注意,这不依赖于|X|这很引人注目,因为它意味着即使我们的TPU只是局部连接的,连接的局部性并不重要。我们只受每个链路速度的瓶颈。

💡 要点:在吞吐量受限的情况下执行AllGather(或ReduceScatter或AllReduce)时,实际通信时间仅取决于数组的大小和可用带宽,而不取决于我们的数组分片的设备数量!

关于ICI延迟的说明:通过ICI链路的每一跳都有一些固有开销,无论数据量多少。这通常约为1微秒。这意味着当我们的数组A非常小,每跳用时少于1微秒时,我们可能进入"延迟受限"的状态,此时计算确实取决于|X|。

点击此处了解完整详情。

Tmin为单跳的最小时间。那么

$$ T_{hop} = \max \left[ T_{min}, \frac{2 \cdot V}{|X| \cdot W_\text{ICI}} \right] $$

$$ T_{total} = \max \left[ \frac{T_{min} \cdot |X|}{2}, \frac{V}{W_\text{ICI}} \right] $$

因为我们执行|X|/2跳。对于大型归约或收集,我们完全受带宽限制。我们发送的数据量如此之大,以至于每跳的开销基本可以忽略不计。但对于小型数组(例如,从模型中采样时),这不可忽略,且ICI带宽无关紧要。我们完全受延迟限制。另一种说法是,给定特定的TPU,例如具有4.5e10单向ICI带宽的TPU v5e,发送任何小于4.5e10 * 1e-6 = 45kB的缓冲区都将受延迟限制。

当我们在多个轴上进行AllGather时会发生什么?当我们在多个轴上进行收集时,我们有多个ICI维度来执行收集。例如,AllGather_{XY}([B, D_{XY}])在两个硬件网格轴上操作。这将可用带宽增加n_{axes}倍。

点击此处了解完整详情。

一般来说,我们有

$$ T_{total} = \max \left[ \frac{T_{min} \cdot \sum_{i} |X_i|}{2}, \frac{V}{W_\text{ICI} \cdot n_\text{axes}} \right] $$

其中\sum_i |Xi|/2是TPU网格中最长路径的长度。

小测验2 [AllGather时间]:使用第2部分中的数字,在具有2D网格{'X': 8, 'Y': 4}的TPUv5e上执行AllGather_Y([E_Y, F]) \to [E, F]需要多长时间,其中E = 2048F = 8192,使用bfloat16?E = 256F = 256又如何?

点击此处查看答案。

答案:让我们先计算一些基本量:

  1. TPU v5e的2个轴中每一个都有4.5e10字节/秒的单向ICI带宽。

  2. 在(a)的bfloat16中,我们有A[E_Y, F],所以每个设备持有形状为bfloat16[512, 8192]的数组,大小为512 * 8192 * 2 = 8.4MB。总数组大小为2048 * 8192 * 2 = 34MB。

对于第(1)部分,我们可以使用上面的公式。由于我们在一个轴上执行AllGather,我们有T_{comms} = 34e6/9e10 = 377μs。为了检查我们是否不受延迟限制,我们知道在大小为4的轴上,最多有3跳,所以我们的延迟限制约为3微秒,因此我们距离限制还很远。然而,TPU v5e只有在一个轴大小为16时才有环绕连接,所以这里我们实际上不能做完全双向的AllGather。我们需要3跳才能让数据从边缘到达另一边,所以理论上我们更接近T_{comms} = 3 * 8.4e6/4.5e10 = 560μs这里来自这个Colab实际配置文件,显示680μs,考虑到我们可能无法获得100%的理论带宽,这是合理的!对于第(2)部分,每个分片大小为64 * 256 * 2 = 32kB。32e3 / 4.5e10 = 0.7us,所以我们受延迟限制。由于我们有3跳,这将花费大约3 * 1微秒 = 3微秒。实际上,接近8微秒。

情况3:两个乘数都在收缩维度上进行分片

第三个基本情况是当两个乘数都在它们的收缩维度上沿着相同的网格轴进行分片时:

$$ A[I, J_X] \cdot B[JX, K] \to C[I, K] $$

在这种情况下,本地分片块矩阵乘法至少是可能执行的,因为它们将共享相同的收缩索引集合。但每个乘积只代表完整所需乘积的部分和,并且沿着X维度的每个设备将留有这个最终所需乘积的不同部分和。这种情况非常常见,因此我们扩展了我们的表示法来明确标记这种情况:

$$ A[I, J_X] \cdot_{LOCAL} B[J_X, K] \to C[I, K]{U_X} $$

表示法{ U_X }读作"沿X网格轴未归约",指的是操作在某种意义上"不完整"的状态,因为它只会在最终求和之后才完成。\cdot_{LOCAL}语法表示我们执行本地求和但保留结果未归约。

这可以看作是关于矩阵乘法和外积的以下结果:

$$ A \cdot B = \sum_{i=1}^{P} \underbrace{A_{:,i} \otimes B_{i,:}}_{\in \mathbb{R}^{n \times m}} $$

其中⊗是外积。因此,如果X轴上的TPU iA的第i列和B的第i行,我们可以进行本地矩阵乘法得到A:,iBi, : ∈ ℝn × m。这个矩阵在每个条目中都有A • B在该条目处的和的第i项。我们仍然需要对P执行求和,这是我们在网格轴X上分片的,以获得完整的A • B。如果我们按块(即分片)写AB,然后对结果的每个分片求和,这种方法同样适用。

我们可以使用沿X轴的完全AllReduce来执行这个求和:

$$ \begin{align*} A[I, J_X] \cdot_\text{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] { U_X } \ \textbf{AllReduce}_X C[I, K] { U_X } \rightarrow &\ C[I, K] \end{align*} $$

AllReduce移除部分和,使得沿该轴的每个设备都具有相同的完全求和值。AllReduce是我们将在本节中讨论的几个关键通信的第二个,第一个是AllGather,其他是ReduceScatter和AllToAll。AllReduce接受一个具有未归约(部分求和)轴的数组,通过在未归约轴周围传递这些分片并累积结果来执行求和。其签名为

$$ \textbf{AllReduce}_Y A[I_X, J]{U_Y} \to A[I_X, J] $$

这意味着它只是移除{U_Y}后缀,但在其他方面保持结果不变。

AllReduce的开销有多大?对于AllReduce执行方式的一种心智模型是,每个设备将其分片发送给其邻居,并对它接收到的所有分片进行求和。显然,这比AllGather更昂贵,因为每个"分片"具有与完整数组相同的形状。通常,AllReduce的开销是AllGather的两倍。看到这一点的一种方式是注意AllReduce可以表示为两个其他原语的组合:ReduceScatterAllGather。与AllReduce一样,ReduceScatter解决数组上的部分和,但导致输出沿给定维度'散布'或分区。AllGather收集所有这些部分并'解分区/解分片/复制'沿该物理轴的逻辑轴。

$$ \begin{align*} \textbf{ReduceScatter}_{Y,J} : A[I_X,J] {U_Y} \rightarrow &\ A[I_X, J_Y] \ \textbf{AllGather}_Y : A[I_X, J_Y] \rightarrow &\ A[I_X, J] \end{align*} $$

ReduceScatter呢?正如AllReduce移除下标(上面的F_Y \to F),ReduceScatter对未归约/部分求和的数组进行求和,然后沿相同的网格轴散布(分片)不同的逻辑轴。[F]{U_Y} \to [F_Y]。动画展示了这是如何完成的:注意它与AllGather非常相似,但不是保留每个分片,而是将它们相加。因此,其延迟大致相同,不包括执行归约所需的时间。

每个跳跃的通信时间只是每分片字节V除以带宽,与AllGather一样,所以我们有

$$ T_{\text{comms per AllGather or ReduceScatter}} = \frac{V}{W_\text{ICI}} $$

$$ T_{\text{comms per AllReduce}} = 2 \cdot \frac{V}{W_\text{ICI}} $$

其中W_{ICI}是双向带宽,只要我们有一个完整的环进行归约。

情况4:两个乘数在非收缩维度上沿相同轴进行分片

在对张量进行分片时,每个网格维度最多只能出现一次。执行上述规则有时会导致违反这一规则的情况,例如:

$$ A[I_X, J] \cdot B[J, K_X] \to C[I_X, K_X] $$

这是无效的,因为沿维度X的特定分片,比如i,将拥有C的第(i, i)个分片,即对角线条目。在所有分片中没有足够的信息来恢复除结果的对角线条目之外的任何内容,因此我们不能允许这种分片方式。

解决这个问题的方法是对某些维度执行AllGather操作。这里我们有两种选择:

$$ \begin{align*} \textbf{AllGather}_X A[I_X, J] \rightarrow &\ A[I, J] \ A[I, J] \cdot B[J, K_X] \rightarrow &\ C[I, K_X] \end{align*} $$

$$ \begin{align*} \textbf{AllGather}_X B[J, K_X] \rightarrow &\ B[J, K] \ A[I_X, J] \cdot B[J, K] \rightarrow &\ C[I_X, K] \end{align*} $$

在这两种情况下,结果的形状中X只会出现一次。我们选择哪一种将基于后续操作需要的分片方式。

TPU通信原语深入探讨

前面的4个案例介绍了几种用于执行分片矩阵乘法的"核心通信原语":

  1. AllGather: 从分片中移除下标,收集所有分片。

  2. ReduceScatter: 通过沿该轴对分片求和来移除数组的"未归约"后缀,使数组在第二个轴上保持分片状态。

  3. AllReduce: 移除"未归约"后缀,使数组沿该轴保持非分片状态。

还有一种在混合专家(MoE)模型和其他计算中出现的核心通信原语需要提及:AllToAll

我们的最后一个通信原语:AllToAll

在考虑分片矩阵乘法时不会自然出现,但在实践中经常使用的最后一个基本集合操作是AllToAll集合操作,或更准确地说是分片转置或重新分片操作的特殊情况。例如:

$$ \textbf{AllToAll}_{X,J} A[I_X, J] \to A[I, J_X] $$

AllToAll通常用于重新排列不同区域之间具有不兼容布局方案的分片计算中的分片布局。它们在考虑分片混合专家模型时自然出现。你可以将AllToAll理解为将下标从一个轴移动到另一个轴。因为all to all不需要在环中复制每个分片的所有数据,所以它实际上比allgather更便宜(便宜四分之一)。

关于ReduceScatter的更多内容

ReduceScatter是一个比它最初看起来更基础的操作,因为它实际上是AllGather的导数,反之亦然。即,如果在前向传播中我们有:

$$ \textbf{AllGather}_X A[I_X] → A[I] $$

然后我们对反向模式导数A'(在每个分片上通常会有所不同)进行ReduceScatter以推导出分片的A'

$$ \textbf{ReduceScatter}_X A'[I]{U_X} \to A'[I_X] $$

同样,前向传播中的ReduceScatter_X(A[I]{U_X}) \to A[I_X])意味着在反向传播中的AllGather_X(A'[I_X]) \to A'[I]。

将AllReduce转换为AllGather和ReduceScatter还有一个便利的特性,就是我们可以推迟最终的AllGather到稍后的时刻。通常情况下,我们宁愿不支付在设备间复制完整矩阵乘积的成本。相反,即使在这种组合两个具有分片收缩维度的乘数的情况下,我们也希望保持分片状态:

$$ A[I, J_X] \cdot B[J_X, K] \to C[I, K_X] $$

在这种情况下,我们也可以执行ReduceScatter而不是AllReduce,然后可以选择在稍后的时间执行AllGather,即:

$$ \begin{align*} A[I, J_X] \cdot_{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] { U_X } \ \textbf{ReduceScatter}_{X,K} C[I, K] { U_X } \rightarrow &\ C[I, K_X] \end{align*} $$

注意,ReduceScatter引入了一个分片维度,因此在这种情况下自然可以沿IK命名维度进行分片。在使用ReduceScatter时,我们通常需要选择哪个命名维度引入新的分片(尽管这个选择通常由更大的建模上下文决定)。这就是为什么我们使用\textbf{ReduceScatter}_{X,K}语法来指定要分片的轴。

我们学到了什么?

  • 数组的分片由网格(Mesh)分片(Sharding)指定,其中网格命名TPU网格的物理硬件轴,而分片将网格轴名称分配给数组的逻辑轴。

    • 例如,A[I_{XY}, J]描述了一个抽象数组A,其第一维度沿两个网格轴X和Y进行分片。结合Mesh(mesh_shape=(4, 8), axis_names=('X', 'Y'))或简写为Mesh({'X': 4, 'Y': 8}),这告诉我们数组沿第一维度分成了32份。
  • 分片数组的算术运算与未分片数组完全相同,除非沿分片轴执行收缩操作。在这种情况下,我们必须引入一些通信。我们考虑四种情况:

    1. 两个数组都没有沿收缩维度分片:不需要通信。
    2. 一个数组沿收缩维度分片(或收缩维度沿不同轴分片):我们在执行操作前对其中一个输入执行AllGather。
    3. 两个数组在收缩维度上以相同方式分片:我们先在本地乘以分片,然后执行AllReduce或ReduceScatter。
    4. 两个数组沿非收缩维度的同一网格轴分片:我们首先对其中一个输入执行AllGather。
  • TPU大致使用4种核心通信原语

    1. AllGather: [A_X, B] → [A, B]
    2. ReduceScatter: [A, B] {U_X} \to [A, B_X]
    3. AllToAll: [A, B_X] → [A_X, B]
    4. AllReduce: [A_X, B]{U_Y} \to [A_X, B](技术上不是原语,因为它结合了ReduceScatter + AllGather)

  • 这些操作的成本和延迟不取决于轴的大小(只要它们受带宽限制),而只取决于输入数组的大小和链路的带宽。对于单向的AllGather/ReduceScatter:

$$ T_{\text{comm per AllGather or ReduceScatter}} = \frac{\text{Data volume}}{\text{bandwidth}} \cdot \frac{\text{Axis} - 1}{\text{Axis}} \longrightarrow \frac{\text{Data volume}}{\text{bandwidth (bidirectional)}} $$

  • AllReduce由ReduceScatter后跟AllGather组成,因此成本是上述的2倍。AllToAll只需要将分片部分传递环形,因此成本是AllGather的¼。以下是总结:
操作描述语法运行时间
AllGather收集沿轴分片的数组的所有分片,移除下标。[A_X,B] \to [A,B]字节数 / (双向ICI带宽 * 轴数)
ReduceScatter沿一个轴对部分求和的数组进行求和,并沿另一个轴进行分片(添加下标)。[A, B] {U_X} \to [A_X, B]与AllGather相同
AllReduce沿一个轴对部分求和的数组进行求和。移除{ Ux }。结合AllGather和ReduceScatter。[A_X, B]{U_Y} \to [A_X, B]2 * AllGather
AllToAll收集(复制)一个轴并沿同一轴对不同维度进行分片。[A,B_X] → [A_X,B]对于双向环,为AllGather / 4

第3章练习题

以下是基于本节内容的一些有启发性的练习题。我们目前不会提供所有答案,但会在可能的情况下撰写更多答案。

问题1 [复制分片]:一个数组以A[IX, J, K, …]的方式分片(即,仅沿X轴分片),网格为Mesh({'X': 4, 'Y': 8, 'Z': 2})A在所有芯片上占用的总字节数与数组单个副本大小的比率是多少?

点击此处查看答案

我们的数组仅沿X轴分片,其大小为4,所以每个分片的大小实际上是[I/4, J, K, …] = sizeof(A)/4。由于我们的数组在Y和Z轴上是复制的,总大小为Y ⋅ Z ⋅ sizeof(A),所以总大小与单个芯片大小的比率是Y ⋅ Z ⋅ sizeof(A)/sizeof(A) = 16

问题2 [AllGather延迟]:在TPUv4p 4x4x4切片上执行AllGather_X([B_X, D_Y]),网格为Mesh({'X': 4, 'Y': 4, 'Z': 4}),如果B = 1024D = 4096,使用bfloat16格式,需要多长时间?AllGather_{XY}([B_X, D_Y])呢?AllReduce_Z([B_X, D_Y]{U_Z})呢?

点击此处查看答案

我们在所有轴上都有环绕连接,因为我们有一个完整的4x4x4立方体,所以我们可以使用9e10的双向带宽。

  1. 因为我们只在一个轴上进行收集,而另一个轴是分片的,我们实际上是在1个轴上收集2BD/Y字节。由于TPU v4p的ICI带宽是9e10字节/秒(双向),这将花费2BD/(9e10 ⋅ Y) = 2 ⋅ 1024 ⋅ 4096/(9e10 ⋅ 4) = 23μs

  2. 我们有两倍于之前的带宽,但我们要对整个数组进行AllGather,所以T = 2BD / (2 * W) = 210244096 / (2 * 9e10) = 46us。这远高于4us的延迟界限(每跳1us),所以我们没问题。

  3. AllReduce的成本是AllGather的两倍。每个分片的大小为2BD/(X * Y),所以成本大约是4BD/(X * Y * W),或大约4 * 1024 * 4096 / (16 * 9e10) = 11.6us

问题3 [延迟受限的AllGather]:假设我们执行的是AllGatherX([BX]),但B很小(比如128)。在TPUv4p 4x4x4切片上,网格为Mesh({'X': 4, 'Y': 4, 'Z': 4}),使用bfloat16格式,这需要多长时间?提示:你可能受到延迟限制。

点击此处查看答案

我们的bfloat16数组总共只使用256字节,每个设备只有64字节。由于我们在TPU v4p上有一个大小为4的轴,我们有环绕链接,所以可以在两个方向上发送数组。以4.5e10的单向带宽,每跳大约需要64 / 4.5e10 ~ 0,所以我们肯定受到延迟限制。计算跳数,我们只需2跳就能完成完整收集,所以大约2us是个不错的估计。

问题4 [矩阵乘法策略]:要执行X[B, D]⋅DY[DX, F] → Z[B, F],本节中我们建议执行AllGatherX(Y[DX, F])并乘以完全复制的矩阵(情况2,策略1)。或者,你可以像X[B, D_X] \cdot_D Y[D_X, F] \to Z[B, F] {U_X}(情况4,策略2)那样乘以本地分片,然后\text{AllReduce}_X(Z[B, F] { U_X})。每种方法执行了多少FLOP和通信?哪种更好,为什么?

点击此处查看答案

让我们先看看基线(策略1)。如我们所示,AllGather的成本是2DF/Wici。一旦我们有了完全复制的数组,总计算时间是2BDF/C(其中C是我们的加速器FLOP/s,因为每个TPU做相同的FLOP)。所以我们有

$$ T_\text{total (Strategy 1)} = \max\left(\frac{2BDF}{C}, \frac{2DF}{W_\text{ici}}\right) $$

相比之下,新策略(策略2)对2BF字节进行AllReduce,成本为4BF/Wici,但因为计算是分片的,所以FLOP减少了1/X。这意味着我们执行2 ⋅ BDF/X FLOP,而产生的AllReduce在bfloat16中通信2 ⋅ 2 ⋅ BF字节。因此,策略2(无AllGather,只有稍后的AllReduce)的总时间大约是

$$ T_\text{total} = \max\left(\frac{2BDF}{X \cdot C}, \frac{4BF}{W_\text{ici}}\right) $$

问题是:哪个更大?D/(XC) > 2/Wici时,策略(2)受计算限制,或者当D/2X > C/Wici ≈ 2550 → X < D/(2 * 2550)时。我们可能合理预期D ≈ 8k,所以这意味着大约X < 2,这不太可能 - 因此我们基本上总是用策略2受通信限制。使用基线(策略1),当B < C/Wici = 2550时我们受通信限制,这种情况经常但并非总是如此。

所以如果B < 2550,我们在两种情况下都受通信限制,我们有

$$ T_\text{comms for Strategy 2} < T_\text{comms for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2DF}{W_\text{ici}} $$

D > 2B且2B < 5100时,这是成立的。这种情况经常出现,所以如果我们的批次较小,策略2有时会更好。当我们的批次较大(B > 2550)时,我们有

$$ T_\text{comms for Strategy 2} < T_\text{math for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2BDF}{C} $$

当2/Wici < D/C时,或者当D > 2 * 2550 = 5100时,这是成立的,对于大型模型通常如此。所以这种替代策略对大型模型通常更好,除非D很小。

为什么我们不总是这样做?实际上,我们有时可能会这样做,但矩阵乘法的一个输入在其收缩维度上沿着另一个输入未分片的轴分片的情况很少见。例如,如果我们正在做FSDP(在第5节中解释),我们会在数据维度上分片我们的参数,但我们的激活也会沿数据维度分片。所以从这个意义上说,这种情况不太常见。

问题5 [最小延迟]:假设我想在TPUv5p 4x4x4上以最低可能的延迟执行矩阵乘法A[B, D]⋅DB[D, F] → C[B, F]。我的输入应该如何分片?总FLOP和通信时间是多少?

问题6:假设我们想在TPUv5e 4x4上执行A[IX, JY]⋅JB[JY, K] → C[IX, K]。我们执行什么通信?通信与计算分别花费多少时间?

  • 如果是A[I, J]⋅B[J, K] → C[I, K]呢?这是训练的最标准设置,我们结合数据、张量和零分片。XJXYXY

  • 如果是A[I, J]⋅B[J, K] → C[I, K]呢?这是推理的标准方式,我们进行纯张量并行(+数据)。XJYXY

问题7:典型的Transformer块有两个矩阵B[D, F]和C[F, D],其中FD。以批次大小B,整个块是CBx,其中x[B, D]。让我们选择D = 8192,F = 32768,B = 128,并假设所有内容都是bfloat16格式。假设我们在TPUv5e 2x2切片上运行,但假设每个TPU只有300MB的可用内存。B、C和输出应该如何分片以保持在内存限制以下,同时最小化总时间?通信和FLOP分别花费多少时间?

问题8 [挑战]:使用上面的简短代码片段作为模板,分配一个分片数组并使用pmap或shard_map对4个主要通信原语(AllGather、AllReduce、ReduceScatter和AllToAll)进行基准测试。你需要使用jax.lax.all_gatherjax.lax.psumjax.lax.psum_scatterjax.lax.all_to_all。你理解这些函数的语义吗?它们需要多长时间?

问题9 [分片矩阵乘法的另一种策略?]上文中我们声称,当矩阵乘法的仅有一个输入沿其收缩维度分片时,我们应该对分片矩阵进行AllGather并在本地执行结果收缩。你可能想到的另一种策略是执行分片矩阵乘法,然后对结果进行AllReduce(就像两个输入都沿收缩维度分片一样),即A[I, JX]*JB[J, K] → C[I, K],通过:

  1. C[I, K] \{ U_X \} = A[I, J_X] \cdot B[J_X, K]

  2. C[I, K] = \text{AllReduce}(C[I, K] \{ U_X\})

回答以下问题:

  1. 明确写出矩阵A[N, M]和B[M, K]的这种算法,使用索引准确显示在每个设备上执行什么计算。假设A在ND设备上以A[I, J]的方式分片,你希望输出在所有设备上复制。X

  2. 如果你不需要最终结果在每个设备上复制,而是分片(沿N或K维度),上述算法如何变化?

  3. 仅从通信成本角度看上述策略(部分(b),而非(a)),这种通信成本与先对A进行AllGather然后执行矩阵乘法的算法的通信成本相比如何?

点击此处查看答案
  1. 首先计算外积,将结果存储在O[N, K]中:o = ∑ab。注意重复的索引不是正在收缩的那个,因为我们是在做外积。这里的总和范围是我们正在使用的特定设备上存储的i值集合。例如,如果我们有一个大小为16的收缩轴,和4个设备,那么在设备0上,i的范围是{0, 1, 2, 3};在设备1上,i的范围是{4, 5, 6, 7};在设备2上,i的范围是{8, 9, 10, 11};在设备3上,i的范围是{12, 13, 14, 15}。然后对每个设备上的O[N, K]部分和进行AllReduce,形成完整的O[N, K]。kjikiij

  2. 不必在步骤2中执行AllReduce,我们可以使用更便宜的ReduceScatter,沿任一轴:[N, K] { U_X } \to [N_X, K]或[N, K] { U_X } \to [N, K_X]。

  3. 如上文所述,执行AllGather(当我们受吞吐量限制时)的成本与ReduceScatter相同;它仅由我们处理的完整矩阵大小决定。因此,在gather-then-matmul算法中,这取决于NM(因为我们对A进行AllGather);在matmul-then-reduce-scatter算法中,这取决于NK(因为我们对O进行reduce-scatter)。所以两种算法的通信成本比率是M/K

问题10:AllToAll的乐趣:在上表中提到,执行AllToAll所需的时间是执行AllGather或ReduceScatter时间的1/4(在我们受吞吐量限制的情况下)。在这个问题中,我们将看到这个因子4的来源,以及如果我们只有单向ICI链接而不是双向ICI链接,这个因子会如何变化。 让我们先从单向情况开始。假设我们有D个设备形成环形拓扑,如果我们在N x N矩阵A上执行AllGather或ReduceScatter,该矩阵以A[I, J]的方式分片(假设D能整除N)。描述这两种集合操作涉及的通信,并计算在整个算法执行过程中通过单个ICI链接传输的标量(浮点数或整数)总数。

第4章:你需要了解的所有Transformer数学

计数点积

让我们从向量xy和矩阵AB的以下形状开始:

$$ \def \red#1{\textcolor{red}{#1}} \def \green#1{\textcolor{green}{#1}} \def \blue#1{\textcolor{blue}{#1}} \def \purple#1{\textcolor{purple}{#1}} \def \orange#1{\textcolor{orange}{#1}} \def \gray#1{\textcolor{gray}{#1}} \begin{array}{cc} \textrm{array} & \textrm{shape} \ \hline x & \textrm{[P]} \ y & \textrm{[P]} \ A & \textrm{[N P]} \ B & \textrm{[P M]} \ \hline \end {array} $$

  • 点积xy需要P加法乘法,或总共2P次浮点运算。

  • 矩阵向量乘积Ax沿A的行进行N次点积,需要2NP次浮点运算。

  • 矩阵矩阵乘积ABB的每一列进行M次矩阵向量乘积,总共需要2NPM次浮点运算。

  • 一般来说,如果我们有两个高维数组CD,其中一些维度是收缩维度,一些是批处理维度。(例如C[\textcolor{blue}{GH}IJ\textcolor{red}{KL}], D[\textcolor{blue}{GH}MN\textcolor{red}{KL}]),那么这种收缩的浮点运算成本是所有CD维度乘积的两倍,其中批处理和收缩维度只计算一次(例如2\textcolor{blue}{GH}IJMN\textcolor{red}{KL})。注意,只有当一个维度在两个乘数中都出现时,它才是批处理维度。(还需注意,如果没有收缩维度,这只是一个元素级乘积,那么系数2不适用。)

$$ \begin{array}{ccc} \textrm{Operation} & \textrm{FLOPs} & \textrm{Data} \ \hline x \cdot y & 2P & 2P \ A x & 2NP & NP + P \ AB & 2NPM & NP + PM \ [c_0,...,c_N] \cdot [d_0,...,d_N] & 2 \prod c_i \times \prod_{\substack{d_j \notin \textcolor{blue}{BATCH} \ d_j \notin \textcolor{red}{CONTRACT}}} d_j & \prod c_i + \prod d_j \ \hline \end {array} $$

请注意,对于矩阵矩阵乘法,计算量按立方级O(N³)增长,而数据传输仅按平方级O(N²)增长 - 这意味着随着我们扩大矩阵乘法的规模,达到计算饱和限制变得更容易。这是极其不寻常的,很大程度上解释了为什么我们使用以矩阵乘法为主的架构 - 它们适合被扩展!

前向和反向浮点运算

在训练过程中,我们并不特别关心给定矩阵乘法的结果;我们真正关心的是它的导数。这意味着我们在反向传播过程中执行明显更多的浮点运算。

如果我们想象B只是更大网络中的一个矩阵,而A是我们的输入激活值,且C = A B,那么损失L相对于B的导数由链式法则给出:

$$ \frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial B} = A^T \left(\frac{\partial L}{\partial C}\right) $$

这是一个外积,需要2NPM次浮点运算来计算(因为它在N维度上收缩)。同样,损失相对于A的导数是

$$ \frac{\partial L}{\partial A} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial A} = \left(\frac{\partial L}{\partial C}\right) B^T $$

同样需要2NPM次浮点运算,因为dL/dC是一个大小为[N, M]的(协)向量。虽然这个量不是相对于参数的导数,但它用于计算网络前面层的导数(例如,就像dL/dC用于计算上面的dL/dB一样)。

将这些加起来,我们看到在训练期间,我们总共有6NPM次浮点运算,而推理期间为2NPM:前向传播中2NPM,反向传播中4NPM。由于PM是矩阵中的参数数量,这是著名的Transformer训练期间浮点运算量近似公式 6 参数数量 词元数量的最简形式:每个词元需要6 * 参数数量的浮点运算。我们将在下面给出更准确的推导。

Transformer核算

Transformer是未来。好吧,至少它们是现在的主流。也许几年前,它们还只是众多架构中的一种。但今天,值得了解这种架构的几乎每一个细节。我们不会重新介绍这种架构,但这篇博客原始Transformer论文可能是有用的参考资料。

这是Transformer解码器架构的基本图示:

图示:该图展示了标准Transformer的一层,从上到下流动。我们使用单字母约定来描述Transformer中数组的形状和布局,再次以红色显示收缩维度,以蓝色显示批处理维度。在给定操作中,输入形状在左上角给出,参数形状在右上角给出,结果形状在下方,例如BTD是门控爱因斯坦求和的输入形状,DF是权重形状。

注意[门控爱因斯坦求和]:上图使用了"门控爱因斯坦求和",我们将上投影矩阵分成两个矩阵(上面的WIn1和WIn2),它们的输出以元素方式相乘,作为一种"门控函数"。并非所有大语言模型都使用这种方式,因此有时你会看到单个WIn矩阵和总MLP参数计数为2DF而非3DF。在这种情况下,通常会调整D和F的大小,以保持参数计数与3矩阵情况相同。尽管如此,LLAMA、DeepSeek和许多其他模型都使用某种形式的门控爱因斯坦求和。

注意2 [MHA注意力]:对于自注意力,T和S是相同的,但对于交叉注意力,它们可能不同。对于普通的多头注意力(MHA),N和K是相同的,而对于多查询注意力(MQA) K=1,对于分组MQA(GMQA),K只需要能够整除N。

全局 FLOPs 和参数计算

对于下面的内容,我们将计算每层的 FLOPs,以避免必须在各处添加 L 因子。

MLPs

Transformer 的 MLPs 通常由 2 个输入矩阵乘法(以元素方式组合)和一个输出矩阵乘法组成:

$$ \begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \ \hline \ A[B,T,\textcolor{red}{D}] \cdot W_{in1}[\textcolor{red}{D}, F] & 6BTDF & DF \[10pt] A[B,T,\textcolor{red}{D}] \cdot W_{in2}[\textcolor{red}{D}, F] & 6BTDF & DF \[10pt] \sigma\left(A_{in1}\right)[B,T, F] * A_{in2}[B,T, F] & \textcolor{gray}{O(BTF)} \[10pt] A[B,T,\textcolor{red}{F}] \cdot W_{out}[\textcolor{red}{F}, D] & 6BTDF & DF \[10pt] \hline \ & \approx 18BTDF & 3DF \end{array} $$

注意力

对于具有不同 QKV 头数的通用分组查询注意力情况,假设 QKV 投影的头维度 H 相等,并估计 QKVO 矩阵乘法的成本:

$$

\begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \ \hline \ A[B,T,\textcolor{red}{D}] \cdot W_{Q}[\textcolor{red}{D}, N, H] & 6BTDNH & DNH \[10pt] A[B,T,\textcolor{red}{D}] \cdot W_{K}[\textcolor{red}{D}, K, H] & 6BTDKH & DKH \[10pt] A[B,T,\textcolor{red}{D}] \cdot W_{V}[\textcolor{red}{D}, K, H] & 6BTDKH & DKH \[10pt] A[B,T,\textcolor{red}{N}, \textcolor{red}{H}] \cdot W_{O}[\textcolor{red}{N}, \textcolor{red}{H}, D] & 6BTDNH & DNH \[10pt] \hline \ & 12BTD(N+K)H & 2D(N+K)H \end{array} $$

点积注意力操作更为微妙,它实际上是在 BK 维度上批处理的 THHS 矩阵乘法,一个 softmax,以及再次在 BK 维度上批处理的 TSSH 矩阵乘法。我们用蓝色突出显示批处理维度:

$$ \begin{array}{cc} \textrm{operation} & \textrm{train FLOPs} \ \hline \[3pt] Q[\textcolor{blue}{B}, T, \textcolor{blue}{K}, G, \textcolor{red}{H}] \cdot K[\textcolor{blue}{B}, S, \textcolor{blue}{K}, \textcolor{red}{H}] & 6BTSKGH = 6BTSNH \[3pt] \textrm{softmax}_S ;; L[B, T, S, K, G] & \textcolor{gray}{O(BTSKG) = O(BTSN)} \[3pt] S[\textcolor{blue}{B}, T, \textcolor{red}{S}, \textcolor{blue}{K}, G] \cdot V[\textcolor{blue}{B}, \textcolor{red}{S}, \textcolor{blue}{K}, H] & 6BTSKGH = 6BTSNH \[3pt] \hline \ & \approx 12BTSNH = 12BT^2NH \ \end{array} $$

其他操作

Transformer 中还有其他几种操作。层归一化(Layernorm)相对便宜,可以在一阶成本估计中忽略。还有最终的巨大(尽管不是每层都有)反嵌入矩阵乘法。

$$ \begin{array}{ccc} \textsf{operation} & \textsf{train FLOPs} & \textsf{params} \ \hline \ \textrm{layernorm}D ;; A[B,T,\textcolor{red}{D}] & \textcolor{gray}{O\left(BTD\right)} & \textcolor{gray}{D} \[10pt] A[B,T,\textcolor{red}{D}] \cdot W{unembed}[\textcolor{red}{D}, V] & 6BTDV & DV \ \end{array} $$

Transformer FLOPs 的一般经验法则

如果我们忽略短上下文训练的点积注意力成本,那么所有层的总 FLOPs 为

$$ \begin{align*} (18BTDF + 12BTD(N+K)H)L = 6 BT * (3DF + 2D(N+K)H)L \ = 6 * \textrm{num tokens} * \textrm{parameter count} \end{align} $$

这导致了一个著名的经验法则,用于估计密集 Transformer 的 FLOP 计数,忽略注意力 FLOPs。(反嵌入是另一个简单的矩阵乘法,具有 6BSDV FLOPs 和 DV 参数,并遵循相同的经验法则。)

注意力随上下文长度的分数成本

如果我们考虑上面的点积注意力,并假设 F = 4DD = NH(这是典型的)以及 N = K

$$ \small{\frac{\textrm{attention FLOPs}}{\textrm{matmul FLOPs}} = \frac{12BT^2NH}{18BTDF + 24BTDNH} = \frac{12BT^2D}{4*18 BTD^2 + 24 BTD^2} = \frac{12BT^2D}{96 BTD^2} = \frac{T}{8D}} $$

所以结论是,点积注意力 FLOPs 只有在 T>8D 时才在训练期间占主导地位。对于 D ~ 8k,这将是 ~64K 个词元。这是有道理的,因为它意味着随着 MLP 大小的增加,注意力 FLOPs 变得不那么关键。对于大型模型,注意力的二次成本实际上不是长上下文训练的巨大障碍。然而,对于较小的模型,例如 Gemma-27B,D=4608,这意味着注意力在大约 32k 序列长度左右开始占主导地位。Flash Attention 也有助于减轻长上下文的成本,我们在附录

各种数学问题

稀疏性和专家混合模型

我们不能不简要讨论专家混合模型(MoE),它用一组可以动态路由的独立MLP替代标准Transformer中的单个密集MLP块。粗略地说,MoE就是每层有E个MLP块的普通密集模型,而不是只有一个。每个词元激活这些专家中的k个,通常k = 2。与密集版本相比,这将参数数量增加了O(E)倍,同时将每个词元激活的参数总数乘以k

图片:一个具有n个专家的MoE层示例。门控专家将每个词元路由到其中的k个专家,这k个MLP的输出被求和。我们的参数数量是每个专家大小的n倍,但每个词元只使用其中的k个。图片来源.

与密集模型相比,MoE引入了新的通信操作,主要是两个AllToAll操作(一个在MoE块之前,一个在之后),它们将词元路由到正确的专家并将其带回原始设备。然而,正如我们在上一节中看到的,每个AllToAll的成本仅为单轴上可比较的AllGather的1/4(对于双向环)。

梯度检查点

反向传播作为一种算法是用内存换计算。不需要O(nlayers2)个浮点运算的反向传递,它需要O(nlayers)内存,保存前向传递过程中生成的所有中间激活值。虽然这比二次方计算要好,但内存消耗非常大:一个批量中包含B * T = 4M(总共400万个词元)、L=64和D=8192的模型,如果要避免所有不必要的反向传递计算,将需要保存大约2 20 B * T * D * L = 84TB的bfloat16激活值。这里的20是对上面Transformer图中每个中间节点的(大致)计数,例如

$$ f(x) = exp (g(x)) $$

$$ \frac{df}{dx} = \exp(g(x)) \cdot \frac{dg}{dx} $$

因此为了避免重新计算,我们需要从前向传递中保存g(x)和exp(g(x))。为了避免保存这么多内存,我们可以选择只保存部分中间激活值。以下是我们使用的几种策略。

  • 块重新计算:只保存每层的输入。这是我们使用的最激进方法,每层只保存1个检查点,意味着在上面的例子中我们只需保存4.2TB。这迫使我们在反向传递中重复基本上所有前向传递的浮点运算,这意味着我们的浮点运算从6ND增加到大约8ND

  • 只保存大型矩阵乘法:另一个简单的策略是只保存大型矩阵乘法的输出。这让我们避免在反向传递期间重新计算任何大型矩阵乘法,但仍然需要重新计算其他激活函数和注意力的部分内容。这将每层的20个检查点减少到接近7个。

这绝不是全面的。在使用JAX时,这些通常由jax.remat/jax.checkpoint控制(您可以在这里阅读更多信息)。

键值(KV)缓存

正如我们将在第7节中看到的,LLM推理有两个关键部分,预填充和生成。

  • 预填充处理长提示并将其注意力激活保存在键值缓存(KV Cache)中,以便在生成时使用,特别是注意力块中的键值投影。

  • 生成将多个KV缓存批量处理在一起,并从每个缓存中采样词元。

每个KV缓存实际上是一个大小为[2, S, L, K, H]的数组,其中2表示键和值。这非常大!int8格式的键值缓存总大小是2SLKH。对于一个中等规模的模型,上下文长度为8k,64层,且KH = NH = D = 8192,这是2 ⋅ 8192 ⋅ 64 ⋅ 8192 = 8GiB。你可以看到为什么我们想使用KN的GMQA。

从本节中你应该获得什么?

  • Transformer的整体参数和FLOPs计算相对简单,在此总结如下,假设使用MHA(批次大小为B,词汇量为V,序列长度为T,D=d,且F=d):模型前馈
组件每层参数每层训练FLOPs
MLP3DF18BTDF
注意力4DNH24BTDNH + 12BT²NH
其他DBTD
词汇DV(总计,非每层)12BTDV
  • MLP块的参数数量主导了总参数数量,且MLP块也主导了FLOPs预算,只要序列长度T<8D

  • 对于合理的上下文长度,训练期间的总FLOPs预算可以很好地用6·参数数量·词元数量来近似。

  • 在推理过程中,我们的KV缓存大约为每个缓存2·S·L·N·H,尽管架构修改通常可以减少这一数量。

一些需要解决的问题

问题1:一个模型,其中 D = 4096,F = 4 ⋅ DV = 32,000,以及 L = 64,有多少参数?这些参数中有多大比例是注意力参数?我们的KV缓存每个词元有多大?你可以假设 N ⋅ H = D 并且使用int8 KV的多头注意力。

点击此处查看答案。
  1. 总参数数量大致为 L ⋅ (3DF + 4DNH + D) + 2DV。对于给定的数字,这是 64 ⋅ (3 ⋅ 4e3 ⋅ 16e3 + 4 ⋅ 4e3 ⋅ 4e3 + 4e3) + 2 ⋅ 4e3 ⋅ 32e3 = 16e9,即160亿参数。

  2. 注意力参数与总参数的比例一般是 4DNH/(4DNH + 3DF) = 4D/(4D + 12D) = 1/4。这意味着大约1/4的参数用于注意力。222

  3. 每个词元的KV缓存为 2 ⋅ LNH = 2 ⋅ 64 ⋅ 4096(以int8计),即 512kB/词元

问题2:{'X': 4, 'Y': 8, 'Z': 4} 上执行 A[BX, DY] *D W[DY, F] 需要多少总浮点运算?每个TPU执行了多少浮点运算?

点击此处查看答案。

该操作的总"理论"浮点运算是 2 ⋅ BDF。然而,由于计算没有在Z维度上分片,我们实际上进行了Z倍的额外浮点运算,意味着总共有 2 ⋅ BDFZ 浮点运算。由于计算在其他维度上是分片的,每个设备的总计大约是 2 ⋅ BDF/(XY)。

问题3:执行 A[I, J, K, L] * B[I, J, M, N, O] → C[K, L, M, N, O] 涉及多少浮点运算?

点击此处查看答案。

根据上面的规则,我们有I和J作为收缩维度,K、L、M、N和O作为非收缩维度。我们没有"批处理维度",所以这只是 2 ⋅ IJKLMNO,所有轴的总和。如果我们有一个共享轴,它只会被计算一次。

问题4:自注意力的算术强度是多少(忽略Q/K/V/O投影)?给出Q和KV长度T和S的函数答案。在什么上下文长度下,注意力是计算受限的?考虑到我们TPU的HBM带宽,随着上下文长度的增长,绘制注意力对前馈网络块的有效相对成本。

点击此处查看答案。

自注意力需要加载 QKV 激活,然后计算 softmax(QK) ⋅ V,然后将结果写回HBM。这将使用Flash Attention完成,所以这个数学计算有一些注意事项,但基本上在bf16中自注意力执行

Q[B,T,N,H]→reshapeQ[B, T, K, G, H] ⋅ K[B, S, K, H] → O[B, T, S, K, G]

U = softmaxS(O[B, T, S, K, G])

U[B, T, S, K, G] ⋅ V[B, S, K, H] → X[B, T, K, G, H]

所以我们的总字节数是 2 * sizeof(Q) + 2 * sizeof(K or V) = 4BTNH + 4BSKH = 4BHK * (TG + S),总浮点运算是 4BTSNH + O(BTSN),算术强度是 4BTSKGH/(4BHK * (TG + S))。

所以基本上,在预填充阶段,我们有 S = T,所以我们有一个算术强度为 4BT2KGH/4BHKT ⋅ (G + 1) = TG/(G + 1) = O(T)。在生成阶段,T = 1,所以我们有 4BSKGH/(4BHK ⋅ (G + S)) = SG/(G + S) → G,假设 S 非常大。根据你如何解释这个问题,在预填充或训练过程中,假设没有序列分片,自注意力在S=240时是计算受限的。在生成过程中,我们永远不会计算受限,因为 G 很小。然而,你可以看到增加 G 会使我们更接近计算受限。

问题5:在什么序列长度下,自注意力的浮点运算等于QKVO投影的浮点运算?

点击此处查看答案。

这纯粹是一个何时 24BTDNH = 12BT2NH 的问题。简化后我们得到 2D = T,所以例如对于 D = 4096,这是8192。这告诉我们,对于大多数合理的上下文长度,矩阵乘法浮点运算更大。

问题6:假设我们在前向传递过程中只保存Transformer层中7个主要矩阵乘法的输出(Q、K、V、O加上三个前馈网络矩阵)。在反向传播过程中,我们需要多少额外的浮点运算来"重新具体化"?

问题7:DeepSeek v3表示它在14.8T个词元上训练了2.79M H800小时(来源)。考虑到它有37B激活参数,他们大致达到了什么硬件利用率?提示:注意他们使用了没有结构化稀疏性的FP8浮点运算。

点击此处查看答案。

这里的规格表中,我们发现有稀疏性的FP8性能为3,026 TFLOPs/s,没有稀疏性时通常是这个的一半(1.513e15 FLOPs/s)。2.79M H800小时意味着 2.79e6 * 1.513e15 * 60 * 60 = 1.52e25 总浮点运算。考虑到37B的激活参数数量,这次训练应该使用了大约 6 * 37e9 * 14.8e12 = 3.3e24 浮点运算。这意味着浮点运算利用率约为 3.3e24 / 1.52e25 = 21.7%

问题8:混合专家模型(MoE)有 E 份标准密集MLP块的副本,每个词元激活其中的 k 个专家。在TPU v5e上使用int8权重的MoE需要多大的批次大小(以词元计)才能达到计算受限?对于DeepSeek,它有256个(路由)专家和 k = 8,这个数字是多少?

点击此处查看答案。

因为我们有 E 份每个专家的副本,在int8中,我们需要加载 EDF 字节。因为每个词元激活 k 个专家,我们有 2 ⋅ kBDF 浮点运算。要在bfloat16浮点运算中达到计算受限,我们需要算术强度超过240,这发生在 (2 ⋅ kBDF)/EDF > 240 或 kB/E > 120。

因此,我们需要 B > 120 ⋅ E/k 才能达到计算受限。对于DeepSeek,这给我们 B > 120 ⋅ 256/8 = 3840。这在生成时是一个非常大的批次大小。

第4章附录 A:Flash Attention 是如何工作的?

传统上对扩展 Transformer 到非常长的上下文的反对意见是,注意力的浮点运算和内存使用量随上下文长度呈二次方增长。虽然注意力 QK 乘积确实具有 [B, S, T, N] 的形状,其中 B 是批次大小,S 和 T 是 Q 和 K 的序列维度,N 是头的数量,但这种说法有一些严重的注意事项:

  1. 正如我们在第 4 部分中所指出的,尽管这是二次方的,但注意力浮点运算仅在 S > 8 · D 时占主导地位,尤其是在训练期间,单个注意力矩阵的内存与内存中所有的权重和激活检查点相比很小,特别是在分片时。

  2. 我们不需要具体化完整的注意力矩阵来计算注意力!我们可以计算局部和与最大值,避免具体化超过数组的一小部分。虽然总浮点运算仍然是二次方的,但我们大大减少了内存压力。

这第二个观察最初由 Rabe 等人 2021 提出,后来在 Flash Attention 论文(Dao 等人 2022)中再次提出。基本思想是以 K/V 的块计算注意力,我们计算局部 softmax 和一些辅助统计数据,然后将它们传递到下一个块,该块将它们与其局部块组合。具体来说,我们计算

  1. M:在序列维度上 q · k 的运行最大值

  2. O:在序列维度上运行的完整注意力 softmax

  3. L:运行的分母 ∑(q · k − 运行最大值)ii

有了这些,我们可以用常量内存计算新的最大值、新的运行和以及新的输出。为了粗略描述这是如何工作的,注意力大致是这个操作:

\text{Attn}(Q, K, V) = \sum_i \frac{\exp(Q \cdot K_i - \max_j Q \cdot K_j) V_i}{\sum_l \exp(Q \cdot K_l - \max_j Q \cdot K_j)}

为了数值稳定性减去最大值,并且可以加上而不影响结果,因为 ∑iexp (ai + b) = exp (b)∑exp (a)。仅看上面的分母,如果我们想象有两个连续的键向量块,K1 和 K2,我们为每个计算局部 softmax 和 L1 和 L2

$$ L1 = ∑iexp (Q · Ki1 − maxjQ · Kj1) $$

$$ L2 = ∑iexp (Q · Ki2 − maxjQ · Kj1) $$

然后我们可以通过使用以下方式将这些组合成这两个块的完整 softmax 和

$$ Lcombined = exp (M1 − max (M1, M2)) · L1 + exp (M2 − max (M1, M2)) · L2 $$

其中

$$ M1 = maxjQ · Kj1 和 M2 = maxjQ · Kj2 $$

这也可以用于完整的 softmax,给我们一种累积任意大的 softmax 和的方法。这是 Flash Attention 论文中的完整算法。

从硬件角度看,这让我们可以将 Q 的块放入 VMEM(上面算法称为片上 SRAM)中,这样我们在每次迭代中只需要加载 KV 块,降低了算术强度。我们也可以在 VMEM 中保存运行统计数据。

最后一个值得强调的微妙点是一个注意力 softmax 属性,它被用来使 Flash VJP(反向模式导数)计算在训练中变得实用。如果我们定义一个中间 softmax 数组为:

$$ S_{ij} = \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_j}} $$

在注意力中,我们从反向模式 dOV 数组获得 dS

$$ dSij = dOid·dVjd = ∑ddOidVjd $$

在这个梯度反向传播到 Q 和 K 期间

$$ d(qi · kj) = (dSij − Sij·jdSij)Sij $$

我们利用一个恒等式,它允许我们将沿着大键长度维度的收缩与沿特征深度维度的局部收缩交换。

$$ \begin{align*} S_{ij} \cdot_j dS_{ij} &= \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} \sum_d dO_{id} V_{jd} \ &= \sum_d dO_{id} \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} V_{jd} \ &= \sum_d dO_{id} O_{id} \ &= dO_{id} \cdot_d O_{id} \end{align*} $$

这种替换对于能够实现序列块局部计算 VJP 至关重要,并且支持更多巧妙的分片方案,如环形注意力。

第5章:如何并行化Transformer进行训练

我们所说的扩展是什么?

"模型扩展"的目标是能够增加用于训练或推理的芯片数量,同时实现吞吐量的成比例、线性增长(我们称之为强扩展)。虽然单个芯片上的性能取决于内存带宽和FLOP之间的权衡,但集群级别的性能则取决于通过与有用的FLOPS重叠来隐藏芯片间通信。这并非易事,因为增加芯片数量会增加通信负载,同时减少我们可用于隐藏通信的每设备计算量。正如我们在第3节中所见,分片矩阵乘法通常需要昂贵的AllGather或ReduceScatter操作,这可能会阻碍TPU执行有用的工作。本节的目标是找出这些操作何时变得过于昂贵

在本节中,我们将讨论四种常见的并行方案:(纯)数据并行、完全分片数据并行(FSDP / ZeRO分片)、张量并行(也称为模型并行)和(简要介绍)流水线并行。对于每种方案,我们将展示产生的通信成本以及这些成本开始成为计算瓶颈的时机。在本节中,您可以仅关注芯片间通信成本,因为只要我们有足够大的单芯片批量大小,从HBM到MXU的数据传输已经与计算重叠。

我们将使用以下符号来简化本节中的计算。

符号含义(模型参数)
Ddmodel(隐藏维度/残差流维度)
Fdff(前馈维度)
B批量维度(批量中的标记数量;总数,非每设备)
T序列长度
L模型中的层数
符号含义(硬件特性)
C每芯片FLOPS/秒
W网络带宽(双向,通常下标为如Wici或Wdcn
X沿网格轴X的芯片数量
Y沿另一个网格轴(标记为Y)的芯片数量
Z沿第三个网格轴(标记为Z)的芯片数量

为简单起见,我们将Transformer近似为MLP块的堆栈 — 如我们在第4节中所见,对于较大的模型,注意力占FLOP的比例相对较小。我们还将忽略门控矩阵乘法,使每层具有以下简单结构:

图示:简化的Transformer层。我们将每个FFW块视为两个矩阵的堆栈Win:bf16[D, F](上投影)和Wout:bf16[F, D](下投影),输入为In:bf16[B, D]。

以下是我们将讨论的4种并行方案。每种方案可以被视为由上图中InWin、Wout和Out的独特分片方式定义。

1. 数据并行:激活沿批量维度分片,参数和优化器状态在每个设备上复制。通信仅在反向传播过程中发生。

$$ In[BX, D]⋅DWin[D, F]⋅FWout[F, D] → Out[BX, D] $$

2. 完全分片数据并行(FSDP或ZeRO-3):激活沿批量维度分片(类似纯数据并行),参数沿相同网格轴分片并在前向传播使用前及时进行AllGather操作。优化器状态也沿批量维度分片。减少重复内存。

$$ In[BX, D]⋅DWin[DX, F]⋅FWout[F, DX] → Out[BX, D] $$

3. 张量并行(也称为Megatron分片或模型并行):激活沿D(dmodel)分片,参数沿F(dff)分片。在每个块前后对激活进行AllGather和ReduceScatter操作。与FSDP兼容。

$$ In[B, DY]⋅DWin[D, FY]⋅FWout[FY, D] → Out[B, DY] $$

4. 流水线并行:权重沿层维度分片,激活被微批处理并沿层维度滚动。流水线阶段之间的通信最小(仅通过单跳移动激活)。滥用符号表示:

$$ In[LZ, B, D][i]⋅DWin[LZ, D, F][i]⋅FWout[LZ, F, D][i] → Out[LZ, B, DY][i] $$

数据并行

$$ 语法: In[BX, D]⋅DWin[D, F]⋅FWout[F, D] → Out[BX, D] $$

当你的模型能够在单个芯片上运行,即使只有很小的批量大小(>240个标记,以便计算受限),你也应该始终使用简单的数据并行。纯数据并行将我们的激活分散到任意数量的TPU上,只要TPU的数量小于我们的批量大小。前向传播不涉及通信,但在每一步结束时,每个TPU都会对其梯度执行AllReduce操作,以便在更新参数之前同步它们。

图示:纯数据并行的示意图(前向传播)。我们的激活(左侧)在批量维度上完全分片,而权重完全复制,因此每个TPU都有相同的权重副本。这意味着我们的权重总内存增加了N倍,但前向传播不需要通信。

这里是前向和后向传播的完整算法。为了简洁起见,我们简化表示法,将dL/dOut写为dOut。

纯数据并行算法:

前向传播:需要计算Loss[BX]

  1. Tmp[B, F] = In[B, D] * W[D, F]XXDin

  2. Out[B, D] = Tmp[B, F] * W[F, D]XXFout

  3. Loss[B] = …X

后向传播:需要计算dWout[F, D], dWin[D, F]

  1. dOut[B, D] = …X

  2. dW[F, D] {U} = Tmp[B, F] * dOut[B, D]outXXBX

  3. dW[F, D] = AllReduce(dW[F, D] {U})(不在关键路径上,可以异步完成)outoutX

  4. dTmp[B, F] = dOut[B, D] * W[F, D]XXDout

  5. dW[D, F] {U} = In[B, D] * dTmp[B, F]inXXBX

  6. dW[D, F] = AllReduce(dW[D, F] {U})(不在关键路径上,可以异步完成)ininX

  7. dIn[B, D] = dTmp[B, F] * W[D, F](前面层所需)XXFin

我们忽略损失函数的细节,并将Tmp = Win ⋅ In简写。请注意,虽然我们的最终损失是平均AllReduce(Loss[BX]),但我们只需要在后向传播时计算AllReduce来平均权重梯度。

注意,前向传播没有通信 — 所有通信都在后向传播中!后向传播还有一个很好的特性,即AllReduce操作不在"关键路径"上,这意味着每个AllReduce可以在方便的时候执行,不会阻塞你执行后续操作。如果总通信成本超过了我们的总计算成本,整体通信成本仍然可能成为瓶颈,但从实现角度来看,它更加宽容。我们将看到模型/张量并行没有这个特性。

为什么要这样做?纯数据并行通过在批量维度上拆分激活来减少激活内存压力,只要我们有更多芯片来分割批量维度,就可以几乎任意增加批量大小。在训练期间,当我们的激活通常主导内存使用时,这非常有用。

为什么不这样做?纯数据并行对减少模型参数或优化器状态的内存压力没有帮助,这意味着纯数据并行对于大规模有趣模型(其参数+优化器状态无法装入单个TPU)很少有用。为了给出规模感,如果我们使用bf16参数和fp32 Adam优化器状态进行训练,我们能容纳的最大模型有TPU内存/10个参数,例如在拥有96GB HBM和纯数据并行的TPUv5p pod上,这大约是90亿参数。

💡 要点:使用Adam和纯数据并行训练的最大模型具有参数数量 = 每设备HBM/10。对于TPU v5p,这大约是90亿参数。

要使其对训练中的实际模型有用,我们至少需要部分分片模型参数或优化器。

何时会被通信瓶颈?如上所示,我们每层有两个AllReduce操作,每个大小为2DF(对于bf16权重)。什么时候数据并行会使我们受通信限制?

如上表所示,设C = 每芯片FLOP,Wici = 双向网络带宽,X = 批量分区的分片数量。让我们计算执行相关矩阵乘法所需的时间Tmath和所需的通信时间Tcomms。由于此并行方案在前向传播中不需要通信,我们只需要计算后向传播的这些量。

通信时间:从前面的部分我们知道,在1D网格中执行AllReduce所需的时间仅取决于被AllReduce的数组的总字节数和ICI带宽Wici;具体来说,AllReduce时间是2 ⋅ 总字节数/Wici。由于我们需要为Win和Wout进行AllReduce,我们每层有2个AllReduce。每个AllReduce都是针对权重矩阵的,即DF参数的数组,或2DF字节。将这些全部放在一起,单层中AllReduce的总时间是

$$ \begin{align} T_\text{comms} &= \frac{2 \cdot 2 \cdot 2 \cdot D \cdot F}{W_\text{ici}}. \ \end{align} $$

矩阵乘法时间:每层在前向传播中包括两个矩阵乘法,或在后向传播中包括四个矩阵乘法,每个都需要2(B/X)DF FLOP。因此,对于后向传播中的单个层,我们有

$$ \begin{align} T_\text{math} &= \frac{2 \cdot 2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \ \end{align} $$

由于我们重叠执行,每层的总时间是这两个量的最大值:

$$ \begin{aligned} T &\approx \max(\frac{8 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{8 \cdot D \cdot F}{W_\text{ici}}) \ T &\approx 8 \cdot D \cdot F \cdot \max(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}) \end{aligned} $$

Tmath/Tcomms > 1时,我们变为计算受限,或者当

$$ \begin{align} \frac{B}{X} &> \frac{C}{W_\text{ici}}. \end{align} $$

结论是,要在数据并行下保持计算受限,我们需要每设备批量大小B/X超过ICI操作强度C/Wici。这最终是由于计算时间与每设备批量大小成比例,而通信时间与此量无关(因为我们传输的是模型权重)。注意B > C/Wici条件与单设备计算受限规则B > 240的相似性;在那种情况下,规则也来自计算时间与批量大小成比例,而数据传输大小(在BF, D情况下)与批量大小无关。

让我们放入一些实际数字来了解规模。对于TPUv5p,C=4.6e14W=2 * 9e10用于ICI上的1D数据并行,所以我们每芯片的批量大小必须至少为2,550才能避免受通信限制。由于我们可以在多个轴上进行数据并行,如果我们将TPUv5p pod的所有三个轴都用于纯数据并行,我们的带宽Wici会增加3倍,可以将批量大小降至每TPU仅850或每pod(8960个芯片)每批760万个标记!这告诉我们,纯数据并行很难成为瓶颈!

💡 关于上下文并行的说明:在本节中,我们使用B表示标记中的总批量大小。然而,显然我们的批量由K个序列组成,每个序列有T个标记,那么我们如何做到这一点?对于MLP来说,标记就是标记!它们是属于同一批次还是不同批次并不重要。因此,我们或多或少可以自由地在批量和序列维度上进行数据并行:我们称之为上下文并行或序列并行,但你可以将其视为另一种数据并行。注意力比MLP更复杂,因为我们进行一些跨序列计算,但这可以通过在注意力期间收集KV或Q并仔细重叠FLOP和通信(通常使用所谓的"环注意力")来处理。在整个本节中,我们将完全忽略序列维度,假设有一定量的批量或序列并行。

完全分片数据并行(FSDP)

$$ 语法: 输入[BX, D]⋅DWin[DX, F]⋅FWout[F, DX] → 输出[BX, D] $$

完全分片数据并行(通常称为FSDP或ZeRO分片)将模型优化器状态和权重分散到数据并行分片上,并在需要时高效地收集和分散它们。与纯数据并行相比,FSDP大幅减少了每设备内存使用,并节省了反向传播的浮点运算,同时几乎没有额外开销。

图示: FSDP沿数据维度分片Win的收缩维度和Wout的输出维度。这减少了内存,但(如第3节所述)在执行矩阵乘法前需要我们收集权重W。注意,激活值(左侧)没有沿收缩维度分片,这就是迫使我们进行收集的原因。注意,我们的权重优化器状态也同样沿收缩维度分片。

您会记得(来自第3节),AllReduce可以分解为AllGather和ReduceScatter。这意味着,与标准数据并行中进行完整梯度AllReduce不同,我们可以在芯片间分片权重和优化器状态,在前向传播期间对每层进行AllGather,并在反向传播期间对权重进行ReduceScatter,而不会产生额外成本。

这是FSDP的完整算法。

完全分片数据并行(FSDP):

前向传播:需要计算Loss[BX]

  1. W[D, F] = AllGather(W[D, F])(不在关键路径上,可以在上一层期间完成)ininX

  2. Tmp[B, F] = In[B, D] * W[D, F](现在可以丢弃Win[D, F])XXDin

  3. W[F, D] = AllGather(W[F, D])(不在关键路径上,可以在上一层期间完成)outoutX

  4. Out[B, D] = Tmp[B, F] * W[F, D]XXFout

  5. Loss[B] = …X

反向传播:需要计算dWout[F, DX], dWin[DX, F]

  1. dOut[B, D] = …X

  2. dW[F, D] {U} = Tmp[B, F] * dOut[B, D]outXXBX

  3. dW[F, D] = ReduceScatter(dW[F, D] {U})(不在关键路径上,可以异步完成)outXoutX

  4. W[F, D] = AllGather(W[F, D])(可以提前完成)outoutX

  5. dTmp[B, F] = dOut[B, D] * W[F, D] (这里可以丢弃Wout[F, D])XXDout

  6. dW[D,F] {U} = dTmp[B, F] * In[B, D]inXXBX

  7. dW[D, F] = ReduceScatter(dW[D, F] {U}) (不在关键路径上,可以异步完成)inXinX

  8. W[D, F] = AllGather(W[D, F])(可以提前完成)ininX

  9. dIn[B, D] = dTmp[B, F] * W[D, F](前面层需要)(这里可以丢弃Win[D, F])XXFin

这也被称为"ZeRO分片",来源于"零开销分片"(ZeRo Overhead sharding),因为我们不执行任何不必要的计算或存储任何不必要的状态。ZeRO-{1,2,3}分别用于指代以这种方式分片优化器状态、梯度和权重。由于所有方法的通信成本相同,我们基本上总是可以进行ZeRO-3分片,它将参数、梯度和优化器状态分片到一组设备上。

为什么要这样做?标准数据并行涉及大量重复工作。每个TPU对完整梯度进行AllReduce,然后更新完整的优化器状态(所有TPU上的相同工作),然后更新参数(再次完全重复)。对于ZeRO分片(分片梯度/优化器状态),与其进行AllReduce,您可以对梯度进行ReduceScatter,仅更新您的优化器状态分片,更新参数分片,然后根据前向传播的需要对参数进行AllGather。

何时会受通信瓶颈限制?我们的相对浮点运算和通信成本与纯数据并行完全相同,因为反向传播中的每个AllReduce都变成了AllGather + ReduceScatter。回想一下,AllReduce是通过AllGather和ReduceScatter实现的,每个操作成本减半。这里我们建模前向传播,因为它与反向传播具有相同的浮点运算与通信比率:

$$ \begin{aligned} T_{math} &= \frac{2 \cdot 2 \cdot B \cdot D \cdot F}{X \cdot C} \ T_{comm} &= \frac{2 \cdot 2 \cdot D \cdot F}{W_\text{ici}} \ T &\approx \max\left(\frac{4 \cdot B \cdot D \cdot F}{X \cdot C}, \frac{4 \cdot D \cdot F}{W_\text{ici}}\right) \ T &\approx 4 \cdot D \cdot F \cdot \max\left(\frac{B}{X \cdot C}, \frac{1}{W_\text{ici}}\right) \end{aligned} $$

因此,与纯数据并行一样,当B/X > C/Wici时,我们受计算限制,即当每设备批量大小B/X超过"ICI操作强度"C/Wici(v5p上为4.59e14 / 1.8e11 = 2550)时。这对我们来说很棒,因为这意味着如果我们的每设备批量大小足够大,可以在纯数据并行下受计算限制,我们可以——无需担心离开计算限制区域——直接升级到FSDP,从而为我们节省大量参数和优化器状态内存!虽然我们确实在前向传播中增加了通信,但这个成本并不重要,因为它只是与前向传播的浮点运算重叠。

💡 要点:当每设备批量大小小于2550/n轴时,FSDP和纯数据并行在TPUv5上都会受到带宽限制。

例如,DeepSeek-V2(少数几个发布训练批量大小信息的近期强大模型之一)使用了约4000万个标记的批量大小。这将允许我们扩展到大约47,000个芯片,或约5个TPUv5 pod,然后我们才会遇到带宽限制。

对于LLaMA-3 70B,它训练了大约6.3e24 (15e12 * 70e9 * 6)浮点运算,我们可以将1600万个标记的批次分散到大约16e6 / (2550 / 3) = 18,823个芯片上(大约2个pod,每个pod 8960个芯片),每个芯片有4.59e14浮点运算,以50%的峰值浮点运算利用率运行(通常称为MFU),大约17天内就能完成训练。不错!但让我们探索如何做得更好。

关于临界批量大小的说明:有些反直觉的是,随着总批量大小减小(固定芯片数量),我们受通信瓶颈的限制会更严重。只要我们能够不断增加批量大小,数据并行和FSDP就能让我们扩展到任意多的芯片!然而,实际上,随着批量大小增加,我们往往会在训练中看到收益递减,因为我们的梯度变得几乎无噪声。我们有时也会看到训练不稳定。因此,在"无限计算资源"情况下寻找最佳分片方案的游戏通常从由缩放定律确定的固定批量大小和已知(大量)芯片数量开始,然后寻找一种划分方式,使我们能够将那么小的批量大小适配到这么多芯片上。

张量并行

$$ 语法: In[B, DY]⋅DWin[D, FY]⋅FWout[FY, D] → Out[B, DY]\(我们使用Y最终与FSDP结合) $$

在完全分片数据并行AllReduce中,我们跨芯片移动权重。我们还可以对模型的前馈维度进行分片,并在层处理期间移动激活值——这被称为"1D模型并行"或Megatron分片。这可以使每个pod的高效批量大小变小。下图展示了以这种方式分片的单个矩阵示例:

图示:基本张量并行的示例。由于我们只在Y维度上分片激活(与FSDP中在X维度上分片不同),我们在X维度上复制激活。使用我们的标准语法,这是A[B, DY] * B[D, FY] -> C[B, FY]。因为我们只在一个收缩维度上分片,我们通常在矩阵乘法前对激活A进行AllGather操作。

如上所述,In[B, DY] D Win[D, FY] F Wout[FY, D] -> Out[B, DY]意味着我们必须在第一次矩阵乘法之前收集我们的激活。当激活比权重小时,这比ZeRO分片更便宜。这通常只有在添加一定量的ZeRO分片时才成立(这减少了收集的大小)。这是我们倾向于混合ZeRO分片和模型并行的原因之一。

这是张量并行的算法!

张量并行:

前向传播:需要计算Loss[B]

  1. In[B, D] = AllGather(In[B, D]) (在关键路径上)Y

  2. Tmp[B, F] = In[B, D] * W[D, F] (不在收缩维度上分片,所以没有通信)YDinY

  3. Out[B, D] {U} = Tmp[B, F] * W[F, D]YYFoutY

  4. Out[B, D] = ReduceScatter(Out[B, D] {U}) (在关键路径上)YY

  5. Loss[B] = …

反向传播:需要计算dWout[FY, D], dWin[D, FY]

  1. dOut[B, D] = …Y

  2. dOut[B, D] = AllGather(dOut[B, D]) (在关键路径上)Y

  3. dW[F, D] = Tmp[B, F] * dOut[B, D]outYYB

  4. dTmp[B, F] = dOut[B, D] * W[F, D] (这里可以丢弃dOut[B, D])YDoutY

  5. In[B, D] = AllGather(In[B, D]) (这可以通过与前向传播中的(1)共享来跳过)Y

  6. dW[D, F] = dTmp[B, F] * In[B, D]inYYB

  7. dIn[B, D] {U.Y} = dTmp[B, F] * W[D, F] (前一层所需)YFinY

  8. dIn[B, D] = ReduceScatter(dIn[B, D] {U.Y}) (在关键路径上)Y

张量并行的一个优点是它与Transformer前向传播中的两个矩阵交互良好。天真地说,我们会在每两个矩阵之后进行AllReduce。但在这里,我们首先执行In[B, DY] * Win[D, FY] -> Tmp[B, FY],然后执行Tmp[B, FY] * Wout[FY, D] -> Out[B, DY]。这意味着我们在开始时对In进行AllGather,在结束时对Out进行ReduceScatter,而不是执行AllReduce。

这有多昂贵?让我们只建模前向传播 - 反向传播只是这里每个操作的转置。在1D模型并行中,我们在第一次矩阵乘法之前对激活进行AllGather,在第二次之后进行ReduceScatter,每次发送两个字节(bf16)。让我们弄清楚什么时候会受到通信瓶颈。

$$ \begin{align} T_{math} & = \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} \ T_{comms} & = \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\ \textnormal{T} & \approx \max \left(\frac{4 \cdot B \cdot D \cdot F}{Y \cdot C}, \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}}\right) \end{align} $$

注意到我们希望计算成本大于通信成本,我们得到:

$$ \begin{align} \frac{4 \cdot B \cdot D \cdot F}{Y \cdot C} > \frac{2 \cdot 2 \cdot (B \cdot D)}{W_\text{ici}} \end{align} $$

$$ \begin{align} \frac{F}{Y \cdot C} > \frac{1}{W_\text{ici}} \end{align} $$

$$ \begin{align} F > Y \cdot \frac{C}{W_\text{ici}} \end{align} $$

因此,例如,对于TPUv5p,C/Wici = 2550(以bf16为单位),所以我们只能进行张量并行,最多到Y < F/2550。当我们有多个ICI轴时,我们的Tcomms减少了naxes倍,所以我们得到Y < naxes * F/2550。

💡 要点:当Y > naxes * F/2550时,模型并行变得受通信限制。对于大多数模型,这在8到16路模型并行之间。

注意,这与计算精度无关,因为例如对于int8,在TPUv5p上,Cint8/Wici是5100而不是2550,但通信量也减半了,所以两个因素抵消了。

让我们考虑一些例子:

  • 在TPUv4p上使用LLaMA 3-70B,其中D = 8192,F ≈ 30,000,我们可以舒适地进行8路模型并行,但在16路模型并行上将受到通信限制。8路模型分片所需的F为20k。

  • 对于Gemma 7B,F ≈ 50k,所以我们在19路模型并行时才会受到通信限制。这意味着我们可能可以进行16路并行并仍然获得良好的性能。

混合FSDP和张量并行

$$ 语法: 输入[BX, DY]⋅DWin[DX, FY]⋅FWout[FY, DX] → 输出[BX, DY] $$

FSDP和张量并行的优点是它们可以结合使用。通过沿两个轴对WinWout进行分片,我们既节省了内存又减少了计算量。因为我们沿X轴分片B,从而减小了模型并行AllGather的大小,同时因为沿Y轴分片F,减少了FSDP的通信开销。这意味着两者的结合可以让我们达到比上文所述更低的有效批量大小。

图示:结合FSDP和张量并行的示意图。与其他情况不同,这里没有模型参数的重复。

这里是混合FSDP+张量并行的完整算法。虽然我们有大量通信,但所有AllGather和ReduceScatter操作都变小了,因为我们已经对激活进行了批量分片,对权重进行了更多的张量分片!

前向传播:需要计算Loss[B]

  1. In[B, D] = AllGather(In[B, D]) (在关键路径上)XYXY

  2. W[D, F] = AllGather(W[D, F]) (可以提前完成)inYXinXY

  3. Tmp[B, F] = In[B, D] * W[D, F]XYXDinY

  4. W[F, D] = AllGather(W[F, D]) (可以提前完成)outYXoutYX

  5. Out[B, D] {U.Y} = Tmp[B, F] * W[F, D]XXYFoutY

  6. Out[B, D] = ReduceScatter(Out[B, D] {U.Y}) (在关键路径上)XYYX

  7. Loss[B] = …X

反向传播:需要计算dWout[FY, DX], dWin[DX, FY]

  1. dOut[B, D] = …XY

  2. dOut[B, D] = AllGather(dOut[B, D]) (在关键路径上)XYXY

  3. dW[F, D] {U.X} = Tmp[B, F] * dOut[B, D]outYXYBX

  4. dW[F, D] = ReduceScatter(dW[F, D] {U.X})outYXXoutY

  5. W[F, D] = AllGather(W[F, D]) (可以提前完成)outYXoutYX

  6. dTmp[B, F] = dOut[B, D] * W[F, D] (这里可以丢弃dOut[B, D])XYXDoutY

  7. In[B, D] = AllGather(In[B, D]) (不在关键路径上+可以与前一层的(2)共享)XYXY

  8. dW[D, F] {U.X} = dTmp[B, F] * In[B, D]inYXYBX

  9. dW[D, F] = ReduceScatter(dW[D, F] {U.X})inXYXinY

  10. W[D, F] = AllGather(W[D, F]) (可以提前完成)inYXinXY

  11. dIn[B, D] {U.Y} = dTmp[B, F] * W[D, F] (前面层需要)XXYFinY

  12. dIn[B, D] = ReduceScatter(dIn[B, D] {U.Y}) (在关键路径上)XYYX

FSDP和MP的最佳组合是什么?一个简单但关键的原则是FSDP移动权重,而模型并行移动激活。这意味着随着批量大小缩小(尤其是当我们进行更多数据并行时),模型并行变得更便宜,因为每个分片的激活变小了。

  • 模型并行执行AllGather([B, D]),随着X增大而缩小。YXY

  • FSDP执行AllGather([D, F]),随着Y增大而缩小。XXY

因此,通过结合两者,我们可以将每个副本的最小批量大小进一步降低。我们可以按照上面的方式计算FSDP和MP的最佳比例:

X为专用于FSDP的芯片数量,Y为专用于张量并行的芯片数量。设N为我们切片中的总芯片数,且N = XY。设MXMY分别为我们执行FSDP和MP的网格轴数量(这些大致应该总和为3)。我们将纯粹建模前向传播,因为它每FLOP的通信量最大。然后将上面算法中的通信量相加,我们得到

$$ T_\text{FSDP comms}(B, X, Y) = \frac{2\cdot 2\cdot D \cdot F}{Y \cdot W_\text{ici} \cdot M_X} $$

$$ T_\text{MP comms}(B, X, Y) = \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y} $$

同样,我们的总FLOP时间为

$$ T_\text{math} = \frac{2\cdot 2 \cdot B \cdot D \cdot F}{N \cdot C}. $$

为简化分析,我们做两个简化:首先,我们允许XY取非整数值(只要它们是正数且满足XY = N);其次,我们假设不在XY轴上重叠通信。在第二个假设下,总通信时间为

$$ Tcomms = TFSDP comms + TMP comms. $$

在询问我们何时受计算限制之前,让我们找出XY的最优值以最小化总通信量。由于我们的FLOP与XY无关,最佳设置就是简单地最小化通信量。为此,让我们用XN(固定值,即我们系统中的芯片数量)而不是XY来表示上述Tcomms:

$$ T_\text{comms} (X) = \frac{F \cdot X}{N \cdot M_X} + \frac{B}{X \cdot M_Y} $$

对该表达式关于X求导并令导数等于零,得到最优值Xopt

$$ \begin{align*} \frac{d}{dX} T_\text{comms} (X_{opt}) = \frac{F}{N \cdot M_X} - \frac{B}{X_{opt}^2 \cdot M_Y} \rightarrow \ X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N} \end{align*} $$

这非常有用!它告诉我们,对于给定的BFN,最佳的FSDP数量是多少。让我们了解一下规模。代入实际值,即N = 64(对应4x4x4芯片阵列),B = 48,000,F = 32,768,大约得到X ≈ 13.9。所以我们会选择X为16,Y为4,接近我们计算的最优值。

💡 要点:一般来说,在训练期间,FSDP的最佳数量是X_{opt} = \sqrt{\frac{B}{F} \frac{M_X}{M_Y} N}。

现在让我们回到我们一直在问所有并行策略的问题:在什么条件下我们会受计算限制?由于我们可以重叠FLOP和通信,当

$$ TFSDP comms + TMP comms < Tmath $$

时我们受计算限制,这给我们

$$ \frac{2\cdot 2\cdot D \cdot F}{Y \cdot W_\text{ici} \cdot M_X} + \frac{2 \cdot 2 \cdot B \cdot D}{X \cdot W_\text{ici} \cdot M_Y} < \frac{2\cdot 2 \cdot B \cdot D \cdot F}{N \cdot C} $$

αC/Wici,即ICI算术强度,我们可以简化:

$$ \frac{F}{Y \cdot M_X} + \frac{B}{X \cdot M_Y} < \frac{B \cdot F}{N \cdot \alpha} $$

将我们计算的Xopt代入上面的方程(并注意Yopt = N/Xopt)得到关于批量大小B的以下条件:

$$ \sqrt{\frac{4 \cdot B\cdot F}{M_X \cdot M_Y \cdot N}} < \frac{B \cdot F}{N \cdot \alpha}, $$

其中左侧与通信时间成正比,右侧与计算时间成正比。注意,虽然计算时间随批量大小线性增长(无论并行方式如何都是如此),但通信时间随批量大小的平方根增长。因此,计算时间与通信时间的比率也随批量大小的平方增长:

$$ \frac{T_\text{math}}{T_\text{comms}} = \frac{\sqrt{BF}\sqrt{M_X M_Y}}{2\alpha \sqrt{N}}. $$

为确保这个比率大于1,使我们受计算限制,我们需要

$$ \frac{B}{N} > \frac{4\alpha^2}{M_X M_Y F} $$

参见附录C获取此关系的另一种推导。要获取近似数字,再次代入F = 32,768,α = 2550,MXMY = 2(3D网格必须如此)。这大约得到B/N > 400。与纯数据并行(或FSDP)情况相比,这大约获得了两倍的提升,在那种情况下,假设3D网格,我们计算出B/N必须超过约850才能受计算限制。

💡 要点:将张量并行与FSDP结合允许我们将B/N降至2 ⋅ 25502/F。这让我们能够处理每芯片仅400的批量大小,比单纯使用FSDP能达到的小约两倍。

下面我们绘制了在代表性4x4x4芯片阵列上混合FSDP + MP的FLOP与通信时间比率,并与仅使用模型并行和仅使用数据并行(FSDP)进行比较。虽然在非常大的批量大小下,纯FSDP并行占主导地位,但在批量大小与芯片数量之比约在400到850之间的区域,需要混合FSDP + MP策略才能受计算限制。

图示:在TPUv5p 4x4x4切片上F=30k时,最佳混合FSDP/MP的FLOP与通信时间比率。如预期,模型并行与批量大小有固定比率;理想的混合FSDP + MP与\\sqrt{B}成比例,而FSDP与B成比例。然而,在中等批量大小区域,只有FSDP + MP能达到大于1的比率。

这是另一个TPU v5p 16x16x16的例子,显示不同分片方案下批量大小对FLOP和通信时间的影响。

图示:不同并行方案的通信时间。黑色虚线是矩阵乘法FLOP所需时间,所以任何高于此线的曲线都受通信限制。我们注意到所有策略在批量大小低于1.5e6时都变得受通信限制,这与我们预期的4096  2  2550^2 / (8192 * 4) = 1.6e10一致。

黑色曲线是模型FLOP所花费的时间,这意味着任何批量大小下此曲线低于所有通信成本的情况都严格受通信限制。你会注意到黑色曲线与绿色曲线在大约1.6e10处相交,这与预测一致。

放大来看,我们可以看到在1M到6M每切片批量大小之间,将两个轴用于FSDP,并使用光学开关重新配置拓扑以获得8长的轴用于模型分片,将给我们最低的通信量,而在6M到100M之间,纯FSDP组合最佳。这与我们上面的计算一致!

这里有一个交互式动画,可以展示不同批量大小下的总计算时间和通信时间:

你会注意到这通常与上面的结论一致(最小值在FSDP=256,MP=16附近),每个轴的数量略有不同导致有些微小差异。

流水线并行

你可能会注意到我们在前面的部分完全避免了讨论流水线并行。流水线并行是GPU并行中的主导策略,在TPU上则不那么必要。简而言之,流水线训练涉及将模型的层分散到多个设备上,并在前向和后向传递过程中在流水线阶段之间传递激活值。算法大致如下:

  1. 在TPU 0上初始化你的数据,权重在层维度上进行分片(W[L, D, F]用于与FSDP和张量并行的流水线)。inZXY

  2. 在TPU 0上执行第一层,然后将生成的激活值复制到TPU 1,重复此过程直到到达最后一个TPU。

  3. 计算损失函数及其导数∂L/∂xL

  4. 对于最后一个流水线阶段,计算导数∂L/∂W和∂L/∂x,然后将∂L/∂x复制到前一个流水线阶段,重复直到到达TPU 0。LL − 1L − 1

这里是一些(可工作的)Python伪代码

这个伪代码应该可以在Cloud TPU VM上运行。虽然它不是很高效或实际,但它让你了解数据如何在设备间传播。

batch_size = 32d_model = 128d_ff = 4 * d_model
num_layers = len(jax.devices())
key = jax.random.PRNGKey(0)
# 假设每一层只是一个矩阵乘法。x = jax.random.normal(key, (batch_size, d_model))
weights = jax.random.normal(key, (num_layers, d_model, d_model))
def layer_fn(x, weight):
  return x @ weight
# 假设我们有 num_layers == num_pipeline_stagesintermediates = [x]
for i in range(num_layers):
  x = layer_fn(x, weights[i])
  intermediates.append(x)
  if i != num_layers - 1:
    x = jax.device_put(x, jax.devices()[i+1])
def loss_fn(batch):
  return jnp.mean(batch ** 2)  # 构造一个虚拟的损失函数loss, dx = jax.value_and_grad(loss_fn)(x)
for i in range(0, num_layers, -1):
  _, f_vjp = jax.vjp(layer_fn, intermediates[i + 1], weights[i])
  dx, dw = f_vjp(dx)  # 计算 jvp dx @ J(L)(x[i], W[i])  weights[i] = weights[i] - 0.01 * dw  # 更新我们的权重  if i != 0:
    dx = jax.device_put(dx, jax.devices()[i-1])

为什么这是个好主意? 流水线并行有很多好处:它在流水线阶段之间的通信成本很低,这意味着即使使用低带宽互连,也可以训练非常大的模型。这在GPU上通常非常有用,因为它们不像TPU那样通过ICI进行密集连接。

为什么这很困难/令人烦恼? 你可能已经注意到上面的伪代码中TPU 0几乎总是空闲的!它只在流水线的第一步和最后一步工作。这段空闲期称为流水线气泡,处理起来非常烦人。通常,我们首先尝试通过微批处理来缓解这个问题,即通过流水线发送多个小批次,使TPU 0至少在总步骤时间的较大部分保持利用。

第二种方法是仔细重叠前向矩阵乘法Wi@xi、后向dx矩阵乘法Wi@∂L/∂xi + 1和dW矩阵乘法∂L/∂xi + 1@xi。由于每个都需要一些FLOPs,我们可以重叠它们以完全隐藏气泡。以下是最近DeepSeek v3论文中展示的"无气泡"流水线调度图:

图示:DeepSeek v3流水线调度(来自他们的最新论文)。橙色是前向矩阵乘法,绿色是dL/dx矩阵乘法,蓝色是dL/dW矩阵乘法。通过优先处理后向dL/dx乘法,我们可以避免"搁置"FLOPs。

因为这对TPU来说不那么关键(TPU有更大的互连pod),我们不会深入探讨这个问题,但理解关键的流水线瓶颈是一个很好的练习。

在 Pod 之间扩展

让我们退一步看一个具体例子,比如在 TPU v5p 上训练 LLaMA-3 70B。LLaMA-3 70B 的 F ≈ 30,000。根据上面的部分,我们知道以下内容:

  • 当我们的模型并行度大于 Y > n * F/2550 ≊ n * 11 时,我们将受到 ICI 的限制。轴轴

  • 纯 FSDP 在批量大小 < 2550/n 时受到 ICI 限制。这意味着如果我们想要以 BS=2M 进行训练,我们最多只能使用约 2400 个芯片,这大约是 TPU v5p pod 的四分之一。轴

  • 混合 FSDP + 模型并行在批量大小 < 2 · 2550/30,000 = 432 时受到 ICI 限制,所以这让我们可以扩展到大约 9k 个芯片!然而,TPU v5p pod 的最大尺寸是 8k 芯片,超过这个数量,我们必须扩展到带宽更低的数据中心网络(DCN)。2

所以这给了我们一个很好的方案,可以在单个 pod 上使用 BS=3.5M。我们会使用上面的方程,得到大约 X (FSDP) = 1024 和 Y (MP) = 8。如果模型更大,可以将模型分片扩展到 16。我们有一些空间可以将批量大小降低到 BS=1.5M,在该 pod 上仍然受计算限制,但我们已接近下限。

要扩展到超过一个 pod,我们需要通过 DCN 进行扩展。因为 DCN 的带宽较低,通常太慢而无法进行有用的 FSDP。相反,我们在 DCN 轴上进行纯数据并行,在 pod 内进行 FSDP。让我们计算一下数据中心网络(DCN)是否足够支撑。

使用 DCN 上的纯数据并行,我们需要在每个步骤中同步权重和优化器状态(当模型完成其反向传递时,我们需要完成 AllReduce)。我们实际上可以借用上面纯数据并行部分的数学公式,它告诉我们当每个 pod 的批量大小 < Cpod/Wdcn 时,我们会受到通信限制,其中右侧是整个 pod 的总计算量和总带宽。

  • 我们的总 DCN 入口+出口带宽是每主机 2.5e10,每主机有 4 个芯片。这给我们大约 2000 个主机,总共 5e13 字节的带宽。

  • 这里的 C 是 pod 大小乘以每芯片计算量,即 8k * 4.5e14 = 3.8e18 FLOP。pod

如前所述,当 Tmath < Tcomms 时我们会遇到瓶颈,这发生在我们的每 pod 批量大小 < C/WDCN = 3.8e18/5e13 = 76,000(我们的 pod 级 DCN 操作强度)。对于 LLaMA-3,这不会是个问题,因为我们的每 pod 批量大小远高于此,但如果我们在较小的切片上训练(例如 v5e),这可能会成为一个问题。

总结:这意味着我们可以相当任意地跨 pod 扩展,例如,使用 10 个 8960 芯片的 pod,我们可以在 89,600 个芯片上实现约 4000 万 token 的全局批量大小,在大约 2 天内训练完 LLaMA-3 70B。

TPU上LLM训练的要点

  • 增加并行度或减少批量大小都倾向于使我们更受通信限制,因为它们减少了每个芯片执行的计算量。

  • 在合理的上下文长度(约32k)范围内,我们可以将Transformer建模为MLP块堆栈,并通过它们如何对每层的两个/三个主要矩阵乘法进行分片来定义各种并行方案。

  • 在训练期间,我们考虑4种主要的并行方案,每种都有其自身的带宽和计算需求(数据并行、FSDP、模型并行)。

策略描述
数据并行激活按批次分片,其他所有内容完全复制,我们在反向传递期间进行全归约梯度。
FSDP激活、权重和优化器按批次分片,权重在使用前进行收集,梯度进行规约-散射。
模型并行(又称Megatron,张量)激活沿dmodel分片,权重沿dff分片,激活在Win之前收集,结果在Wout之后进行规约-散射。
混合FSDP + 模型并行上述两者结合,其中FSDP收集模型分片的权重。

以下是每种方法的"公式":

$$ \small \begin{array}{cc} \text{Strategy} & \text{Formula}\ \hline \text{DP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D, F] \cdot_F W_\text{out}[F, D] \rightarrow \text{Out}[B_X, D] \ \text{FSDP} & \text{In}[B_X, D] \cdot_D W_\text{in}[D_X, F] \cdot_F W_\text{out}[F, D_X] \rightarrow \text{Out}[B_X, D] \ \text{MP} & \text{In}[B, D_Y] \cdot_D W_\text{in}[D, F_Y] \cdot_F W_\text{out}[F_Y, D] \rightarrow \text{Out}[B, D_Y] \ \text{MP + FSDP} & \text{In}[B_X, D_Y] \cdot_D W_\text{in}[D_X, F_Y] \cdot_F W_\text{out}[F_Y, D_X] \rightarrow \text{Out}[B_X, D_Y] \ \hline \end{array} $$

  • 这些策略中的每一种都有一个限制,在此限制下它会变得网络/通信受限,这基于它们的每设备计算和通信。以下是每层的计算和通信,假设X是FSDP,Y是模型并行。

$$ \small \begin{array}{ccc} \text{Strategy} & \text{Compute per layer} & \text{Comms per layer} \ & \text{(ignoring gating einsum)} & \text{(bytes, forward + backward pass)}\ \hline \text{DP} & 4BDF/X + 8BDF/X & 0 + 8DF \ \text{FSDP} & 4BDF/X + 8BDF/X & 4DF + 8DF \ \text{MP} & 4BDF/Y + 8BDF/Y & 4BD + 4BD \ \text{FSDP + MP} & 4BDF/(XY) + 8BDF/(XY) & (4BD/X + 4DF/Y) + (8BD/X + 8DF/Y) \ \hline \end{array} $$

  • 纯数据并行很少有用,因为模型及其优化器状态使用的字节数 = 参数数量的10倍。这意味着我们很少能在内存中放入超过几十亿的参数。

  • 当每个分片的批量大小 < C/W(网络的算术强度)时,数据并行和FSDP变得通信受限。对于ICI,这是2,550,对于DCN,这是75,000。这可以通过更多的并行轴来增加。

  • 当|Y| > F/2550时,模型并行变得通信受限。对于大多数模型,这大约是8-16路。这与批量大小无关。

  • 混合FSDP + 模型并行允许我们将批量大小降低到低至2 ⋅ 2550²/F ≈ 400。这相当接近我们变得HBM带宽受限的点(约200)。

  • 跨pod的数据并行在变得DCN受限之前,每个pod需要最小批量大小约为75,000。

  • 基本上,如果你的批量大小很大或模型很小,事情就很简单。你可以做数据并行或FSDP + 跨DCN的数据并行。中间部分是事情变得有趣的地方。

第5章练习题

让我们使用LLaMA-2 13B作为本节的基本模型。以下是一些细节:

超参数
n_layers (L)40
d_model (D)5,120
ffw_multiplier (F / D)2.7
n_heads (N)40
n_kv_heads (K)40
d_qkv (H)128
n_embeddings (V)32,000

问题1:LLaMA-2 13B有多少参数(我知道这很傻但请计算一下)?注意,正如在Transformer数学中所述,LLaMA-3有3个大型FFW矩阵,两个上投影和一个下投影。我们在本节中忽略了两个"门控"爱因斯坦求和矩阵,但它们在本节中的行为与Win相同。

点击此处查看答案。
  • FFW参数:3LDF = 8.5e9

  • 注意力参数:4DNHL = 4.2e9

  • 词汇表参数:2VD = 0.3e9

  • 总计:8.5e9 + 4.2e9 + 0.39e9 = 13.1e9,与预期一致!

问题2:假设我们使用BS=16M个token进行训练并使用Adam优化器。暂时忽略并行性,模型参数、优化器状态和激活值总共使用多少内存?假设我们将参数存储为bf16格式,优化器状态存储为fp32格式,并在每层的三个大型矩阵乘法后进行三次激活检查点。

点击此处查看答案。

参数(bf16)和两个优化器状态(fp32,一阶和二阶矩累积器)的总内存使用量为(2 + 4 + 4) * 13e9 ~ 130GB。前两个矩阵乘法后的激活形状为BF,最后一个矩阵乘法后为BD(根据上面的Transformer图),因此bf16的总内存为2 ⋅ L ⋅ (BD + 2 * BF) = 2LB ⋅ (D + 2F)或2 * 40 * 16e6 * 5,120 * (1 + 2 * 2.7) ~ 4.2e13 = 42TB,因为B=16e16。所有其他激活值基本上可以忽略不计。

问题3:假设我们想要使用32k序列长度和总批量大小为3M个token在TPUv5p 16x16x16切片上进行训练。假设我们想使用bfloat16权重和float32优化器,如上所述。

  1. 我们能使用纯数据并行吗?为什么能或为什么不能?

  2. 我们能使用纯FSDP吗?为什么能或为什么不能?使用纯FSDP,每个设备将使用多少内存(假设我们只在3个大型FFW矩阵之后进行梯度检查点)。

  3. 我们能使用混合FSDP + 模型并行吗?为什么能或为什么不能?如果可以,XY应该是多少?每个设备将存储多少内存?仅使用屋顶线FLOP估计并忽略注意力机制,每个训练步骤将花费多长时间?

点击此处查看答案。

首先,让我们写下一些数字。使用32k序列长度和3M批量大小,我们的序列批量大小为96。在TPU v5p 16x16x16切片上,我们有393TB的HBM。

  1. 我们不能使用纯数据并行,因为它在每个芯片上复制参数和优化器状态,这些已经约为130GB(来自Q2),超过了我们每个芯片拥有的HBM(96GB)。

  2. 让我们先看纯粹的内存。在Q2中将BS=16M替换为3M,我们得到~7.86e12总检查点激活,加上1.3e11优化器状态,这使我们达到几乎正好8e12 = 8TB。TPUv5p切片总共有393TB的HBM,所以我们安全地低于HBM限制。接下来让我们看看我们是否会受通信或计算限制。使用4096个芯片和3个并行轴,我们可以做的最小批量大小为850 * 4096 = 3.48M个token。这略高于我们的3M批量大小。所以我们实际上是受通信限制的,这很遗憾。因此,总的答案是不,我们不能仅使用FSDP

  3. 现在我们知道我们的主要关注点是受通信限制,所以让我们代入一些数字。首先,从上面的判别式中,我们知道使用混合FSDP + 模型并行时,我们的每芯片批量大小需要大于2 ⋅ 25502/F = 940,这实际上比纯FSDP略差。显然,这有点像我们做的一些近似的产物,但这表明混合FSDP + 模型并行实际上并没有好多少。部分原因是F太小,我们无法进行完整轴的模型并行。解决这个问题的一种方法是做4个芯片的张量并行的小子环,并将第一个轴的剩余带宽专用于FSDP。我们不会计算出具体的数学,但检查我们可能可以在不受通信限制的情况下做到这一点是很好的。

问题4:如果我们想降到批量大小1M怎么办?这如何影响问题3的答案?批量大小10M又如何?

第5章附录 A - 关于 FSDP 的更多内容

这里有一个很好的额外图表,展示了 FSDP 如何分片参数/梯度。图中的行依次是纯数据并行、ZeRO-1/2/3。没有太多理由不使用 ZeRO-3,因为它实际上具有相同的通信负载。

图示:展示纯数据并行、ZeRO-1/2/3 的参数、梯度和优化器状态内存的图表。图片来源

第5章附录 B - 推导反向传播所需的通信

上面,我们将 Transformer 层前向传播简化为 Out[B, D] = In[B, D] D Win[D, F] F Wout[F, D]。我们如何推导反向传播所需的通信?

这自然遵循前一节中单个矩阵乘法 Y = X * A 的规则:

$$ \frac{dL}{dA} = \frac{dL}{dY}\frac{dY}{dA} = X^T \left(\frac{dL}{dY}\right) $$

$$ \frac{dL}{dX} = \frac{dL}{dY}\frac{dY}{dX} = \left(\frac{dL}{dY}\right) A^T $$

使用这些,我们得到以下公式(让 Tmp[B, F] 表示 In[B, D] * Win[D, F]):

  1. dWout[F, D] = Tmp[B, F] *B dOut[B, D]

  2. dTmp[B, F] = dOut[B, D] *D Wout[F, D]

  3. dWin = dTmp[B, F] *B Tmp[B, F]

  4. dIn[B, D] = dTmp[B, F] *F Win[D, F]

  5. dW[F, D] = Tmp[B, F] * dOut[B, D]

  6. dTmp[B, F] = dOut[B, D] * W[F, D]

  7. dW = dTmp[B, F] * Tmp[B, F]

  8. dIn[B, D] = dTmp[B, F] * W[D, F]

请注意,这些公式是数学表达式,没有提及分片。反向传播的任务是计算这四个量。因此,要确定所需的通信,我们只需取上述四个等式中要进行矩阵乘法的所有量(Tmp、dOut、Wout、Win)的分片方式,这些由我们的并行化方案指定,然后使用分片矩阵乘法的规则来确定我们必须进行的通信。注意,dOut 的分片方式与 Out 相同。

第5章附录 C - 混合 FSDP + 模型并行的批量大小约束的替代推导

上面我们推导出,当使用 FSDP + 模型并行的组合时,在以下条件下我们可以受计算约束:

$$ \frac{B}{N} > \frac{4\alpha^2}{M_X M_Y F} $$

这里我们提供这一事实的替代推导。我们首先将通信时间设为等于计算时间,并寻找使这种等式不可能成立的条件。

$$ \frac{F}{Y \cdot M_X} + \frac{B}{X \cdot M_Y} = \frac{B \cdot F}{N \cdot \alpha} $$

由于 XY = N,我们可以用 X 重写:

$$ \frac{FX}{N \cdot M_X} + \frac{B}{X \cdot M_Y} = \frac{B \cdot F}{N \cdot \alpha}, 或 $$

$$ X^2 \frac{F}{N \cdot M_X} + \frac{B}{M_Y} - X \frac{B \cdot F}{N \cdot \alpha} = 0. $$

由于这是关于 X 的二次方程,我们将没有解的点是判别式变为零的点。这发生在

B2 ⋅ F2 ⋅ MX2 ⋅ MY2 − 4 ⋅ α2 ⋅ FBNMYMX = 0

或通过简化

BFMXMY − 4 ⋅ α2 ⋅ N = 0

这给我们

$$ B = \frac{4 \cdot \alpha^2 \cdot N}{F \cdot M_X \cdot M_Y} $$

所以我们的总批量大小除以芯片总数不能低于

$$ \frac{4 \alpha^2}{F \cdot M_X \cdot M_Y}, $$

正如我们上面推导的那样。

第6章:在TPU上训练LLaMA 3

我们在本节中的目标是将上一节的结果应用于一个非常实际的问题:训练LLaMA 3系列(群体)模型。与前面的章节不同,我们希望你自己完成大部分工作。因此,我们隐藏了每个部分的答案,这样你可以先尝试自己回答。试着拿起笔,动手计算吧!

LLaMA 3是什么样的?

LLaMA-3模型系列包括3个主要模型:LLaMA 3 8B、70B和405B。我们将主要关注70B,并在最后的问题部分留给你探索8B和405B。以下是LLaMA 3-70B的架构,摘自LLaMA的HuggingFace页面

超参数数值
nlayers (L)80
dmodel (D)8,192
dff(F)28,672
nheads (N)64
nkv_heads (K)8
dqkv (H)128
nembeddings (V)128,256

为了突显这些信息多么容易找到,这里是配置本身,以及对应关系:

为许多不同的开源LLM制作一个包含这些数字的大表格是很有用的,这样你可以快速比较它们所做的设计决策。

计算参数和浮点运算

问题:从这个表格中,我们能计算出LLaMA 3-70B的参数数量吗?🤫 让我们应用第4节的内容,看看能否得到70B!

参数公式数量
FFW参数d_modeld_ff3 (用于gelu + 输出投影) * n_layers8,1928,1923.5380 =56.3e9
词汇表参数2 (输入和输出嵌入)n_embeddingsd_model2128,2568,192 =2.1e9
注意力参数n_layers[ 2 (用于q嵌入和连接输出投影)d_modeln_headsd_qkv + 2 (用于k和v)d_modeln_kv_heads * d_qkv]80(28,19264128 + 28,1928 * 128) =12e9
56.3e9 + 2.1e9 + 12e9 =70.4e9

太好了!我们得到了预期的数字。正如预期的那样,你会注意到FFW参数完全主导了整体参数数量,尽管注意力参数也不少。

💡 要点:MLP模块中的3个大型权重矩阵比Transformer中的所有其他数组都大得多,以至于在推理模型内存或浮点运算时,我们通常几乎可以忽略所有其他参数。对于LLaMA 3-70B,它们在70B参数中占了56B。

现在让我们看看浮点运算!记住第4节中关于训练的一般规则。

问题:LLaMA-3在每个训练步骤中每个token执行多少浮点运算?这有助于我们确定整个训练过程的成本。

在你思考后点击这里查看答案!

答案:如第4节所示,我们每个token大约执行6 ⋅ 参数数量的浮点运算,所以这里大约是6 * 70e9 = 4.2e11浮点运算/token。这大约是每个token每步半个TFLOP。假设我们受计算限制,在单个TPU v5p芯片上,假设完美的浮点运算利用率,这大约需要4.2e11 / 4.59E+14 = 1ms

问题:LLaMA 3训练了约15万亿个token。总共需要多少浮点运算?

在你思考后点击这里查看答案!

答案:这很简单,就是4.2e11 * 15e12 = 6.3e24浮点运算。6.3 yottaFLOPs。这太多了!在单个TPU上,这将需要6.3e24 / 4.59E+14 = 435年。这也太多了!

问题:假设我们想在一个完整的TPU v5p集群上训练,该集群有16x20x28 = 8960个芯片。在bfloat16下以40% MFU训练,假设我们受计算限制,这需要多长时间?

在你思考后点击这里查看答案!

答案:我们知道每个TPU v5p每秒可以执行4.59e14次浮点运算。在40% MFU下,这大约需要T = 6.3e24 / (8960 * 4.59e14 * 0.4) = 3.8e6秒这大约是44天!假设我们真的能达到40% MFU,这相当合理。

问题:LLaMA 3-70B预训练时的批量大小约为4M个token。我们至少需要多少TPU才能以这种批量大小进行训练?你可以假设bfloat16参数和float32优化器状态,并且每层检查点梯度4次。

在你思考后点击这里查看答案!

答案:这个问题主要是关于内存使用,因为这是可用计算的唯一严格约束。在训练期间,我们有三个主要的HBM用途:模型参数、优化器状态和梯度检查点。如果我们假设bfloat16权重、float32优化器状态和一个非常保守的梯度检查点方案(每层4次),我们有:

参数 | 2 * 70GB | ~140GB |优化器状态 | 8 * 70GB | ~560GB |梯度检查点 | 2 8192 4e6 4 80 | ~20.9TB |总计 | | ~21.6TB |

这里的总内存约为21.6TB。你会注意到,即使使用非常保守的检查点方案,梯度检查点也在内存图中占主导地位。从技术上讲,我们可以每层使用1个检查点,或者做微批处理,但这是一个合理的图景。根据这些假设,由于每个TPU v5p有96GB的HBM,我们需要21.6e12 / 96e9 = 225个TPU。这实际上不是很多!

为什么我们不这样做呢?因为这将花费44天 * 8960 / 225 = 1752天来训练。那是6年半。太久了。不过,这清楚地表明,我们使用这些大型集群不是因为受内存限制,而是因为我们需要额外的浮点运算能力。

问题:在与上述问题相同的假设下,如果我们使用8960个TPU v5p芯片,每个芯片将使用多少内存?

在你思考后点击这里查看答案!

答案:我们的总内存仍然约为21.6TB,所以每个芯片我们将使用约2.4GB,这基本上不算什么。如果我们做更激进的检查点,例如每层12个检查点,我们仍然只有每个芯片8GB。在这些规模下,我们在训练期间远没有受到内存限制。

💡 要点:从技术上讲,即使在非常小的拓扑上训练非常大的模型也是可能的,但缺点是它们可能需要很长时间。能够计算训练运行的总浮点运算量使我们能够通过假设适度的MFU和已知拓扑来大致估计其训练时间。

如何为训练对LLaMA 3-70B进行分片

让我们继续使用上面的设置,假设我们想要在8960个芯片的TPU v5p集群上训练LLaMA 3-70B,批量大小为4M个token(每批1024个序列,每个序列长度为8192)。让我们讨论一下这个模型最佳的分片策略。

问题:在上述假设下,我们能否仅使用FSDP训练我们的模型?首先,假设我们不能进行任何序列/上下文并行化。这应该是你首先想到的方法,因为它简单且如果可行的话不会引入额外的通信开销。

在你思考后点击这里查看答案!

答案:这个回答会有点学究气。如上所述,LLaMA 3-70B最初是用长度为4K的序列训练的,所以4M个token的批量大小给了我们1024的序列批量大小。这意味着我们实际上只能在最多1024个芯片上进行纯数据并行/FSDP,因为我们只有这么多序列可以进行数据并行。所以从"完全数据并行且无额外通信"的简单意义上来说,答案是否定的。下一个问题将回答这个问题的一个稍微不那么学究的版本。

问题:让我们放宽不进行任何序列分片的要求。如果我们允许自己在批次序列轴上进行FSDP,我们能否仅使用FSDP在8960个芯片上训练LLaMA 3-70B?

在你思考后点击这里查看答案!

答案:现在我们允许自己进行序列/上下文并行,我们可以扩展得更多。首先让我们计算每个设备的批量大小。如果我们进行8960路FSDP,我们最终每个TPU的批量大小为4 * 1024 * 1024 / 8960 = 468个token。我们从上一节知道,当每个设备批量大小小于2550/n轴时,FSDP会受到ICI限制。由于我们可以在完整的3D集群中使用3个轴,这将给我们一个850的下限,而我们远远低于这个值。所以答案是否定的,即使有3个轴。我们将完全受到通信限制。

问题:现在让我们看看混合张量并行和FSDP。是否存在某种组合让我们保持计算受限?如果有,我们应该采用多少FSDP和张量并行?

在你思考后点击这里查看答案!

答案:首先让我们检查这是否能够适配。我们知道,如果每个芯片的批量大小小于2 ⋅ 25502/F = 453,我们将受到通信限制。如上所述,我们略高于这个值。太好了!现在为了选择最佳的FSDP数量,我们可以使用公式

$$ X_{opt} = \sqrt{\frac{2BN}{F}} = \sqrt{\frac{2 \cdot 4.19e6 \cdot 8960}{28672}} = 1618 $$

四舍五入到2的合理倍数,大约给我们2048路FSDP和4路模型并行。这应该效果不错!

💡 要点:我们可以在完整的TPU v5p集群上使用数据并行(1024路)、序列并行(2路)和张量并行(4路)的混合方式训练具有4M token批量大小的LLaMA-3,而不会受到通信限制。如果我们尝试使用纯FSDP或FSDP+序列并行,我们将受到通信限制。我们在前一节中推导的方程非常实用。

第6章练习题

问题1 [将LLaMA 70B扩展到更多芯片]:假设我们想在4个集群上训练LLaMA 3-70B,批量大小相同。我们会使用什么并行方案?我们会受到计算限制还是通信限制?大致需要多长时间来训练?确保使用正确的屋顶线界限。

问题2 [LLaMA 405B]:

  1. 使用LLaMA 3-405B的配置,像上面一样列出所有关键超参数的表格。这个模型总共有多少参数?每个训练步骤有多少FLOP?如果我们训练15T个token,我们会执行多少FLOP?

  2. 假设我们想在8个TPU v5p集群上训练。我们会使用什么并行方案?训练需要多长时间?我们会受到计算限制还是通信限制?

第7章:关于Transformer推理的一切

Transformer推理基础

假设你已经训练了一个Transformer模型,现在想用它来生成一些新序列。归根结底,基准分数的提高和损失曲线的下降只是代理指标,真正的考验是当模型实际应用时会发生什么!

采样在概念上很简单。我们输入一个序列,我们喜欢的Transformer会输出log p(下一个标记i|前面的标记),即所有可能的下一个标记的对数概率。我们可以从这个分布中采样并获得一个新标记。附加这个标记并重复这个过程,我们就能获得一个标记序列,作为提示的延续。

图示:从Transformer进行朴素采样。蓝色的logits给我们提供了可以采样的下一个标记的分布。注意,每一步都会重新处理整个前缀,导致算法的运行时间为Θ(n2)。

我们刚才描述的是Transformer采样的朴素实现,虽然它能工作,但我们在实践中从不这样做,因为我们每次生成一个标记时都要重新处理整个序列。这个算法在生成n个标记时,FFW部分的复杂度是O(n2),注意力机制的复杂度是O(n3)!

我们如何避免这种情况?与其每次都进行完整的前向传播,事实证明我们可以保存每次前向传播中的一些中间激活值,从而避免重新处理先前的标记。具体来说,由于在点积注意力中,给定标记只关注之前的标记,我们可以简单地将每个标记的键和值投影写入一个称为KV缓存的新数据结构中。一旦我们保存了过去标记的键/值投影,未来的标记就可以简单地计算它们的qikj乘积,而不需要对早期标记执行任何新的浮点运算。太棒了!

考虑到这一点,推理有两个关键部分:

  • 预填充:给定一个长提示,我们同时处理提示中的所有标记,并将生成的激活值(特别是键-值投影)保存在"KV缓存"中。我们还保存最后一个标记的logits。

  • 生成:给定KV缓存和前一个logits,我们从logits中递增地采样一个标记,将该标记反馈给Transformer,并为下一步产生一组新的logits。我们还将该新标记的KV激活附加到KV缓存中。我们重复这个过程,直到遇到特殊的&lt;EOS&gt;标记或达到某个最大长度限制。

以下是使用KV缓存采样的示意图:

图示:使用KV缓存进行高效的Transformer采样。预填充处理我们的提示并将所有每个标记的键-值激活保存在缓存中。生成使用这个缓存(和最后标记的logits),采样一个新标记,并将该新标记传递给模型,关注KV缓存并将新标记的键-值投影保存回缓存。这在MLP块中是一个O(n)算法。

通过使用KV缓存进行采样,我们已经将生成n个标记的时间复杂度降低到FFW上的O(n)和注意力上的O(n2),因为我们从不重新处理前一个标记。然而,生成一个序列仍然需要许多前向传递——这就是当你查询Gemini或ChatGPT时结果流式返回给你的情况。每个标记(通常)都是对一个庞大模型的单独(但部分缓存的)Transformer调用。

我们很快就会看到,预填充生成是非常不同的任务——Transformer推理实际上是两个伪装成一个的任务!与训练相比,KV缓存也是一个新颖且重要的复杂性来源。

我们实际上想优化什么?

在我们进一步讨论之前,值得强调推理中一个全新的方面:延迟。虽然在训练过程中,我们只关心吞吐量(每秒处理的总令牌数),但在推理过程中,我们必须关注生成令牌的速度(包括首个令牌的时间(TTFT)每个令牌的延迟)。例如:

  • 离线批量推理用于评估和数据生成,只关心推理的整体成本,而不关注单个样本的延迟。

  • 聊天界面/流式任务需要在规模上经济高效地运行,同时具有低TTFT,并且生成令牌的速度足够快,超过人类阅读速度。

  • 边缘推理(例如在笔记本电脑上运行的llama.cpp)只需要以尽可能低的延迟为一个用户提供服务,可能还有严格的硬件限制。

最大化硬件利用率仍然至关重要,有助于降低成本和TTFT,但与训练不同,它不一定会为所有情况下的个别用户带来更好的体验。加速器、系统和模型架构层面的许多优化都在延迟、吞吐量、上下文长度甚至模型质量之间做出权衡。

Transformer的更细粒度视图

到目前为止,我们主要将Transformer视为前馈块的堆叠。虽然从FLOPs和内存角度来看这通常是合理的,但对于正确建模推理来说这是不够的。正如我们在第4部分中看到的,Transformer前向传递的主要组成部分是:

  1. 一系列线性操作,包括MLP(WW)和注意力QKV投影与输出投影(WWWW)。这些都涉及从HBM读取参数和一批激活值,进行一些FLOPs,然后将结果写回HBM。inoutQKVO

  2. 点积注意力。我们需要从HBM读取一批键值投影和一批查询激活,进行一些内积和softmax操作,然后将注意力结果写回HBM。

  3. 其他所有内容,包括应用层归一化、激活函数、令牌采样、更新KV缓存和位置嵌入。这些确实需要一些FLOPs,但被上述操作所主导或融合到其中。

在接下来的几节中,我们将在预填充和生成的背景下查看每一个部分,并询问什么可能会成为我们性能的瓶颈。在单个加速器内,我们是受计算限制还是受内存限制?我们想强调预填充与生成的答案会有多么不同。

线性操作:什么成为我们的瓶颈?

我们所有的线性操作在概念上都是相同的,无论它们位于MLP块还是注意力机制中。它们的算术密度取决于批量大小。我们在第1部分中已经进行了这些计算,但值得重复一下。让我们看一个bf16[B, D]批次乘以bf16[D, F]矩阵的单个矩阵乘法。这可能是大型MLP块或较小的注意力投影之一(WQWKWVWO)。要执行这个矩阵乘法,我们需要将这两个数组从HBM加载到MXU中,进行乘法运算,然后将结果写回HBM。如前所述,我们有:

$$ T_\text{math} = \frac{\text{总FLOP数}}{\text{TPU FLOP/秒}} = \frac{2BDF}{\text{TPU FLOP/秒}} $$

$$ T_\text{comms} = \frac{\text{总字节数}}{\text{HBM带宽}} = \frac{2BD + 2FD + 2BF}{\text{HBM带宽}} $$

TPU可以通过在计算的同时加载来重叠这些操作,因此要受计算限制,我们需要Tmath ≥ Tcomms,或:

$$ \frac{2BDF}{2BD + 2DF + 2BF} \geq \frac{\text{TPU FLOP/秒}}{\text{HBM带宽}} = \frac{1.97E+14}{8.20E+11} = 240 $$

其中右侧是我们硬件的算术密度。现在假设DF相比B非常大(通常我们的批次最多为500,而DF > 10k),我们可以通过使用\small{2BD + 2DF + 2BF \approxeq 2DF}这一事实来简化分母,得到

$$ \begin{align*} \frac{2BDF}{2BD + 2DF + BF} \approxeq \frac{2BDF}{2DF} \geq \frac{\text{TPU FLOP/秒}}{\text{HBM带宽}} \ = \frac{1.97E+14}{8.20E+11} \implies B \geq 240 = B_{\text{crit}} \end{align*} $$

结论:要使任何矩阵乘法受计算限制,我们的总令牌批量大小必须大于Bcrit,这取决于硬件和量化。对于TPU v5e上的bf16激活,这个值是240个令牌。这适用于Transformer中的任何简单矩阵乘法(例如MLP块或注意力投影)。

在训练期间,所有矩阵乘法的算术密度都很高,因为我们在非常大的批次上重复使用相同的权重。这种高算术密度也适用于预填充阶段,因为用户提示通常有数百甚至数千个令牌长。如前所述,TPUv5e的硬件算术密度为240,因此如果将长度超过240个令牌的序列输入到在bf16下运行的密集模型中,我们预计会受计算限制,一切正常。技术上可以将短于此长度的提示批处理在一起以实现更高的利用率,但这通常不是必要的。

结论:在预填充期间,所有矩阵乘法基本上都是受计算限制的。因此,简单地最大化硬件利用率或MFU(模型FLOP利用率)足以最大化每芯片吞吐量(成本)和延迟(以TTFT形式)。除非提示非常短,否则在提示级别进行批处理只会增加延迟,而预填充吞吐量的改进很小。

然而,在生成阶段,对于每个请求,我们只能一次执行一个令牌的前向传递,因为步骤之间存在顺序依赖关系!因此,我们只能(容易地)通过批处理多个请求并在批次维度上并行化来实现良好的利用率。我们稍后会详细讨论这一点,但实际上,在不影响延迟的情况下批处理许多并发请求是很困难的。因此,在生成阶段要让硬件FLOP饱和要困难得多。

结论:我们的总令牌批量大小必须大于Bcrit,才能使生成阶段的线性/前馈操作受计算限制(TPU v5e上的bf16参数为240)。由于生成是按令牌串行发生的,这需要我们将多个请求批处理在一起,这很困难!

值得注意的是这个数字有多大!生成批量大小为240意味着同时生成240个并发请求,以及240个单独的密集模型KV缓存。这意味着在实践中很难实现,除了在一些批量推理设置中。相比之下,在预填充期间处理超过240个令牌是相当常见的,尽管随着稀疏性的增加需要一些注意。

请注意,这个确切的数字将根据量化类型和硬件而有所不同。加速器通常可以在较低精度下提供更多的算术运算。例如,如果我们使用int8参数但在bf16中进行计算,临界批量大小会降至120。使用int8激活和int8参数时,它会跳回到240,因为TPUv5e可以提供400 TOPs/s的int8 x int8计算能力。

注意力机制怎么样?

当我们查看点积注意力操作时,情况变得更加复杂,特别是因为我们必须考虑KV缓存。让我们只看一个具有纯多头注意力的注意力头。在单个Flash Attention融合中,我们:

  1. 从HBM读取形状为bf16[B, T, D]的Q激活。

  2. 从HBM读取KV缓存,这是一对bf16[B, S, D]张量。

  3. QK矩阵乘法中执行2BSTD FLOP。使用Flash Attention,我们不需要将bf16[B, S, T]注意力矩阵写回HBM。

  4. 在注意力AV矩阵乘法中执行2BSTD

  5. 将结果bf16[B, T, D]张量写回HBM。

综合起来,我们得到:

$$ \text{多头注意力算术强度} = \frac{4BSTD}{4BSD + 4BTD} = \frac{ST}{S+T} $$

对于预填充,S = T,因为我们正在做自注意力,所以这简化为T²/2T = T/2。这很好,因为它意味着预填充期间注意力的算术强度是Θ(T)。这意味着注意力很容易受计算限制。只要我们的序列长度足够大,我们就没问题!

但由于生成阶段的序列维度很小,且BD维度相互抵消,我们可以做出近似:

$$ S \gg T = 1 \implies \frac{ST}{S+T} \approx 1 $$

这很糟糕,因为它意味着我们无法做任何事情来提高生成期间注意力的算术强度。我们在加载巨大的KV缓存的同时只做了极少量的FLOP。所以在注意力阶段我们基本上总是受内存带宽限制!

💡 要点:在预填充期间,对于任何合理的序列长度(大约 > 480个令牌),注意力通常受计算限制,而在生成期间,我们的算术强度低且恒定,所以我们总是受内存带宽限制。

从概念上讲,为什么会这样?主要是因为在模型的线性部分我们受计算限制,因为参数(内存带宽消耗大的组件)被多个批次项目重复使用。然而,每个批次项目都有自己的KV缓存,所以更大的批次大小意味着更多的KV缓存。除非架构被积极调整,否则我们在这里几乎总是受内存限制。

这也意味着一旦参数内存变得与KV缓存内存相当,增加批次大小对吞吐量的收益将会递减。递减收益对你的影响程度取决于单个序列的参数与KV缓存字节的比率,即大约是2DF/SHK的比率。由于HKD,这大致取决于FS(序列长度)的比率。这也取决于使KV缓存变小的架构修改(我们稍后会详细说明)。

LLM延迟和吞吐量的理论估计

从这些数学计算中,我们可以得到在优化时应该瞄准的步骤时间的相当好的界限。(注意:如果我们希望读者从整个章节中获取一件事,那就是以下内容)。对于生成阶段中常见的小批量大小,我们可以通过假设在注意力和MLP模块中都受内存带宽限制来确定每步延迟的下限:

$$ \begin{equation*} \text{理论最小步骤时间} = \frac{\text{批量大小} \times \text{KV缓存大小} + \text{参数大小}}{\text{总内存带宽}} \end{equation*} $$

同样,对于吞吐量:

$$ \begin{equation*} \text{理论最大令牌/秒} = \frac{\text{批量大小} \times \text{总内存带宽}}{\text{批量大小} \times \text{KV缓存大小} + \text{参数大小}} \end{equation*} $$

最终,随着批量大小的增长,FLOP开始主导参数加载,所以在实践中我们有更一般的等式:

$$ \begin{align} \tiny \text{理论步骤时间(一般)} = \underbrace{\frac{\text{批量大小} \times \text{KV缓存大小}}{\tiny \text{总内存带宽}}}{\text{注意力(始终受带宽限制)}} + \underbrace{\max\left(\frac{2 \times \text{批量大小} \times \text{参数数量}}{\text{总FLOP/秒}}, \frac{\text{参数大小}}{\text{总内存带宽}}\right)}{\tiny \text{MLP(可能受计算限制)}} \end{align} $$

其中注意力组件(左侧)从不受计算限制,因此不需要FLOP屋顶线。这些对于粗略计算非常有用,例如:

小测验:假设我们想在TPU v5e 4x4切片上使用int8格式(具有bf16 FLOP)对一个30B参数密集模型进行批量大小为4个令牌的生成步骤,上下文长度为8192且每个令牌的KV缓存为100 kB。这个操作的延迟合理下限是多少?如果我们想对256个令牌的批次进行采样呢?

点击这里查看答案。

答案:在int8中,我们的参数将使用30e9字节,按照给定的规格,每个KV缓存将使用100e3 * 8192 = 819MB。我们有16个芯片,每个具有8.1e11字节/秒的带宽和1.97e14 bf16 FLOP/秒。根据上述等式,由于我们的批量大小较小,我们预计步骤时间至少为(4 * 819e6 + 30e9) / (16 * 8.1e11) = 2.5 ms。在256个令牌的情况下,我们的MLP块将深入计算限制区域,所以我们的步骤时间大约为(256 * 819e6) / (16 * 8.1e11) + (2 * 256 * 30e9) / (16 * 1.97e14) = 21ms

如你所见,这里存在吞吐量和延迟之间的明显权衡。小批量处理速度快但无法充分利用硬件。大批量处理速度慢但效率高。以下是为一些较早的PaLM模型计算的延迟-吞吐量帕累托前沿(来自ESTI论文):

图示:几种PaLM模型的成本(即吞吐量)与延迟的帕累托前沿。注意芯片数量(C)和批量大小(B)如何沿着帕累托前沿移动,绿点(C:32 B:16的PaLM 540B)除外,因为可用内存不足,阻碍了该设置支持良好的批量大小,导致吞吐量下降。注意吞吐量通常在批量大小240之后趋于平稳。int8权重提供了更好的延迟-吞吐量帕累托最优解,但没有提供更好的最大吞吐量。

我们不仅可以通过批量大小作为调节旋钮来权衡延迟和吞吐量,如果我们发现自己受到HBM限制,我们可能更喜欢更大的拓扑结构而非更小的,这样我们可以容纳更大的批量。下一章将更详细地探讨这一点。

💡 要点:如果你关心生成吞吐量,请使用可能的最大每芯片批量大小。任何超过TPU算术强度(Bcrit,通常为120或240)的每芯片批量大小都将最大化吞吐量。你可能需要增加拓扑结构来实现这一点。较小的批量大小将允许你以吞吐量为代价改善延迟。

从硬件角度来看,这里有一些注意事项。点击这里了解一些细节。

这些都相当理论化。实际上,我们通常不会看到明显的屋顶线,原因有几个:

  • 我们假设HBM读取将与FLOP完美重叠的假设并不现实,因为我们的编译器(XLA)是可能出错的。

  • 对于分片模型,XLA通常也无法有效地将模型分片矩阵乘法的ICI通信与FLOP本身重叠,因此在批量大小超过32时,我们常常在线性运算上遭受延迟损失。

  • 大于理论屋顶线的批量大小由于重叠不完美,仍会看到吞吐量的一些改善,但这是一个很好的启发式方法。

内存情况如何?

我们已经花了一些时间研究带宽和FLOP,但还没有讨论内存。在推理阶段,由于我们的新数据结构KV缓存,内存情况看起来大不相同。在本节中,让我们选择一个真实的模型(LLaMA 2-13B)来展示情况有多不同:

超参数
层数 (L)40
模型维度 (D)5,120
前馈网络乘数 (F // D)2.7
注意力头数 (N)40
KV头数 (K)40
QKV维度 (H)128
词表大小 (V)32,000

推理过程中什么占用内存?显然,首先是我们的参数。计算这些参数,我们有:

参数公式大小(字节)
前馈网络参数d_model² x ffw_multiplier x 3(用于gelu + 输出投影)x n_layers5,120 x 5,120 x 2.7 x 3 x 40 =8.5e9
词表参数2(输入和输出嵌入)x n_embeddings x d_model2 x 32,000 x 5,120 =0.3e9
注意力参数[2(q和输出)x d_model x n_heads x d_qkv + 2(k和v)x d_model x n_kv_heads x d_qkv] x n_layers(2 x 5,120 x 40 x 128 + 2 x 5,120 x 40 x 128) x 40 =4.2e9

将这些参数加起来,我们得到8.5e9 + 4.2e9 + 0.3e9 = 总共13e9参数,正如预期。如前几节所述,在训练期间,我们可能会将参数存储在bfloat16中,优化器状态存储在float32中。这可能会使用约100GB的内存。这与我们的梯度检查点相比微不足道,后者可能使用几个TB。

推理有何不同?在推理期间,我们存储一份参数副本,假设用bfloat16格式。这使用了26GB内存——实际上通过量化我们通常可以做得更好。没有优化器状态或梯度需要跟踪。由于我们不进行检查点(不保留用于反向传播的激活值),无论是预填充还是生成阶段,我们的激活值占用空间都可以忽略不计。如果我们预填充8k个标记,单个激活值仅使用约8,192 x 5,120 x 2字节 = 80MB的内存。更长的预填充可以分解为许多较小的前向传递,所以对于更长的上下文也不是问题。生成使用的标记甚至更少,所以激活值可以忽略不计。

主要区别在于KV缓存。这些是所有过去标记的键和值投影,其大小仅受允许的最大序列长度限制。T个标记的总大小为

KV缓存大小 = 2 ⋅ 每个浮点数字节数 ⋅ HKLT

其中H是每个头的维度,K是KV头的数量,L是层数,2是因为同时存储键和值。

这很快就会变得非常大,即使批量大小和上下文长度适中。对于LLaMA-13B,单个8192序列在bf16格式下的KV缓存为

8192(T)× 40(K)× 128(H)× 40(L)× 2(字节)× 2 = 6.7GB

仅4个这样的缓存就超过了我们参数的内存使用量!需要明确的是,LLaMA 2在更长上下文的KV缓存大小方面并未优化(情况并非总是如此糟糕,因为通常K要小得多,如LLaMA-3中所示),但这仍然很有说明意义。在内存或延迟估计中,我们不能忽视这些。

LLaMA 2-13B的吞吐量和延迟建模

让我们看看如果我们在不同批量大小下在8个TPU v5e上尝试完美高效地执行生成会发生什么,最大到之前推导出的理论最大吞吐量的临界批量大小(240)。

批量大小18163264240
KV缓存内存(GiB)6.753.6107.2214.4428.81608
总内存(GiB)32.779.6133.2240.4454.81634
理论步骤时间(毫秒)4.9812.1320.3036.6569.33249.09
理论吞吐量(标记/秒)200.61659.30787.99873.21923.13963.53

8个TPU v5e提供给我们128GiB的HBM,6.5TiB/s的HBM带宽(每个0.82TiB/s)和1600TF/s的计算能力。

对于这个模型,增加批量大小确实能提高吞吐量,但我们很快就会遇到收益递减。批量大小超过16时内存溢出,需要比240批量大小多一个数量级的内存。更大的拓扑结构可以改善延迟,但我们已经达到了每芯片吞吐量的瓶颈。

假设我们保持参数总数不变,但神奇地使KV缓存缩小5倍(比如,使用1:5的GMQA,这意味着我们有8个KV头共享给40个Q头——详见下一章)。

批量大小18163264240
KV缓存内存(GiB)1.3410.7221.4442.8885.76321.6
总内存(GiB)27.3436.7247.4468.88111.76347.6
理论步骤时间(毫秒)4.175.607.2310.5017.0452.99
理论吞吐量(标记/秒)239.941,429.192,212.483,047.623,756.624,529.34

使用较小的KV缓存,我们仍然面临收益递减,但每芯片的理论吞吐量继续扩展到批量大小240。我们可以容纳更大的64批量大小,而且在所有批量大小下延迟也始终更好。延迟、最大吞吐量和最大批量大小都有显著改善!实际上,后来的LLaMA世代正是使用了这种优化——LLaMA-3 8B有32个查询头和8个KV头(来源)。

💡 要点:除了参数外,KV缓存的大小对模型的最终推理性能有很大影响。我们希望通过架构决策和运行时优化的组合来控制它。

提高生成吞吐量和延迟的技巧

自从原始的 Attention Is All You Need 论文发布以来,许多提高模型效率的技术已经被开发出来,这些技术通常针对KV缓存。一般来说,更小的KV缓存使得在不影响延迟的情况下增加生成步骤的批量大小和上下文长度变得更容易,并且使围绕Transformer的系统(如请求缓存)更易于管理。撇开对质量的影响不谈,我们可能会看到:

分组多查询注意力(又称GMQA、GQA):我们可以减少KV头的数量,并在注意力机制中让多个Q头共享它们。在极端情况下,可以在所有Q头之间共享单个KV头。与纯MHA相比,这将KV缓存减少了Q:KV比例的倍数,并且观察到模型的性能对这种变化相对不敏感。

这也有效地增加了注意力计算的算术强度(参见第4节中的问题4)。

混入一些局部注意力层:局部注意力将上下文限制在小到中等大小的最大长度。在训练时间和预填充时间,这涉及将注意力矩阵掩蔽为对角条带而不是三角形。这有效地限制了局部层的KV缓存的最大长度。通过在模型中混入一些局部层和一些全局层,在超过局部窗口的上下文中,KV缓存的大小大大减少。

跨层共享KV:模型可以学习以某种模式在层之间共享相同的KV缓存。虽然这确实减少了KV缓存的大小,并在增加批量大小、缓存、离线存储等方面提供了好处,但共享的KV缓存可能需要从HBM多次读取,所以它不一定能改善步骤时间。

左:多层纯全局注意力。右:一个全局/局部交错模式与相邻层共享的例子。来源:Character.ai博客

量化:推理通常对参数和KV的精度不太敏感。通过量化参数和KV缓存(例如,转为int8、int4、fp8等),我们可以节省两者的内存带宽,减少达到计算屋顶线所需的批量大小,并节省内存以便在更大的批量大小下运行。量化的另一个优势是,即使模型没有通过量化进行训练,它通常也可以在训练后应用。

使用不规则HBM读取和分页注意力:在上面的计算中,我们为每个KV缓存分配了8k的上下文,但通常不需要从内存中读取整个KV缓存——请求的长度分布范围很广,并且不使用模型的最大上下文,因此我们通常可以实现只读取KV缓存非填充部分的内核(例如Flash Attention变体)。

分页注意力是对此的改进,它将KV缓存存储在类似操作系统的页表中,并且基本上避免了KV缓存的填充。这增加了很多复杂性,但意味着每个批次只使用它所需的内存量。这是一种运行时优化,因此它对架构是无差别的。

图示:在生成过程中,单个标记(第四个)关注多个KV缓存块/页面。通过对KV缓存进行分页,我们避免了加载或存储超出需要的内存。取自PagedAttention论文

总体概况:总的来说,这些KV缓存优化可以使KV缓存大小比标准MHA Transformer减少一个数量级以上。这可以导致Transformer整体成本提高一个数量级。

在多个加速器上分布式推理

到目前为止,我们只是粗略地讨论了如何扩展到单个芯片之外。参照第5节,让我们探索可用的不同策略及其权衡。一如既往,我们将分别研究预填充和生成阶段。

预填充

从屋顶线模型的角度来看,预填充几乎与训练相同,几乎所有相同的技术和权衡都适用 — 模型(Megatron)并行、序列分片(对于足够长的上下文)、流水线、甚至FSDP都是可行的!你只需要保留KV值以便之后进行生成。与训练一样,增加芯片数量使我们能够获得更多的FLOPs/s(可能降低TTFT),但会增加通信开销(可能降低每个芯片的吞吐量)。

分片预填充的一般规则:这里有一套预填充的一般规则。我们假设我们只对单个序列进行预填充(没有批次维度):

  1. 模型分片:我们通常首先进行一定程度的模型并行,直到我们受到ICI限制。正如我们在第5节中看到的,对于1轴,这大约是F/2550(通常是4-8路分片)。

  2. 序列并行:超过这个限制后,我们进行序列并行(类似于数据并行,但在序列维度上分片)。虽然序列并行在注意力机制中引入了一些额外的通信,但在较长上下文中通常相当小。与训练一样,我们可以重叠通信和计算(分别使用Megatron的集体矩阵乘法和环形注意力)。

💡 要点:在预填充期间,几乎任何适用于训练的分片策略都可以正常工作。先进行模型并行直到ICI限制,然后进行序列并行。

生成

生成比预填充更复杂。一方面,更难获得大批量,因为我们需要将多个请求批处理在一起。延迟目标更低。这些因素加在一起,意味着我们通常更受内存限制,对通信开销更敏感,这限制了我们的分片策略:

  1. FSDP是不可能的:由于我们在从HBM加载参数和KV缓存到MXU时受内存限制,我们不希望通过ICI移动它们,因为ICI比HBM慢几个数量级。我们希望移动激活值而不是权重。这意味着类似FSDP的方法通常完全不适用于生成阶段。

  2. 没有理由进行数据并行:纯数据并行没有帮助,因为它复制了我们的参数,不能帮助我们更快地加载参数。更好的做法是启动多个模型副本。

  3. 没有序列 = 没有序列分片。序列分片祝你好运。

这主要留给我们的是密集模型生成的模型分片变体。与预填充一样,我们能做的最简单的事情是简单的模型并行(激活完全复制,MLP的权重在隐藏维度上完全分片),当我们受到ICI限制时达到4-8路。然而,由于我们通常受内存带宽限制,我们实际上可以超越这个限制来改善延迟!

关于生成阶段的ICI限制说明:在训练期间我们希望受计算限制,所以我们的屋顶线模型关注ICI通信何时比我们的FLOPs花费更长时间。然而,在生成期间,如果我们受参数加载的内存带宽限制,我们可以增加模型分片超过这个点,以最小的吞吐量成本改善延迟。更多的模型分片给我们提供了更多HBM来加载权重,而我们的FLOPs不重要。让我们看看在模型并行成为瓶颈之前我们能做多少。

$$ \begin{align}T_\text{HBM comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{ICI comms} = \frac{2BD}{W_\text{ici}}\end{align} $$

$$ T_\text{ICI comms} > T_\text{HBM comms} \rightarrow \frac{W_\text{hbm}}{W_\text{ici}} > \frac{F}{Y \cdot B} \rightarrow Y > F / (B \cdot \beta) $$

其中β = Whbm/Wici。这个数字对于TPU v5e和TPU v6e通常约为8。这意味着,例如,如果F是16,384,B是32,理论上我们可以进行高达16384 / (32 * 8) = 64路的模型并行,而不会对吞吐量造成显著影响。这假设我们可以完全将KV缓存分片为64路,这很困难:我们在下面讨论这一点。

对于注意力层,我们还以Megatron风格在头部上对注意力WQWO进行模型分片。KV权重相当小,复制它们通常比超过K路分片更经济。

💡 要点:我们在生成期间的唯一选择是模型并行的变体。我们的目标是移动激活值而不是KV缓存或参数,因为后者更大。当我们的批量大小较大时,我们进行模型并行直到FLOPs-ICI限制(F/α)。当我们的批量大小较小时,我们可以通过更多的模型分片来改善延迟(以适度的吞吐量成本)。当我们想要进行比KV头数更多路的模型分片时,我们也可以沿批次维度分片KV。

分片KV缓存

我们还有一个需要分片的额外数据结构——KV缓存。同样,我们几乎总是倾向于避免复制缓存,因为它是注意力延迟的主要来源。为此,我们首先沿着头部维度对KV进行Megatron分片。这限制在K路分片,所以对于头数较少的模型,我们尽可能地沿头部维度分片,然后沿批次维度分片,即KV[2, BZ, S, KY, H]。这意味着KV缓存完全分布式存储。

图示:注意力机制的比较:(a)具有纯模型分片的多头注意力和(b)具有KV缓存批次分片的多查询注意力。注意我们需要两个额外的AllToAll操作来将激活从模型分片转移到批次分片,使它们能够作用于KV缓存。

这样做的代价是每个注意力层需要两次AllToAll操作——一次是将Q激活转移到批次分片以便使用批次分片计算注意力,另一次是将批次分片的注意力输出转回纯模型分片。

这里是完整算法!

我们将写出具有YZ两种模型并行的完整注意力算法。我为同时使用K表示键张量和KV头部维度而道歉。设M = N/K

  1. X[B, D] = …(已有激活,从上一层未分片)

  2. K[B, S, K, H], V[B, S, K, H] = …(已有KV缓存,批次分片)ZYZ

  3. Q[B, N, H] = X[B, D] * W[D, N, H]YZQYZ

  4. Q[B, N, H] = AllToAll(Q[B, N, H])ZYZ->BYZ

  5. Q[B, K, M, H] = Reshape(Q[B, N, H])ZYZY

  6. O[B, S, K, M] = Q[B, K, M, H] * K[B, S, K, H]ZYZYHZY

  7. O[B, S, K, M] = Softmax(O[B, S, K])ZSZY

  8. O[B, K, M, H] = O[B, S, K, M] * V[B, S, K, H]ZYZSZY

  9. O[B, K, M, H] = AllToAll(O[B, K, M, H])YZZ->MZY

  10. O[B, N, H] = Reshape(O[B, K, M, H])YZYZ

  11. X[B, D] {U} = W[N, H, D] * O[B, N, H]YZOYZN,HYZ

  12. X[B, D] = AllReduce(X[B, D] { U})YZ

这相当复杂,但你大致可以看出它的工作原理。新的通信成本适中,因为它们操作的是我们较小的激活值,而作为回报,我们在加载KV时节省了大量内存带宽(KV是静态的)。

  • 序列分片:如果批次大小太小,或者上下文很长,我们可以对KV缓存进行序列分片。同样,我们需要在跨分片累积注意力时付出集体计算的代价。首先,我们需要AllGather Q激活,然后以类似于Flash Attention的方式累积KV。

设计高效的推理引擎

到目前为止,我们已经了解了如何单独优化和分片各个预填充和生成操作。要有效地使用它们,我们需要设计一个推理引擎,该引擎可以在我们选择的延迟/吞吐量帕累托前沿上提供这两种操作。

最简单的方法是先运行一批预填充,然后运行一批生成:

图示:在最简单的设置中,请求被聚合,服务器在运行一批预填充和调用生成函数之间交替,直到所有序列完成。

这种方法实现起来很简单,是大多数代码库中的第一个推理设置,但它有多个缺点:

  1. 延迟非常糟糕。我们将预填充和生成的批处理大小耦合在一起。在大型预填充批处理大小下,首个令牌的生成时间(TTFT)非常糟糕——你需要完成所有预填充才能让任何用户看到任何令牌。在小批量大小下,生成吞吐量非常糟糕。

  2. 我们将较短的生成阻塞在较长的生成上。许多序列会比其他序列更早完成,在生成过程中留下空批处理槽位,进一步损害生成吞吐量。随着批处理大小和生成长度的增加,这个问题会更加严重。

  3. 预填充需要填充。预填充被填充到最长序列,我们浪费了大量计算。虽然有解决方案,但历史上XLA使得跳过这些浮点运算变得相当困难。随着批处理大小和预填充序列长度的增加,这个问题会变得更糟。

  4. 我们被迫在预填充和生成之间共享分片。预填充和生成都位于同一个切片上,这意味着我们对两者使用相同的拓扑和分片(除非你保留两份权重副本),这通常对性能不利,例如生成需要更多的模型分片。

因此,这种方法仅推荐用于边缘应用(通常只关心服务单个用户并使用具有更低FLOPs/字节的硬件)以及Transformer代码库生命周期早期的快速迭代(因其简单性)。

一种稍好的方法是以批处理大小1执行预填充(在这种情况下,它受计算限制但具有合理的延迟),但在生成期间将多个请求批处理在一起:

这将避免批处理预填充造成的TTFT浪费,同时保持生成吞吐量高。我们称之为交错配置,因为我们"交错"预填充和生成步骤。这对于像评估这样的批量生成应用非常强大,其中吞吐量是主要目标。协调器可以配置为在任何生成槽位打开时优先考虑预填充,确保即使对于非常大的生成批处理大小也能保持高利用率。我们还可以避免将预填充填充到最大长度,因为它不与其他请求批处理在一起。

主要缺点是,当服务器执行预填充时,所有其他请求的生成都会暂停,因为所有计算资源都将被预填充消耗。正在解码响应的用户A将被正在进行预填充的用户B阻塞。这意味着即使TTFT有所改善,令牌生成平均来说也会不稳定且缓慢,这对许多应用来说不是一个好的用户体验——其他用户的预填充处于请求整体延迟的关键路径上。

为了解决这个问题,我们将解码和预填充分开。虽然Transformer推理可以在一台服务器上完成,但从延迟角度来看,通常在两组TPU/GPU上执行这两种不同的任务更好。预填充服务器生成KV缓存,这些缓存通过网络发送到生成服务器,后者将多个缓存批处理在一起,并为每个缓存生成令牌。我们称之为"分离式"服务。

这提供了几个优势:

  1. 大规模低延迟:除非预填充容量不足,否则一个用户的请求永远不会被另一个用户的请求阻塞。请求应立即预填充,然后发送到生成服务器,然后立即插入到生成缓冲区中。如果我们预计会有很多并发请求进来,我们可以独立于生成服务器的数量来扩展预填充服务器的数量,这样用户就不会在预填充队列中停留很长时间。

  2. 专业化:通常,预填充和生成的延迟最优参数分片策略/硬件拓扑相当不同(例如,更多的模型并行对生成有用,但对预填充没用)。强制两种操作使用相同的分片会损害两者的性能,而拥有两组权重会使用更多内存。此外,通过将预填充移到自己的服务器上,它不需要保存任何KV缓存,除了当前正在处理的那个。这意味着我们有更多的内存可用于历史缓存(见下一章)或优化预填充延迟。

一个缺点是KV缓存现在需要通过网络传输。这通常是可以接受的,但再次提供了减小KV缓存大小的动机。

💡 要点:对于对延迟敏感、高吞吐量的服务,我们通常必须将预填充和生成分离到不同的服务器上,预填充以批处理大小1运行,而生成则将许多并发请求批处理在一起。

连续批处理

上面的问题(2)引发了连续批处理的概念。我们优化并编译:

  • 具有可变上下文长度的多个预填充函数,并将其插入到某个KV缓冲区中,设定最大批处理大小和上下文长度/页面数。

  • 一个生成函数,它接收KV缓存,并为所有当前活跃的请求执行生成步骤。

然后,我们将这些函数与一个协调器结合起来,该协调器对传入请求进行排队,根据可用的生成槽位调用预填充和生成功能,处理历史缓存(见下一章),并流式输出标记。

前缀缓存

由于预填充计算成本高且受计算能力限制(给我们的余量更少),减少其成本的最佳方法之一就是减少使用它。由于LLM是自回归的,查询["I", "like", "dogs"]和["I", "like", "cats"]在前两个标记中产生的KV缓存是完全相同的。这意味着,原则上,如果我们先计算"I like dogs"的缓存,然后计算"I like cats"的缓存,我们只需要做1/3的计算。通过重用缓存,我们可以节省大部分工作。这在几种特定情况下特别强大:

  1. 聊天机器人:大多数聊天机器人对话涉及严格追加到自身的来回对话。这意味着如果我们能保存每个对话轮次的KV缓存,我们可以跳过除最新标记之外的所有计算。

  2. 少样本提示:如果我们有任何类型的少样本提示,这些可以免费保存和重用。系统指令通常也具有这种形式。

这种方法难以实现的唯一原因是内存限制。如我们所见,KV缓存很大(通常有许多GB),而且为了使缓存有用,我们需要保留它们直到后续查询到达。通常,预填充服务器上任何未使用的HBM都可用于本地缓存系统。此外,加速器通常在其CPU主机上有大量内存(例如,一个8xTPUv5e服务器有128GiB的HBM,但大约450GiB的主机DRAM)。这种内存比HBM慢得多——通常太慢而无法执行生成步骤——但对于缓存读取来说足够快。在实践中:

  • 因为KV缓存与处理初始请求的TPU集合相关联,我们需要某种亲和性路由来确保后续查询到达相同的副本。这可能会导致负载平衡问题。

  • 较小的KV缓存(再次)很有帮助——它使我们能够在相同空间内保存更多KV缓存,并减少读取时间。

  • KV缓存及其查找可以很自然地存储在树或字典树中。驱逐可以基于LRU(最近最少使用)原则进行。

图示:以LRU字典树实现的KV前缀缓存。我们可以通过共享前缀来避免KV内存重复。来源:Character.ai博客。

让我们看一个实现:JetStream

谷歌已经开源了一个实现这种逻辑的库,称为JetStream。服务器有一组"预填充引擎"和"生成引擎",通常位于不同的TPU切片上,由单个控制器协调。预填充发生在"预填充线程"中,而生成则发生在"生成线程"中。我们还有一个"传输线程",负责协调KV缓存从预填充切片到生成切片的复制。

引擎接口(在这里实现)是任何LLM必须提供的通用接口。关键方法有:

  • prefill: 接收一组输入标记并生成KV缓存。

  • insert: 接收一个KV缓存并将其插入到generate正在生成的KV缓存批次中。

  • generate: 接收一组批处理的KV缓存,并为每个批处理条目生成一个标记,为每个标记将单个标记的KV缓存附加到解码状态。

我们还有JetStream的PyTorch版本,可在这里获取。

第7章练习题

我将基于LLaMA-2 13B发明一个新模型用于本节。以下是详细信息:

超参数数值
n_layers (L)64
d_model (D)4,096
d_ff (F)16,384
n_heads (N)32
n_kv_heads (K)8
d_qkv (H)256
n_embeddings (V)32,128

问题1:上述模型有多少参数?每个标记的KV缓存有多大?假设我们共享输入和输出投影矩阵。

点击此处查看答案。

参数数量:

  • MLP参数数量:L * D * F * 3

  • 注意力参数数量:L 2 D * H * (N + K)

  • 词汇表参数:D * V(因为我们共享这些矩阵)

因此,我们的总参数数量是L * D * (3F + 2H * (N + K)) + D * V。代入上面的数字,我们得到64 * 4096 * (3*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 18.4e9。因此,该模型约有184亿参数。

问题2:假设我们想在TPUv5e 4x4切片上提供这个模型的服务,并且可以在这个拓扑上完全分片我们的KV缓存。假设我们对所有内容使用int8,能容纳的最大批处理大小是多少?如果我们将KV头的数量降到1个会怎样?

问题3:假设我们完全受HBM带宽限制。将所有参数从HBM加载到MXU需要多长时间?这是每步延迟的一个很好的下限。

问题4:假设我们想在TPUv5e 4x4切片上提供这个模型的服务。我们应该如何分片它?提示:也许先回答这些问题:

  1. 这个模型在ICI上的张量并行性上限是多少?

  2. 我们如何分片KV缓存?

对于这种分片,生成的每步大致延迟是多少?

问题5:假设上述模型实际上是一个MoE。MoE模型实际上是一个具有E个FFW块副本的密集模型。每个标记通过k个FFW块,这些k个输出被平均以产生最终输出。让我们使用E=16k=2,配合上述设置。

  1. 它有多少参数?

  2. 需要多大的批处理大小才能受到FLOPs限制?

  3. 每个标记的KV缓存有多大(假设没有局部注意力)?

  4. 使用T个标记的前向传递涉及多少FLOPs?

问题6:对于MoE,我们可以进行"专家分片",将专家分布在网格的一个轴上。在我们的标准表示法中,我们的第一个FFW权重形状为[E, D, F],我们将其分片为[EZ, DX, FY],其中X仅在训练期间用作我们的FSDP维度。假设我们想在TPU v5e上进行推理:

  1. 在Y=8,Z=16的TPU v5e 8x16切片上,上述模型的HBM权重加载时间是多少?每个TPU有多少可用HBM?

  2. 我们可以将模型放在的最小切片是多大?

问题7 [2D模型分片]:这里我们将详细计算ESTI论文中所称的2D权重静态分片的数学原理。我们在附录B中简要描述了这一点,但先尝试自己解决这个问题,看看你能否推导出数学公式。2D权重静态分片的基本思想是沿着DF轴分片我们的权重,使每个块大致是正方形的。这减少了通信负载,使我们能够稍微扩展得更远。

以下是2D权重静态的算法:

  1. In[B, DX] = AllGatherYZ(In[B, DXYZ])

  2. Tmp[B, FYZ] {U.X} = In[B, DX] *D Win[DX, FYZ]

  3. Tmp[B, FYZ] = AllReduceX(Tmp[B, FYZ] {U.X})

  4. Out[B, DX] {U.YZ} = Tmp[B, FYZ] *F W2[FYZ, DX]

  5. Out[B, DXYZ] = ReduceScatterYZ(Out[B, DX] {U.YZ})

  6. In[B, D] = AllGather(In[B, D])

  7. Tmp[B, F] {U.X} = In[B, D] * W[D, F]

  8. Tmp[B, F] = AllReduce(Tmp[B, F] {U.X})

  9. Out[B, D] {U.YZ} = Tmp[B, F] * W2[F, D]

  10. Out[B, D] = ReduceScatter(Out[B, D] {U.YZ})

你的目标是计算出该算法的Tmath和Tcomms,并找出何时它将优于传统的3D模型分片?

点击此处查看答案!

让我们计算Tmath和Tcomms。所有的FLOPs都是完全分片的,所以和之前一样,我们有Tmath = 4BDF/(NC),但我们的通信现在是

$$ \begin{align*} T_\text{2D comms} = \frac{2BD}{2X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} + \frac{2BD}{2X \cdot W_\text{ici}} = \frac{2BD}{X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} \end{align*} $$

我们注意到AllReduce的成本是两倍,并且我们根据每个操作执行的轴数来缩放我们的通信。假设我们有自由选择我们的拓扑结构,并且假设F = 4D(如在LLaMA-2中),我们声称(通过一些基本的微积分)XYZ的最优值是X = \sqrt{N / 8},YZ = \sqrt{8N},因此总通信量为

$$ T_\text{2D comms} = \frac{2B}{W_\text{ici}} \left(\frac{D}{X} + \frac{8D}{YZ}\right) = \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \approx \frac{11.3 BD}{\sqrt{N} \cdot W_\text{ici}} $$

首先,从上面复制,正常的1D模型并行性将有Tmodel parallel comms = 4BD/(3 ⋅ Wici),那么新的通信何时更小?我们有

$$ \begin{align*} T_\text{model parallel comms} > T_\text{2D comms} \iff \frac{4BD}{3 \cdot W_\text{ici}} > \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \ \iff N > 128 \cdot \left(\frac{3}{4}\right)^2 = 81 \end{align*} $$

对于一般的F,我们声称这个条件是

$$ N > 32 \cdot \left(\frac{F}{D}\right) \cdot \left(\frac{3}{4}\right)^2 $$

所以这告诉我们,如果我们有超过81个芯片,使用这种新方案会更好。现在这是一个有点奇怪的结果,因为我们历史上发现在大约~20路张量并行性时受到ICI限制。但在这里,即使我们受到通信限制,我们的总通信量仍然随着芯片总数的增加而减少!这告诉我们,我们可以持续增加芯片数量,增加批处理大小,进行更多参数扩展,并看到延迟降低。

第7章附录 A:批量大小 > 240 的规则有多真实?

我们上面提供的简单规则,即批量大小必须大于 240 个标记才能受计算限制,大致上是正确的,但它忽略了 TPU 在其他操作(如设备间通信)未使用所有可用 HBM 时预取权重的能力。

这里是一个小型 Transformer 的层时间(以微秒为单位)的实证图,其 dmodel 为 8192,dff 为 32768,且每层仅有 2 个矩阵乘法。这来自这个 Colab 笔记本。你会看到步骤时间在批量大小约 240 之前增长非常缓慢,之后呈线性增长。

这是以标记/微秒为单位的实际吞吐量。这使得论点相当清晰。由于我们的层约有 6 亿参数,在这里被 4 路分片,我们预计最小延迟约为 365 微秒。

因此至少在这个模型中,我们确实看到吞吐量增加直到每个数据并行分片约 BS240。

第7章附录 B:2D 权重固定分片

随着拓扑结构的增长,如果我们能够访问更高维度的网格(如 TPU 的网格),可以通过引入第二个分片轴进一步优化,称为"2D 权重分片"。我们称之为"2D 权重固定",在高效扩展 Transformer 推理论文中有更详细的描述。

因为我们在 Megatron 中只分片隐藏的 F 维度,随着芯片数量增加,使用 1D 分片时,它可能会变得比 Edmodel 维度)小得多。这意味着在较大的批量大小下,在 MLP 的第一层应用后,通过隐藏维度执行部分集体操作可能更经济。

此图显示:

  1. 1D 权重固定分片,即纯 Megatron 分片,其中激活在 AllGather 后完全复制,权重在隐藏 F 维度上完全分片。

  2. 2D 权重固定分片,其中权重在隐藏 F 和归约 E 维度上都进行分片,激活在 E 维度上分片。我们在第一层之前在 (yz) 轴上执行 AllGather,然后在 (x) 轴上执行 ReduceScatter。

对于注意力层,Megatron 风格的分片对于较小数量的芯片也相对简单。然而,Megatron 是在 nheads 维度上进行的,这限制了可能的分片数量。通过修改 2D 分片(不是分片隐藏维度,而是分片 nheads 维度),我们获得了进一步扩展的能力。

第7章附录 C:延迟限制通信

回顾一下,在第 3 节中,我们推导了在每个 TPU 上对大小为 B 的张量执行 AllGather 所需的时间,这是在 X 个芯片上,通过全双工带宽为 WICI 且延迟为 Tmin 的 1D 环形链路进行的。

$$ T_{total} = \max\left(\frac{T_{min} \cdot |X|}{2}, \frac{B}{W_{ICI}}\right) $$

对于大的 B 值,时钟时间保持相对恒定,因为当你向系统添加更多芯片时,执行操作所需的数据移动量和可用的总带宽同时增加。

由于延迟优化推理期间移动的数据量相对较低,激活上的集体操作通常受延迟项限制(尤其是对于小批量大小)。可以通过计算完成前需要的跳数来轻松可视化延迟。

在 TPU 上,如果通信的张量大小依赖部分小于每跳 1 微秒(一跳是两个相邻设备间的通信),我们可能会受到实际调度集体操作的固定开销的瓶颈限制。使用 4.5e10 单向 ICI 带宽,当(字节数/n分片数)/4.5e10 < 1e-6 时,ICI 通信变为延迟限制。对于 8 路 Megatron 分片,这是当 buffer_size < 360kB 时。这在推理期间实际上并不那么小:使用 BS=16D=8192 的 int8 格式,我们的激活将使用 16*8192=131kB,所以我们已经受到延迟限制。

结论:当总字节数 < WICI × 1e-6 时,我们的通信变为延迟限制。例如,使用模型并行度 Y,当 Y > BD/45,000 时,我们在 int8 中受到限制。

这里可以与计算屋顶线进行类比 — 我们正在承担一些小操作的固定成本(通信的延迟,矩阵乘法的内存带宽)。

第7章附录 D:推测性采样

当我们真正关心端到端延迟时,还有一个额外的技巧可以使用,称为推测性采样。回顾一下,我们通常从大型 Transformer 一个接一个地生成标记:

使用推测性采样,我们使用更小、更便宜的模型生成标记,然后用大模型检查结果。这在贪婪解码中最容易理解:

  1. 我们从一些更小、更便宜的模型中贪婪采样。理想情况下,我们使用训练来匹配较大模型的模型,例如通过蒸馏,但它也可以简单到仅使用 n-gram 或对小型文本语料库进行标记匹配。

  2. 在生成 K 个标记后,我们使用大模型计算到目前为止我们生成的所有标记的下一个标记对数概率。

  3. 由于我们是贪婪解码,我们可以检查由较小模型生成的标记是否在所有可能的标记中具有最高概率。如果某个标记错误,我们取最长的正确前缀并用正确的标记替换第一个错误的标记,然后返回步骤 (1)。如果所有标记都正确,我们可以使用最后一个正确的对数概率采样一个额外的标记,然后返回步骤 (1)。

为什么这能提高延迟性能?这个方案仍然需要我们为每个标记做一次大模型的前向传递的等效 FLOP,但因为我们可以将大量标记批处理在一起,我们可以在一次前向传递中完成所有这些 FLOP,并利用我们不受计算限制的事实来免费评分更多标记。

平均而言,每个被接受的标记在 FLOP 方面变得更昂贵(因为有些会被拒绝,而且我们必须调用一个草稿模型),但我们从硬件中榨取了更多 FLOP,而且小模型很便宜,所以总体上我们是赢家。由于所有内容都已由大模型检查过,我们完全不改变采样分布(尽管对于非贪婪方式,确切的轨迹会有所不同)。

对于正常的自回归采样,每秒标记数与步骤时间相同。我们仍然受制于这里算术强度部分的理论最小步骤时间(实际上,推测性采样步骤时间通常比正常自回归采样慢得多,但因为我们平均每步获得超过 1 个标记,所以可以获得更好的每秒标记数)。

图示:这个图显示了 Chinchilla(DeepMind 的 700 亿参数模型)与 40 亿参数起草器(小模型)的每步延迟和推测成功率。对于 XSum(一个自然语言数据集),理想的推测量是提前约 3-4 个标记,而 HumanEval(一个编码数据集)更可预测,并且从更激进的推测中获益。

这对非贪婪解码如何工作?这有点复杂,但本质上归结为一个受 Metropolis-Hastings 启发的算法,其中有 P草稿模型(选择的标记) 和 P目标模型(选择的标记) 从对数概率派生,如果这些概率的比率小于某个阈值,则概率性地拒绝所选标记。

两篇论文同时推导了这一点,并有很好的实际工作示例。

结论:推测性采样是另一个强大的杠杆,用于交换吞吐量以获得更好的每标记延迟。然而,在批量大小有限的情况下(例如,小型硬件占用空间或大型 KV 缓存),它成为双赢策略。

第8章:在TPU上部署LLaMA 3

本节将探讨部署LLaMA-3所需的条件以及如何高效地进行部署。与之前的"应用"部分一样,请先尝试用笔和纸自行推导答案,然后再查看解答!

LLaMA服务的情况如何?

让我们回顾一下LLaMA 3-70B的结构(参考第6节):

超参数数值
nlayers (L)80
dmodel (D)8,192
dff(F)28,672
nheads (N)64
nkv heads (K)8
dqkv (H)128
nembeddings (V)128,256

让我们从一个简单的问题开始:我们应该使用什么硬件来部署?答案基本上是,每美元FLOPs最便宜的那个。这并不总是正确的,有时候更多的HBM或ICI带宽比FLOPs更关键,但这是一个很好的启发式方法。因此,我们通常希望在TPU v5e上部署,这是我们当前专用的推理芯片(成本来自截至2025年2月的Google Cloud定价):

TPU类型bfloat16 FLOPs/sGoogle Cloud美元/小时FLOPs / $
H1009.9e14$10.83.3e17
v5p4.59e14$4.23.9e17
v5e1.97e14$1.25.8e17

每个TPU v5e有16GB的HBM,这将要求我们相当激进地对模型进行分片。让我们先思考一些可能对我们有影响的基本数值:

问题:LLaMA 3-70B的KV缓存每个token有多大?你可以假设我们将它们存储在int8中。这决定了我们在给定拓扑上的批量大小可以有多大。

思考完毕后点击这里!

LLaMA 3-70B有8个KV头,所以每个token的大小是2 * K * H * L = 2 * 8 * 128 * 80 = 160kB

注意这有多大!如果我们有32k个token的序列长度(这很常见),这将使用162e3 * 32,768 = 5.3GB/序列。对于BS=240,这是1.3TB!由于TPU v5e每个只有16GB,我们大约需要(70e9 + 1.3e12) / 16e9 = 86个TPU v5e芯片才能容纳这么多内存。还要注意与70GB的模型参数相比,这有多大。

问题:假设我们想要以32的批量大小和8192的序列长度部署L3 70B,所有内容(参数和KV)都使用int8。这将使用多少总内存?我们可以在最小的什么规模的切片上部署它?

答案

由于我们的KV在int8中是160e3字节,我们的总KV内存是160e3 * 8192 * 32 = 41.9e9字节。我们的参数是70e9字节,因为每个参数有1个字节。因此,我们的总内存使用量是41.9e9 + 70e9 = 112GB

我们可以使用的最小切片将有112e9 / 16e9 = 7个TPU,或者(四舍五入到偶数大小),TPU v5e 4x2。这将是一个紧凑的配置,考虑到其他开销,我们可能无法完全适应,所以我们可能至少需要一个4x4(或者降低批量大小)。

问题:在TPU v5e 4x2上使用这个批量大小和量化,我们每个解码步骤预期的延迟大约是多少?吞吐量(tokens/秒/芯片)是多少?4x4呢?假设我们在bfloat16中执行FLOPs,并且一切都是完全分片的。

答案

我们可以引用上一节的公式

$$ \begin{align*} \tiny \text{理论步骤时间(一般)} = \underbrace{\frac{\text{批量大小} \times \text{KV缓存大小}}{\tiny \text{总内存带宽}}}{\text{注意力(始终受带宽限制)}} + \underbrace{\max\left(\frac{2 \times \text{批量大小} \times \text{参数数量}}{\text{总FLOPs/s}}, \frac{\text{参数大小}}{\text{总内存带宽}}\right)}{\tiny \text{MLP(可能受计算限制)}} \end{align*} $$

这里我们的临界批量大小将约为120,因为我们的参数在int8中,但我们的FLOPs在bfloat16中。我们也可以手动计算RHS最大值,但这基本上是我们已经做过几次的计算。所以我们的矩阵乘法和FLOPs都深处内存带宽限制区域。

严格来看内存带宽,我们的步骤时间基本上是(KV大小 + 参数大小) / (8 * HBM带宽) = 112e9 / (8 * 8.1e11) = 17ms所以理论上我们的步骤时间约为17ms。我们的吞吐量将是32 / .017 = 1882 tokens / 秒,或1882 / 8 = 235 tokens / 秒 / 芯片

这里有一个注意事项,即检查我们的矩阵乘法是否可能受ICI限制。我们可以在这里专用2个轴,所以理论上当Y > 2 * F/2550 = 2 * 28672/2550 = 22时,我们受ICI限制,所以我们很好!

如果我们在4x4上运行,从ICI角度我们仍然没问题,所以我们的延迟会降至17 / 2 = 8.5ms,但每芯片的吞吐量将保持不变。

考虑吞吐量

让我们花一点时间纯粹思考吞吐量。当我们优化吞吐量时,我们希望计算受限,意味着我们尽可能接近利用全部TPU MXU容量。通常这意味着我们希望批量大小尽可能大,这样我们就能完成尽可能多的工作。

问题:在TPU v5e上,使用bfloat16权重和激活值,我们的批量大小需要多大才能在矩阵乘法中受计算限制?如果我们使用int8权重但在bfloat16中执行FLOPs呢?如果是int8权重配合int8 FLOPs又如何?

答案

如第7节所述,对于任何BDF的bfloat16矩阵乘法,我们有

$$ \begin{equation*} T_\text{math} > T_\text{comms} \leftrightarrow \frac{2BDF}{2DF} \geq \frac{\text{TPU bfloat16 FLOPs/s}}{\text{HBM bandwidth}} = 240 \end{equation*} $$

当我们的权重在int8中时,分母损失一个因子2,所以我们有2BDF/DF = 2B > 240,或者等价地B > 120,是之前临界批量大小的一半。这对我们非常有帮助!当我们使用int8权重和int8 FLOPs时,我们必须使用TPU FLOPs/s的int8值,它从bfloat16的1.97e14增加到3.94e14,几乎翻倍。这意味着我们回到了起点,约B > 240。

int8权重和bfloat16 FLOPs的情况非常常见,因为无损量化参数通常比执行低精度算术更容易。

问题:使用bfloat16、int8和int4(同时适用于KV和参数)以及8k上下文,我们可以在最小的什么TPU v5e拓扑上部署LLaMA 3-70B?对于这个问题,你可以认为KV缓存忽略不计。

答案

这很简单!如果我们可以接受很小的批量大小,那么唯一的限制就是将参数内存放入HBM,即只是ceil(num_params * sizeof(dtype) / HBM per TPU,或ceil(70e9 * sizeof(dtype) / 16e9)四舍五入到最近的合理拓扑(2的某个倍数):

dtype参数大小每token的KV大小(字节)最小TPU v5e数量实际最小切片KV缓存的剩余HBM8k情况下的KV缓存数量
bf16140GB324kB8.754x4 = 16个芯片11643
int870GB162kB4.384x2 = 8个芯片6852
int445GB81kB2.812x2 = 4个芯片1967

这非常酷!它告诉我们如果需要,我们可以在TPU v5e 2x2上部署LLaMA 70B。不过你会注意到KV缓存的数量非常小。那就是我们的批量大小!这意味着我们的FLOPs利用率会很差。我们会很乐意使用更大的拓扑来将批量大小提高到240。

问题:假设我们在这些拓扑上使用最大批量大小,每个生成步骤我们可以预期什么样的延迟?

答案

这也很简单,因为我们选择的批量大小填满了所有HBM!这只是一个问题,即加载一个完整TPU v5e字节到MXU需要多长时间。这就是v5e HBM / v5e HBM内存带宽 = 16GB / 8.2e11 = 19ms,所以理论上是每步19ms。假设我们的生成中位长度为512个token,那大约是每次解码9秒。注意,使用更小的批量大小可能获得略微更好的延迟,例如如果我们只考虑int4的模型参数,我们的最小延迟约为10ms/步,因为HBM不再是满的。

💡 要点:我们总是可以通过询问从HBM加载所有模型参数到MXU需要多长时间来获得解码延迟的下限。当我们的KV缓存很小时,你可以将每层视为只是加载权重块然后丢弃它们。除非我们使用大批量大小或大量设备间通信,否则这通常是一个合理的界限(在1.5倍之内)。当我们的批量大小更大时,我们需要同时建模KV缓存加载,因为它会主导参数。

同样,在FLOPs受限区域(例如训练或大批量推理),我们可以使用总FLOPs/(NC) = 2 ⋅ 参数数量 ⋅ B/(NC)下限,这假设没有通信。

问题:对于上述每种情况,每芯片的吞吐量是多少(以查询/芯片为单位)?假设我们的中位解码长度为512个token。

答案

这是一个重要问题,因为它与每token成本直接相关。

根据我们对中位解码长度的假设,我们的吞吐量就是B/(每步延迟 ⋅ 中位步数 ⋅ N) ≊ 43/(0.019 512 N)。这给我们大约(4.42/N) QPS,所以带入N我们得到:

dtypeQPS / 芯片
bfloat160.27
int80.66
int41.72

注意,这相当乐观,因为它完全忽略了前向传递的工作内存(分配给激活和注意力的内存)。使用Flash Attention这不是荒谬的,但也不现实。实际数字可能是这个的1/2左右。为了获得绝对最大吞吐量,我们可能需要将芯片数量增加一倍以上,并显著增加批量大小。

问题:如果我们将上述每个示例的拓扑加倍,我们的峰值吞吐量会如何变化?

答案

如果我们在bfloat16中使用4x8切片,我们将有186GB剩余用于KV缓存,这将让我们将批量大小提高到161。然后,由于我们的步骤时间保持不变,我们将有吞吐量16.54 / num_chips,或

dtypeQPS / 芯片
bfloat16(在4x8上)0.51
int8(在4x4上)1.03
int4(在2x4上)2.06

进一步增加将带来更大的收益!重要的启示是,如果我们受KV缓存大小限制,最小拓扑不是性能最佳的拓扑

问题:现在让我们深入探讨分片问题。假设我们想在TPU v5e 4x8上以bfloat16格式提供服务。在生成过程中,我们应该在TPU v5e 4x8上使用什么样的模型分片?我们能避免受通信限制吗?

答案

正如上一节所讨论的,在生成过程中我们只有一个分片选择:模型并行性。在变成通信受限之前,我们能做多少?正如我们在上一节中讨论的,当满足以下条件时,我们的模型大约会变成通信受限:

$$ Y > \frac{F \cdot n_\text{axes}}{2550} $$

对于LLaMA 3-70B,我们有F = 28,672,所以如果我们做2个轴的模型分片,这给我们大约Y = 28672 ⋅ 2/2550 = 22,因此一般来说,我们可以扩展到大约16个芯片而不会受到通信限制,这让我们可以使用4x4但不是4x8。通常,由于我们不能完美地重叠计算,即使这个估计也过于乐观。

💡 要点:我们实际上不能在4x8上使用纯模型并行性提供服务。我们能做的最好的是4x2或者可能是4x4。

然而,正如我们所讨论的,当我们的批量大小较小时,我们通常可以做更多的模型并行而不会显著影响吞吐量,因为我们的模型是内存带宽受限而不是FLOPs受限。我们之前说这个值大约是Y = F/(8 ⋅ B),所以如果我们使用批量大小64,理论上我们可以达到Y = 28,672 / (8 * 64) = 56的模型并行度,然后才会变成ICI受限。为了进行理性检查,我们可以看一下单个矩阵乘法的Tici通信,Thbm通信和Tmath。我们清楚地有:

$$ \begin{align}T_\text{ici comms} = \frac{2BD}{W_\text{ici}} && T_\text{hbm comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{math} = \frac{2BDF}{Y \cdot C}\end{align} $$

对于4x8,这将给我们Tici comms = (2 * 64 * 8192) / 9e10 = 11usThbm comms = (2 * 8192 * 28,672) / (32 * 8.1e11) = 18us,以及Tmath = (2 * 64 * 8192 * 28,672) / (32 * 1.97e14) = 4us,所以理论上我们仍然是HBM带宽受限,这很好!*注意,从4x4扩展到4x8可能从吞吐量角度来看帮助不大,但它会减少我们的延迟!

如果我们看看int8和int4配置,我们可以使用纯模型并行性。所以我们达到了一个量化实际上提供了超越更快FLOPs的有意义优势的点:它让我们在变成通信受限之前可以使用更大的批量大小。*所以这个故事的结尾是,我们不能在4x8上达到峰值吞吐量,但对于int8和int4配置,我们可以使用纯模型并行性

提示:最大的有用模型并行度取决于dff和你分片模型的轴数。最大值通常在8到32之间,取决于模型大小。你可以超出这个限制来改善延迟,但会付出一些吞吐量成本。

预填充怎么样?

我们在这里基本上忽略了预填充,因为它简单得多。让我们将几个概念放在一起,思考端到端的情况。

问题:假设我们在预填充期间达到40%的FLOPs利用率。在16个TPU v5e芯片上,长度为8192的预填充需要多长时间?

答案

在8k个token的情况下,我们完全是计算受限,所以我们只需要考虑FLOPs。我们知道我们的模型有70e9个参数,所以每次前向传递使用2 * 70e9 * B个FLOPs。假设40%的MFU(FLOPs利用率),这给我们的运行时间约为2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s。与我们之前看到的数字相比,这实际上相当多!

问题:假设我们的中位预填充长度为8192个token,中位解码长度为4096个token。假设我们的生成批量大小为32。平均每步有多少序列完成解码?平均每步从我们的KV缓存中驱逐多少个token?

答案

这相当直接。由于我们的中位解码长度为4096个token,一个序列大约每1/4096个token完成一次。给定批量大小32,这意味着我们每步有32 / 4096个序列被驱逐。由于我们的KV缓存长度大约是8192 + 4096,这是32 * (8192 + 4096) / 4096 = 96个token每步被驱逐。一般公式是B * (P + G)/G,其中PG是预填充和生成长度。

问题:假设我们使用分离式服务,中位预填充长度为8192,中位解码长度为512。假设使用上面计算的bfloat16的预填充和生成延迟。你需要什么比例的预填充:生成服务器才能使两者都充分饱和。

答案

这是一个有趣的问题。让P是预填充服务器的数量,G是生成服务器的数量。所以一般来说,这是一个管道问题,我们以P / prefill_latency的速率输入序列,并以B * G / (generate_latency * median_decode_length)的速率消费它们。我们计算出每个预填充步骤910ms,批量大小43(让我们称之为32)的每个解码步骤19ms。因此,我们需要P / 0.91 = 32 * G / (0.019 * 512)P = 3G,即我们需要大约3倍于生成服务器的预填充服务器!

可视化延迟与吞吐量之间的权衡

让我们继续以LLaMA 70B为例,来实际查看不同批量大小在生成过程中的延迟和吞吐量。正如我们在上一节针对PaLM模型所展示的,这为我们提供了吞吐量/延迟的帕累托前沿。让我们假设使用16路张量并行,因为这是在MLP块中保持计算密集型的合理上限。我们将在这里使用TPU v5e 4x4拓扑。滑块控制序列长度,以便您可以看到更大KV缓存的影响。

  • 看看成本与延迟之间的权衡是多么显著。以增加一倍的每token延迟为代价,我们可以实现大约100倍的每token成本降低。此外,我们的延迟可以从低批量大小时的5.5毫秒到非常大批量时的20毫秒不等。

  • 注意在2k上下文长度时,当达到BS 120屋顶线时(这里是120,因为我们使用int8权重但bf16 FLOPs),吞吐量实际上在每芯片每毫秒约1个token处趋于平稳。然而,随着序列长度增加,我们无法再将这个批量大小装入内存,因此我们永远无法达到完全饱和点。

  • 注意对于相同的吞吐量,大批量大小下的延迟要高得多,因为KV加载变得占主导地位(而不是参数加载)。

我们可以通过将成本和延迟的来源分解为参数加载时间、KV加载时间和FLOPs时间来更好地理解这一点。红色区域是我们预期在MLP块中计算密集型的区域。

这讲述了一个很有说服力的故事。您可以看到,最初,参数加载占据了延迟的绝大部分,直到批量大小变得足够大,FLOPs和KV加载变得更加显著。值得注意的是,在所有大于2048的序列长度下,我们在KV缓存加载上花费的时间比在FLOPs上的时间更多!因此,虽然我们可以通过增加批量大小来提高硬件利用率,但在长上下文长度下,KV加载始终主导总步骤时间。

结论:对于LLaMA 3-70B,在几乎所有这些配置中,我们都强烈受KV缓存内存带宽限制(和HBM限制),这突显了减小KV缓存大小对生成吞吐量的重要性。还要注意延迟/吞吐量权衡仍然是多么显著。

这段代码相当简单。

以下是计算这些屋顶线的代码:

import numpy as np
num_chips = 16  # 我们固定16作为我们进行的总模型并行度
param_size = 70e9  # int8意味着每个参数1字节
sequence_length = 8192  # 这个可以变化
hbm_bandwidth = 8.20E+11  # v5e
flops = 1.97E+14  # v5e
param_size = bytes_per_param * param_count

def kv_cache_size(bs):
    return 2 * bs * 128 * 8 * 80

def min_topology(bytes):
    return 2 ** np.ceil(np.log2(bytes / 16e9))

def get_max_batch_size(max_num_chips: int = 16):
  # 对于topo_sizes中的num_chips:  
  batch_sizes = np.arange(1, 1024, 4)
  kv_sizes = kv_cache_size(sequence_length * batch_sizes)
  num_chips = min_topology(kv_sizes + param_size)
  max_idx = np.where(num_chips <= max_num_chips)[0][-1]
  return max_idx

max_idx = get_max_batch_size(num_chips, sequence_length, param_size)  # 获取能够容纳的最大批量大小
batch_sizes = np.arange(1, 512, 1)[:max_idx]
kv_sizes = kv_cache_size(sequence_length * batch_sizes)
kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)
param_comms_time = param_size / (num_chips * hbm_bandwidth)
param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])
flops_time = 2 * param_count * batch_sizes / (num_chips * flops)  # 在二维感知上大致正确
mlp_time = np.maximum(flops_time, param_comms_time)
attn_time = kv_comms_time  # 对于生成过程总是受带宽限制
latency = 1000 * (mlp_time + attn_time)
throughput = batch_sizes / (latency * num_chips)

注意我们如何非常明确地将延迟分为两个来源:KV加载和参数加载,以及延迟如何受到FLOPs或通信的限制,取决于哪个更大。

第8章练习题

这里有一些练习题。其中一些重复了上面已经解答的内容,但可能在教学上有用。

问题1:LLaMA 3-405B的每个前向传递每个token使用多少FLOPs?假设我们受FLOPs限制,在TPU v5e上N个芯片上单个前向传递的下限是多少?如果我们受通信限制呢?忽略模型无法适应单个芯片的事实。

问题2:假设我们想要使用BS240、int8权重和int8 KV缓存来提供LLaMA 3-8B服务。(a)模型参数(b)KV缓存和(c)峰值工作激活(大致)各使用多少字节?我们可以在什么最小拓扑上运行这个?

问题3:您将如何在TPU v5e上提供LLaMA 3-405B服务?假设使用int8权重和bfloat16 FLOPs。假设我们有每token 15ms的严格限制,我们能够实现的最高吞吐量配置是什么?理论最小步骤时间是多少?

第9章:如何分析TPU代码

TPU软件栈的全局视图

Google提供了许多用于编程TPU的API,从高级JAX代码到低级Pallas或HLO。大多数程序员专门编写JAX代码,这让你可以编写抽象的NumPy风格线性代数程序,这些程序会自动编译以在TPU上高效运行。

这里有一个简单的例子,一个用于矩阵相乘的JAX程序:

import jax
import jax.numpy as jnp
def multiply(x, y):
  return jnp.einsum('bf,fd->db', x, y)

y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))

通过调用jax.jit,我们告诉JAX跟踪这个函数并生成一个称为StableHLO的低级IR,这是一个平台无关的机器学习计算IR,然后由XLA编译器转换为HLO。编译器运行多个过程来确定融合、布局和其他因素,这些因素会产生在JAX性能分析中可观察到的HLO。这个HLO以LLVM风格的图形视图表示JAX代码中的所有核心线性代数操作(矩阵乘法、点式操作、卷积等)。例如,上述程序的简化HLO版本如下:

ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
  %Arg_1.2 = bf16[256,16]{1,0} parameter(1), metadata={op_name="y"}
  %convert.3 = f32[256,16]{1,0} convert(bf16[256,16]{1,0} %Arg_1.2),
  %Arg_0.1 = f32[128,256]{1,0} parameter(0), metadata={op_name="x"}
  ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1},
}

我们稍后会解释HLO的语法,但现在只需注意它与上面的JAX代码非常匹配。例如,

ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1}

就是上面的实际矩阵乘法,它分别沿0和1维度相乘两个f32矩阵。

为了将这个HLO转换为可以在TPU上执行的代码,XLA编译器首先将它降级为LLO(低级优化器)IR。LLO直接编程TPU,调度内存之间的复制,将数组推入脉动阵列等。LLO代码包含将缓冲区推入脉动阵列、拉取结果以及调度在TPU内存不同部分之间通信的DMA的原语。一旦降级为LLO,它就被编译成加载到TPU IMEM并执行的机器代码。

当程序运行速度比我们希望的慢时,我们主要在JAX级别上改进性能。然而,这样做通常要求我们了解一些HLO的语义以及代码在TPU上的实际运行方式。当低级出现问题时,我们会拉出另一个逃生舱并在Pallas中编写自定义内核。要查看程序的HLO及其运行时统计信息,我们使用JAX分析器。

JAX分析器:多功能TPU分析工具

JAX提供了一个多功能的TPU分析器,包含许多有用的工具,用于理解程序运行时TPU上发生的情况。你可以使用jax.profiler模块来追踪程序运行过程,记录从每个子组件的持续时间、每个程序的HLO、内存使用情况等所有内容。例如,以下代码将把追踪结果转储到/tmp/tensorboard文件中,可以在TensorBoard中查看(这里有一个逐步指南)。

import jax
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (1024, 1024))
  y = x @ x
  y.block_until_ready()

# 现在你可以在Google Colab中加载TensorBoard
#
# !pip install tensorboard-plugin-profile
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/tensorboard
#
# 或者在外部使用
#
# > tensorboard --logdir=/tmp/tensorboard
#

以下是你可以在分析器中做的概述:

进入TensorBoard后,分析器有几个关键选项卡,帮助你理解你的程序:

  1. 跟踪查看器以时间线的形式显示TPU上实际发生的详细时间线。

  2. 图形查看器显示HLO图,让你看到程序的各个部分如何相互馈送以及如何分片。

  3. 内存分析和内存查看器:这些显示你的程序使用了多少内存。

虽然共享分析结果有点困难,但这里有一个Perfetto链接,其中至少包含了一个简单Transformer的跟踪查看器组件。这个Colab让你生成完整的JAX/TensorBoard跟踪并进行试验。

跟踪查看器

跟踪查看器可能是分析器中最有用的部分。下面的例子显示了一个简单的Transformer,并标注了各个部分。名称来自代码中提供的标签。

跟踪查看器显示了每个TPU核心上所有操作的时间顺序时间线。我们这里只看TPU:0,因为通常所有TPU执行相同的指令。几个关键注意点:

  1. 顶行(XLA Ops)显示了实际的TPU操作(名称是HLO名称)。其他所有内容都是基于jax.named_scopejax.named_call和Python堆栈跟踪的近似追踪。

  2. 注意到重复的块,我们可以在这里隔离单个层。我们还可以看到(通过查看代码/理解Transformer的工作原理)哪些部分是注意力机制,哪些部分是MLP。

  3. 通过点击XLA操作,我们可以查看它来自代码中的哪个位置(有助于理解跟踪)并查看到图形查看器的链接。

提示:你可以使用"电子游戏"风格的控制来导航跟踪查看器,用A/D左右平移,W/S放大和缩小。这些控制使导航变得更加容易。

如何解读XLA操作

HLO实际上并不难读懂,而且对于理解上面跟踪中的特定部分非常有帮助。这里有一个名为fusion.3的操作示例。

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3

让我们将其分解为各个部分。

  • 操作名称: fusion.3

    • 点或融合操作是一组最多包含1个矩阵乘法和可能有一堆相关的逐点VPU操作的操作集合。
  • 形状/布局: bf16[32,32,4096]

    • 这是操作的输出形状。我们可以看到数据类型是bf16(每个参数2字节)和[32,32,4096]是形状。
  • 布局: {2,1,0:T(8,128)(2,1)}

    • {2,1,0:T(8,128)(2,1)}告诉我们内存中轴的顺序(列优先,行优先等)和数组填充。下面有更多解释。
  • 内存位置: S(1)

    • S(1)告诉我们这个数组存在于VMEM中。S(0)(有时省略)是HBM。S(2)和S(3)是其他内存空间。
  • 参数: bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32

    • 这个操作有一个输入,一个名为fusion.32的bf16数组,具有特定形状。这告诉我们哪个函数输入到这个函数中。

让我们更深入地理解这种表示法。以下是一个简单的例子:

f32[3,5]{1,0:T(2,2)}

这再次告诉我们,这个操作返回一个形状为[3, 5]的float32数组,具有特定的平铺{1,0:T(2,2)}。虽然平铺不太重要,但简单来说,平铺告诉我们一个N维数组在内存中如何按顺序排列。下面是一个图表,显示了这个数组的布局方式:

{1,0:T(2,2)}中,1,0部分告诉我们数组维度在物理内存中的排序,从最小到最大。你可以从右到左读取这部分,并在f32[3,5]中找出相应的维度,以确定数组的物理布局。在这个例子中,物理布局是[3,5],与逻辑形状相同。之后,T(2,2)告诉我们数组是以(2, 2)大小的块进行平铺的,在每个块内,数组先有行(行优先),然后是列,即(0, 0)后面是(0, 1),然后是(1, 0)(1,1)。由于T(2, 2)平铺,数组被填充到[4, 6],内存使用量扩大约1.6倍。对于上面给出的大型bf16数组,bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)},我们使用T(8,128)(2,1),这告诉我们数组有两个层次的平铺,一个外部的(8, 128)平铺和一个内部的(2, 1)平铺(用于bf16,使我们的加载始终是4字节的倍数)。例如,这里是bf16[4,8]{1,0,T(2,4)(2,1)}(颜色是(2,4)平铺,红色框是(2,1)平铺):

平铺会影响张量块加载到VMEM的效率,XLA有时会在程序内引入复制来"重新平铺"或"重新布局"张量,有时会带来不小的开销。

图形查看器

虽然上面的一些融合操作看起来很复杂,但XLA图形查看器使它们更容易解析。例如,这里是一个相当复杂的融合视图:

盯着一堆HLO图表并尝试将HLO操作映射到你正在分析的代码上非常有帮助。通过悬停在一个框上,你通常会看到函数定义所在的代码行。

查看真实示例性能分析

这个Colab有一个用于假Transformer的示例性能分析。这里是一个Perfetto链接,如果你赶时间可以至少查看Trace Viewer。我比平常花了更多精力使用jax.named_scope调用来注释跟踪,这样你就可以识别发生了什么。

查看这个性能分析并尝试真正理解每个部分在做什么。让我们从FFW块开始,逐一分解:

这里我们放大了FFW块。你会看到上投影Op是一个融合(矩阵乘法),输入为bf16[8, 1024, 8192]bf16[8192, 16384],输出为bf16[32, 1024, 16384]。我知道(因为我写了这段代码)这是4路DP、2路MP分片矩阵乘法的局部视图,所以我们实际上在做

X: bf16[32, 1024, 8192] * W_{in}: bf16[8192, 32768] -> Tmp: bf16[32, 1024, 32768]

我们预期这需要多长时间?首先,我们每个数据并行分片的批量大小是8 * 1024 = 8192,所以我们应该是完全受计算约束的。这是在8个TPUv2核心上(Google Colab上免费提供),所以我们预计它需要大约2 * 32 * 1024 * 8192 * 32768 / (23e12 * 8) = 95.6ms,这几乎就是它实际花费的时间(96ms)。太棒了!这意味着我们获得了极佳的FLOPS利用率!

通信情况如何?你会注意到第二个矩阵乘法末尾隐藏的小融合。如果我们点击它,你会看到

%fusion.1 = bf16[8,1024,4096]{2,1,0:T(8,128)(2,1)} fusion(bf16[8,1024,8192]{2,1,0:T(8,128)(2,1)} %fusion.31), kind=kCustom, calls=%all-reduce-scatter.1

这基本上是一个小型ReduceScatter(这是GraphViewer);

我们预期这需要多长时间?好的,我们在TPUv2 4x2上执行ReduceScatter,这应该只需要在1.2e11双向带宽上跳一次。数组大小为23210248192,批处理轴分成4份,所以每个分片是2810248192=134MB。所以这应该大约需要1.1ms。实际上需要多长时间?性能分析报告为1.13ms。所以我们非常接近理论上限!

我们也来看看注意力机制!这是注意力组件的性能分析:

我点击了Q投影操作,它使用形状为[d_{model} = 8192, n_{heads} = 32, d_{qkv} = 256]的矩阵W_Q。我们沿头部维度进行Megatron分片。尝试做同样的练习,计算这些应该需要多长时间。

内存分析

内存分析使查看程序内存随时间的变化变得容易。这对调试OOM(内存不足)很有帮助。你可以在这里看到约7.5GB分配给模型参数,还有约10GB可用。所以我们可以在内存中放入更多内容。

第9章练习题

问题1:查看这个Colab/性能分析文件,找出看起来可疑的地方以及这里发生了什么。你能准确地告诉我正在发生什么计算,以及每个操作在做什么吗?每个矩阵的真实形状是什么,它们是如何分片的?尝试先查看性能分析文件,不要阅读代码。

点击此处查看答案

这是两个矩阵乘法,具体来说是这样的:

def matmul(w1, w2, x):
  return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))

你可以看到一个归约操作,两个大型融合操作和一个全归约操作。第一个大型融合是:

%fusion.1 = bf16[4096]{0:T(1024)(128)(2,1)} fusion(bf16[4096,8192]{1,0:T(8,128)(2,1)} %param.1, bf16[8192]{0:T(1024)(128)(2,1)} %reduce.6), kind=kLoop, calls=%fused_computation.1

这告诉我们每个分片的形状是bf16[8192] * bf16[4096, 8192] -> bf16[4096](在8192维度上)。通过观察最终的AllReduce操作,其replica_groups={{0,16,32,48,64,80,96,112}, ...},我们可以判断我们正在进行8路模型并行,所以真实的形状是[8, 8192] * bf16[32,768, 8192] -> bf16[8, 32,768]

问题2:前面提到的Transformer Colab实现了一个简单的模拟Transformer。按照Colab中的说明,使用GSPMD分区对朴素Transformer进行基准测试。每个部分需要多长时间?理论上应该需要多长时间?使用了什么分片策略?尝试修复分片策略!提示:使用jax.lax.with_sharding_constraints来约束行为。有了这个修复,你能获得的最佳MXU是多少?

作为参考,初始版本每层大约需要184ms,优化后的性能分析每层需要67ms。完成这些后,尝试仔细查看性能分析,看看你是否能仅从性能分析中回答以下问题:

  • 这是什么分片策略?

  • 批处理大小、d_{model}d_{ff}是多少?

  • 注意力机制与MLP块分别占用多少比例的时间?

  • 在理论性能上限下,每个操作应该占用多少比例的时间?

注意:自从这个问题编写以来,XLA编译器已经有所改进。初始版本现在每层大约需要90ms,优化后的性能分析只比初始版本好约10ms(每层80ms)。尽管如此,尝试一下看看你是否能做得更好,这仍然很有价值。

第10章:在JAX中编程TPU

如何高效地使用JAX编程TPU!本节大部分内容摘自这里

JAX中的并行计算是如何工作的?

JAX支持两种多设备编程思路:

  1. 编译器,掌控一切!让编译器自动分割数组并决定添加什么通信来促进给定程序的执行。这使您可以在单个设备上编写程序,然后无需任何更改即可自动在数百台设备上运行。

  2. 让我直接表达我的意思,该死!虽然编译器很好,但它们有时会做错事并添加您不想要的通信。有时我们希望对我们正在做的事情非常明确。

相应地,JAX为这两种思路提供了两种API:jitjax.jit)和shard_mapjax.experimental.shard_map.shard_map)。

  1. jax.jit允许您指定程序输入和输出的分片(通过in_shardingsout_shardings),并使用GSPMD编译器推断其余部分。虽然它并不完美,但通常能很好地自动将您的程序扩展到任意数量的芯片上。

  2. jax.experimental.shard_map.shard_map是更明确的对应方案。您可以获得程序的设备本地视图,并且必须明确编写所需的任何通信。有一个分片数组并希望每个设备上都有完整的数组?添加jax.lax.all_gather。想要在设备间对数组求和?添加jax.lax.psum(一个AllReduce)。编程更难但不太可能做出您不想要的事情。

jax.jit:自动并行解决方案

jax.jit在JAX内部扮演两个角色。顾名思义,它将Python函数"即时"编译成字节码(通过XLA/HLO/LLO)以加快运行速度。但如果输入已分片或用户指定了in_shardingout_sharding,它还允许XLA将计算分布到多个设备上并根据需要添加通信。例如,以下是如何使用jax.jit编写分片矩阵乘法:

import jax
import jax.numpy as jnp
import jax.sharding as shd

# 在TPU v5e 2x2上运行。这为硬件的两个物理轴分配名称。
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'))
def P(*args):
  return shd.NamedSharding(mesh, shd.PartitionSpec(*args))

# 我们创建一个在设备间分片的矩阵W和输入激活In。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P('Y', None))

def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

# 我们可以在这里显式编译分片矩阵乘法函数。这会添加所有必要的通信(例如矩阵乘法后的AllReduce)。
jit_matmul = jax.jit(matmul_square, out_shardings=P('X', None)).lower(In, W).compile()
out = jit_matmul(In, W)

这将自动在任何分片上运行,并在我们的设备上分割计算。但在硬件层面上实际发生了什么?

  1. 首先我们创建了在我们设备间分片的矩阵In和W。W沿收缩维度分片成2份,而In则在四个方向上分片(沿着收缩维度和输出维度)。这对应于一种分片模式W[D, F]和In[B, D],也就是一种模型并行和数据并行的组合。

  2. 如果我们在本地运行(即在单个设备上),matmul_square只会简单地对输入进行平方并执行简单的矩阵乘法。但因为我们指定了out_shardingsP('X', None),我们的输出将沿批次维度分片但在模型维度上复制,这将需要执行AllReduce操作来计算。

使用我们之前章节的符号,这可能会执行类似以下操作:

  1. Out[B, F] { U } = In[B, D] * W[D, F]XYXYDY

  2. Out[B, F] = AllReduce(Out[B, F] { U })XXY

jax.jit会为我们自动添加这些操作!我们实际上可以通过jit_matmul.as_text()打印出HLO代码,并看到以下HLO(大幅缩略):

# 这个融合操作是对分片输入和矩阵的实际矩阵乘法
%fusion = bf16[4,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[4,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)

# 我们在设备间对部分求和结果进行规约
ROOT %AllReduce = bf16[4,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[4,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)

我们可以看到上面的矩阵乘法(fusion)和AllReduce。特别注意形状。bf16[4, 1024]是激活的本地视图,因为我们的batch_size=8被分割到2个设备上,我们的d_model=2048同样被分成2份。

这真是太神奇了!无论我们的程序多么复杂,GSPMD和jit都会尝试为所有中间激活找到分片并根据需要添加通信。话虽如此,GSPMD也有其缺陷。它会犯错。有时你查看性能分析时会发现出了问题。一个巨大的AllGather占用了80%的性能分析,而实际上它并不需要。当这种情况发生时,我们可以通过显式地为中间张量添加注释来尝试纠正编译器,使用jax.lax.with_sharding_constraint。例如,使用两个矩阵乘法,我可以强制中间激活沿y维度分片(这并不是个好主意)如下所示:

import jax
import jax.numpy as jnp

def matmul(x, Win, Wout):
  hidden = jnp.einsum('bd,df->bf', x, Win)
  hidden = jax.lax.with_sharding_constraint(hidden, P('x', 'y'))
  return jnp.einsum('bf,df->bd', hidden, Wout)

这构成了jit世界中JAX并行编程的大约60%,因为这是我们干预编译器的唯一方式。值得在Colab中尝试使用with_sharding_constraint并了解它的工作原理。当我们使用jax.jit编写LLM时,我们控制分片的90%工作是更改输入和输出分片(通过in_shardingsout_shardings)以及使用with_sharding_constraint注释中间张量,以确保正确的通信发生。有关更多jax.jit的例子,这是一个很棒的文档可以阅读

shard_map:对程序的显式并行控制

虽然GSPMD是"编译器接管一切"的模式,而JAX的shard_map则把一切都交给你来控制。你可以像在jax.jit中一样指定输入的分片,但随后你需要显式地编写所有通信。与jax.jit为你提供程序的全局跨设备视图不同,shard_map给你一个局部的每设备视图。

这里有一个例子。试着思考一下这个函数做了什么:

import jax
import jax.numpy as jnp
import jax.lax
import jax.sharding as shd

from jax.experimental.shard_map import shard_map as shmap

P = shd.PartitionSpec
mesh = jax.make_mesh(axis_shapes=(2,4), axis_names=('x','y'))

x = jnp.arange(0, 512, dtype=jnp.int32, device=jax.NamedSharding(mesh, P(('x', 'y'))))

# 这个函数将在数组的1/8上操作。
def slice_and_average(x):
  assert x.shape == (512 // 8,)
  return jax.lax.pmean(x[:4], axis_name=('x', 'y'))

out = shmap(slice_and_average, mesh, in_specs=P(('x', 'y')), out_specs=P(None,))(x)
assert out.shape == (4,)

这做了什么? slice_and_average在每个TPU上运行,处理数组的1/8,从中我们切片出前4个元素并在整个网格上对它们取平均。这意味着我们实际上是在做mean(x[:4], x[64:68], x[128:132], …)。这非常酷,因为这在JAX中不是一个容易表达的操作。

为什么不用jax.jit? 如果我们使用jax.jitslice_and_average会看到数组的全局视图(完整的[512,]数组)。我们就必须切片出这个非均匀的切片,然后执行平均操作,而XLA必须正确解释这一切。XLA可能会添加错误的通信或产生混淆。在这里,我们看到的是局部视图,并且只编写我们需要的通信。

示例[集体矩阵乘法]:以一个更现实的例子,假设我们要实现模型并行,其中激活最初是按模型分片的,即A[BX, DY] * W[D, FY] -> Out[BX, FY]。天真的做法是先对A进行AllGather,然后进行局部矩阵乘法:

  1. A[B, D] = AllGather(A[B, D])XYXY

  2. Out[B, F] = A[B, D] * W[D, F]XYXDY

可惜,这样做不好,因为它不允许我们将通信与计算重叠。重叠它们可以通过"集体矩阵乘法"来实现,如Wang等人2023所述。该算法基本如下:

  • 对于每个Y分片,执行局部A块与局部W块的矩阵乘法,生成形状为[B / X, F / Y]的结果。同时,置换A使你在本地获得下一个块,执行矩阵乘法,并对结果求和。

我们可以用shard_map相当容易地实现这一点:

import functools

import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np

from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'))
def P(*args):
  return shd.NamedSharding(mesh, shd.PartitionSpec(*args))
  
B, D, F = 1024, 2048, 8192
A = jnp.arange(np.prod((B, D))).reshape((B, D))
W = jnp.arange(np.prod((D, F))).reshape((D, F))

A = jax.device_put(A, P('X', 'Y'))
W = jax.device_put(W, P(None, 'Y'))

@functools.partial(jax.jit, out_shardings=P('X', 'Y'))
def matmul(lhs, rhs):
  return lhs @ rhs
  
def collective_matmul_allgather_lhs_contracting(lhs, rhs):
  # lhs是循环操作数;rhs是本地操作数  
  axis_size = jax.lax.psum(1, axis_name='Y')  # 在这个例子中axis_size = 4
  idx = jax.lax.axis_index('Y')
  
  chunk_size = lhs.shape[1]
  assert rhs.shape[0] % chunk_size == 0
  
  def f(i, carrys):
    accum, lhs = carrys
    rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
    # 对一个块进行矩阵乘法
    update = lhs @ rhs_chunk
    # 向左循环移位
    lhs = jax.lax.ppermute(
      lhs,
      axis_name='Y',
      perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
    )
    return accum + update, lhs

  accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
  accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)
  # 在最后一次置换后计算最后一个块,使lhs保持在我们发现它的状态
  i = axis_size - 1
  rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
  update = lhs @ rhs_chunk
  return accum + update

jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_contracting, mesh,
  in_specs=(shd.PartitionSpec('X', 'Y'), shd.PartitionSpec(None, 'Y')), out_specs=shd.PartitionSpec('X', 'Y')))
  
shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)

np.testing.assert_array_equal(shmapped_out, expected_out)

这相当不错!我们可以对此进行基准测试,发现它也快得多!这里是默认jit矩阵乘法的性能分析,它需要311微秒,开头有一个大的阻塞AllGather:

这里是上面的版本,只需要244微秒。你可以看到分析中没有AllGather。全是有用的工作!我们的FLOPs利用率也高得多。

值得注意的是,在收缩维度上没有分片的矩阵乘法时间是224微秒,所以我们非常接近未分片的基准。这是一个很好的例子,说明了你可能会进行何种性能工程来提高TPU利用率。有关shard_map的更多示例,这篇笔记很棒

现在这里有几个有用的实践问题,尝试使用jax.jitshard_map来实现!

第10章练习题

这里有一些随机的JAX相关问题。我稍后会添加更多。对于所有这些问题,你需要在Colab中使用一些TPU。你可以使用带有TPUv2-8的公共Colab。从现在开始,我们假设你有N个可用设备。

问题1:对于接下来的几个部分,我们将让A是一个形状为float32[SX, DY]的激活数组,其中X * Y = N。请执行以下操作:

  1. 编写一个JAX函数,计算每个X分片的平均值,即返回一个大小为[X, DY]的数组,其中arr[i]是分片i的平均值。使用jax.jitshard_map两种方式实现。对每种方法进行性能分析,看看它们各自花了多长时间。是否添加了任何通信?提示:本不应该有,但有时XLA还是会添加。 这里是答案。

  2. 编写一个JAX函数,返回roll(x, shift) - x,其中shift是在每个分片X内。我不会让你用jax.jit来实现这个问题,所以只需使用shard_map即可。

问题2:在这里,我们将一起构建一个基本的"专家混合"模型。让W: float32[E_X, D, F_Y]是一组E个"专家"矩阵。让A如上所述(我们的激活值),让B是一组"路由分配",其中B[i]是范围[0, E)内的整数,告诉我们要用哪个矩阵处理该激活值。我们想在JAX中编写一个函数,返回Out[i] = W[B[i]] @ A[i]

  1. 首先,让我们完全忽略分片。使所有这些张量足够小,以便它们适合一个设备。编写此函数的本地实现。确保你不会具体化形状为[S, D, F]的数组!提示:尝试将令牌排序到形状为[E, S, D]的新缓冲区中,注意掩码(为什么我们需要第二维的大小为S?)。

  2. 如果你只是对上述方法使用jax.jit,会发生一些事情。分析一下它决定进行什么通信。这需要多长时间?

  3. 你会注意到上述方法的一个问题是,它可能会在本地收集完整的激活集A,即AllGather_X([S_X, D_Y])。这不仅在通信方面代价高昂,如果我们无法在本地适配完整的激活集,在内存方面也极其昂贵。使用shard_map和显式通信实现上述功能。

    1. 对于第一步,使用jax.lax.all_gather并按(a)中的方式重新排序可能最简单。
    2. 对于第二步,尝试避免具体化任何大小为[E, S, D]的数组,即尝试使用jax.lax.all_to_alljax.lax.while_loop内以不规则方式执行计算。这样,你可以避免具体化完整的激活并浪费计算资源在填充上。这比你的原始实现快多少?
  4. 大多数MoE会路由到多个(k)专家,然后对结果取平均值。重构上述代码以实现这一点。在这种情况下,让B: int32[S, k]表示要路由到的k个专家。

问题3:上面的集体矩阵乘法示例实际上与真实的LLM非常相关。让我们调整示例以实现完整的Transformer堆栈。

  1. 作为练习,让我们先实现一个AllReduce集体矩阵乘法,即A[B_X, D_Y] *_D W[D_Y, F] \to Out[B_X, F]。注意输出不是复制的。上面讨论了朴素算法,基本上就是本地矩阵乘法后跟一个AllReduce。尝试制作一个通信重叠的"集体"版本的这个操作。提示:在输出维度上进行分块,随意使用jax.lax.psum(也就是AllReduce)。 注意:由于XLA处理这个的方式,它可能实际上并不比基线快。

  2. 上面的AllReduce集体矩阵乘法的补充是ReduceScatter集体矩阵乘法,如Tmp[B_X, F_Y] *_F W2[F_Y, D] \to Out[B_X, D_Y]。这在Transformer的下投影矩阵中出现。在JAX中实现一个集体的、重叠的版本。小心只传递你需要的最小数据量。提示:尝试在累积结果时对其进行置换。

  3. 将这两者结合到一个端到端的Transformer块中,该块执行In[B_X, D_Y] *D W{in}[D, F_Y] *F W{out}[F_Y, D] \to Out[B_X, D_Y],并具有重叠通信。这比jax.jit实现快多少?

问题4:上面实现的所有集体矩阵乘法都是单向的:它们只在一个方向上置换。重写集体AllReduce矩阵乘法和集体ReduceScatter矩阵乘法以使用双向通信。这些有多快?

第11章:结论和进一步阅读

感谢您阅读这一系列文章,并恭喜您坚持读到最后。在我们结束之前,有几点致谢:

致谢

本文档代表了谷歌DeepMind众多人员的重要集体投入,我们想简要致谢!

  • James Bradbury、Reiner Pope和Blake Hechtman最初提出了本手稿中的许多想法,并且他们很早就理解了Transformer的系统视角。

  • Sholto Douglas撰写了本文档的第一个版本,并负责启动了这个项目。他比任何人都更负责本文档的整体叙述。

  • Jacob Austin领导了将第一版从粗略笔记转变为更加完善和全面的作品的工作。他完成了大部分的编辑、排版和发布工作,并协调了其他作者的贡献。

  • 大部分图表和动画由Anselm Levskaya和Charlie Chen制作。

  • Charlie Chen撰写了推理部分并绘制了许多推理图表。

  • Roy Frostig在出版、编辑和旅程的许多其他步骤中提供了帮助。

我们还要感谢许多在整个过程中提供关键反馈的人,特别是Zak Stone、Nikhil Sethi、Caitlin Stanton、Alex Dimitriev、Sridhar Lakshmanamurthy、Albert Magyar、Diwakar Gupta、Jeff Dean、Corry Wang、Matt Johnson、Peter Hawkins等许多人。感谢Ruiqi Gao在HTML格式方面的帮助。

感谢你们所有人!

延伸阅读

这里有一系列相关的文章,包括以下内容:

在这一领域,仍有大量空间可以进行全面的写作,因此我们希望这份手稿能够鼓励更多的创作!我们也相信这是一个值得研究的有成果的领域。在许多情况下,即使没有大量硬件加速器,也可以进行这方面的研究。

反馈

请留下评论或问题,以便我们进一步改进。您可以通过以下方式联系我们的通讯作者Jacob Austin:jaaustin [at] google [dot] com,或者通过在GitHub上提交问题、拉取请求或讨论来建议编辑。

如何引用

在学术资料中进行引用时,请按以下方式引用本作品:

Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或作为BibTeX条目:

@article{scaling-book,
  title = {How to Scale Your Model},
  author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
  and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
  publisher = {Google DeepMind},
  howpublished = {Online},
  note = {Retrieved from https://jax-ml.github.io/scaling-book/},
  year = {2025}
}

参考资料