大模型训练(5):Zero Redundancy Optimizer(ZeRO零冗余优化器)
0 英文缩写
- Large Language Model(LLM)大型语言模型
- Data Parallelism(DP)数据并行
- Distributed Data Parallelism(DDP)分布式数据并行
- Zero Redundancy Optimizer(ZeRO)零冗余优化器
- P o s P_{os} Pos:Partition Optimizer States 对优化器状态进行切片
- P G P_G PG:Partition Gradients 对梯度进行切片
- P P P_P PP:Partition Parameters 对参数进行切片
1 背景
1.1 大模型训练现状
LLM在训练时往往需要大量内存来存储中间激活、权重等参数,百亿模型甚至无法在单个GPU上进行训练,使得模型训练在某些情况下非常低效和不可能。这就需要进行多卡,或者多节点分布式训练。在大规模深度学习模型训练中有如下几种主要的范式:
- 流水线并行
- 张量并行
- 数据并行
- 模型并行
目前训练超大规模语言模型技术路线:
- GPU(算力核心设备)
- PyTorch(框架,代码到驱动的链接者映射者)
- Megatron-LM(实现模型并行、张量并行、流水线并行,朴素数据并行)
- DeepSpeed(实现数据并行)
1.2 几种并行原理简述
-
PP并行的原理:当模型太大,一块GPU放不下时,流水线并行将模型的不同阶段(一般为不同层)放到不同的GPU上,通过将mini-batch切割成更细粒度的micro-batch,实现对训练数据的流水线处理,提升GPU计算通讯比。同时通过re-materialization机制降低显存消耗。
-
DP并行的优势(更强易用性应用更加广泛):在实际应用中,流水线并行并不特别流行,主要原因是模型能否均匀切割,影响了整体计算效率,这就需要算法工程师做手调。因此,来介绍一种应用最广泛,最易于理解的并行范式:数据并行。
-
数据并行的核心思想是:在各个GPU上都拷贝一份完整模型,各自吃一份数据,算一份梯度,最后对梯度进行累加来更新整体模型。理念不复杂,但到了大模型场景,巨大的存储和GPU间的通讯量,就是系统设计要考虑的重点了。在本文以及后续文章中,我们将递进介绍三种主流数据并行的实现方式:
- DP:最早的数据并行模式,一般采用参数服务器这一编程框架。实际中多用于单机多卡
- DDP:分布式数据并行,采用Ring AllReduce的通讯方式,实际中多用于多机场景
- ZeRO:(本文主要的核心内容)零冗余优化器。
2 DeepSpeed概述
2.1 基本理念与达成路径
DeepSpeed是由Microsoft提供的分布式训练工具,与其他框架相比,优势在支持更大规模的模型和提供更多的优化策略和工具(例如 ZeRO 和 Offload 等)。看一下官网对于这个理念的描述:
Why would you want to use DeepSpeed with just one GPU?
- It has a ZeRO-offload feature which can delegate some computations and memory to the host’s CPU and RAM, and thus leave more GPU resources for model’s needs - e.g. larger batch size, or enabling a fitting of a very big model which normally won’t fit.
- It provides a smart GPU memory management system, that minimizes memory fragmentation, which again allows you to fit bigger models and data batches.
具体点说,DeepSpeed将当前时刻,训练模型用不到的参数(包括模型参数、optimizer、梯度等),不计算或者缓存到CPU中,等到要用到了,再从其他GPU上拿或者从CPU挪到GPU。越多的参数被卸载掉,GPU的负担就越小;但随之的代价就是,更为频繁的GPU-GPU或者CPU-GPU交互,极大增加了训练推理的时间开销。因此,DeepSpeed使用的一个核心要义是,时间开销和显存占用的权衡。
2.2 基本概念
在分布式计算环境中,需要理解几个非常基础的概念:
- 节点编号(node_rank):分配给系统中每个节点的唯一标识符,用于区分不同计算机之间的通信。
- 全局进程编号(rank):分配给整个系统中的每个进程的唯一标识符,用于区分不同进程之间的通信。
- 局部进程编号(local_rank):分配给单个节点内的每个进程的唯一标识符,用于区分同一节点内的不同进程之间的通信。
- 全局总进程数(word_size):在整个系统中运行的所有进程的总数,用于确定可以并行完成多少工作以及需要完成任务所需的资源数量。
- 主节点(master_ip+master_port):在分布式计算环境中,主节点负责协调所有其他节点和进程的工作,为了确定主节点,我们需要知道它的IP地址和端口号。主节点还负责监控系统状态、处理任务分配和结果汇总等任务,因此是整个系统的关键部分。
2.3 通信策略
DeepSpeed 还提供了 mpi、gloo 和 nccl 等通信策略,可以根据具体情况进行选择和配置。
- mpi 是一种跨节点通信库,常用于 CPU 集群上的分布式训练
- gloo 是一种高性能的分布式训练框架,支持 CPU 和 GPU 上的分布式训练
- nccl 是 NVIDIA 提供的 GPU 专用通信库,被广泛应用于 GPU 上的分布式训练
在使用 DeepSpeed 进行分布式训练时,可以根据具体情况选择合适的通信库,例如在 CPU 集群上进行分布式训练,可以选择 mpi 和 gloo;如果是在 GPU 上进行分布式训练,可以选择 nccl。
2.4 为什么需要DeepSpeed
- ZeRO减少内存占用,用 3D 并行化优化并实现万亿参数模型训练:
- pytorch官方提供的分布式训练工具Accelerate只支持nvlink,而T4,3090这类显卡是PIX ,检测方式:nvidia-smi topo -m;
- DeepSpeed 实现了三种并行方法的灵活组合:ZeRO 支持的数据并行,流水线并行和张量切片模型并行。3D 并行性适应了不同工作负载的需求,以支持具有万亿参数的超大型模型,同时实现了近乎完美的显存扩展性和吞吐量扩展效率。此外,其提高的通信效率使用户可以在网络带宽有限的常规群集上以 2-7 倍的速度训练有数十亿参数的模型。
- Offload 使 GPU 单卡能够训练 10 倍大的模型为了同时利用 CPU 和 GPU 内存来训练大型模型,扩展了 ZeRO-2。我们的用户在使用带有单张英伟达 V100 GPU 的机器时,可以在不耗尽显存的情况下运行多达 130 亿个参数的模型,模型规模扩展至现有方法的10倍,并保持有竞争力的吞吐量。此功能使数十亿参数的模型训练更加大众化,并为许多深度学习从业人员打开了一扇探索更大更好的模型的窗户。
- 混合精度训练
- 通过 DeepSpeed Sparse Attention 用6倍速度执行10倍长的序列: DeepSpeed提供了稀疏 attention kernel ——一种工具性技术,可支持长序列的模型输入,包括文本输入,图像输入和语音输入。与经典的稠密 Transformer 相比,它支持的输入序列长一个数量级,并在保持相当的精度下获得最高 6 倍的执行速度提升。它还比最新的稀疏实现快 1.5–3 倍。此外,我们的稀疏 kernel 灵活支持稀疏格式,使用户能够通过自定义稀疏结构进行创新。
- 1 比特 Adam 减少 5 倍通信量: Adam 是一个在大规模深度学习模型训练场景下的有效的(也许是最广为应用的)优化器。然而,它与通信效率优化算法往往不兼容。因此,在跨设备进行分布式扩展时,通信开销可能成为瓶颈。我们推出了一种 1 比特 Adam 新算法,以及其高效实现。该算法最多可减少 5 倍通信量,同时实现了与Adam相似的收敛率。在通信受限的场景下,我们观察到分布式训练速度提升了 3.5 倍,这使得该算法可以扩展到不同类型的 GPU 群集和网络环境。
2.5 训练介绍
对应deepspeed参数
- ZeRO:这对应DeepSpeed工具中的ZeRO方式,分别是
zero_optimization.stage=0/1/2/3 - Offload:ZeRO-Offload 通过利用主机CPU上的计算和内存资源来执行优化器,从而减少此类模型的GPU计算和内存需求。卸载通过
zero_optimization.offload_optimizer.device设置 - gradient_checkpointing : 降低深度学习模型训练过程中内存消耗的技术
- 混合精度:在 DeepSpeed 中,可以通过在配置文件中设置
bf16.enabled: true来启用 BF16 混合精度训练,减少占用内存。混合精度训练是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。 - DeepSpeed的推理优化技术:
- Deep fusion:如下图,红色虚线框是以该单位为优化Kernel,对应的数字是优化的效率倍数
- Inference-customized GeMM

2.6 ZeRO思路整理
前文描述提要:存储主要分为两大块:Model States和Residual States
- Model State:指和模型本身息息相关的,必须存储的内容,具体包括:
- Optimizer States:是 Optimizer 在进行梯度更新时所需要用到的数据,例如 SGD 中的 Momentum、Adam优化算法中的Momentum(动量)和Variance(方差)。
- Gradients:模型梯度 G G G,是在反向传播后所产生的梯度信息,其决定了参数的更新方向。
- Parameters:模型参数 W W W,也就是我们在整个过程中通过数据“学习”的信息。
- Residual States:指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
- Activation:激活值。在流水线并行中曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
- Temporary Buffers:临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
- Unusable Fragment Memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
通过前文知道了什么东西会占存储,以及它们占了多大的存储之后,微软开发ZeRO是为了克服数据并行性和模型并行性的限制,同时实现两者的优点。注意到,在整个训练中,有很多states并不会每时每刻都用到,举例来说;
- Adam优化下的optimizer states只在最终做update时才用到
- 数据并行中,gradients只在最后做AllReduce和updates时才用到
- 参数 W W W只在做forward和backward的那一刻才用到
所以,ZeRO想了一个简单粗暴的办法:如果数据算完即废,等需要的时候,我再想办法从个什么地方拿回来,那不就省了一笔存储空间吗?沿着这个思路,我们逐一来看ZeRO是如何递进做存储优化的。
- 第1种优化:针对模型状态内存的优化(ZeRO-DP优化)
- ZeRO 将模型参数分成了三个部分:Optimizer States、Gradient 和 Model Parameter。
- ZeRO-0:禁用所有类型的分片,仅使用 DeepSpeed 作为 DDP
- ZeRO-1 Stage 1:即为 P o s P_{os} Pos
- Optimizer State Partitioning 只对optimizer进行切片后分布式保存(每一个节点仅存部分)
- 分割Optimizer states。优化器参数被划分到多个memory上,每个momoey上的进程只负责更新它自己那部分参数。减少了4倍的内存,通信容量与数据并行性相同
- ZeRO-2 Stage 2:即为 P G P_G PG
- Gradient Partitioning 对optimizer和grad进行切片后分布式保存(每一个节点仅存部分)
- 分割Optimizer States与Gradients。每个memory,只保留它分配到的optimizer state所对应的梯度。这很合理,因为梯度和Optimizer是紧密联系在一起的。只知道梯度,不知道Optimizer state,是没有办法优化模型参数的。
- ZeRO-3 Stage 3: 即为 P P P_P PP
- Parameter Partitioning (ZeRO stage 3) 对optimizer、grad和模型参数进行切片后分布式保存(每一个节点仅存部分)
- 分割Optimizer States、Gradients与Parameters,或者说,不同的layer. ZeRO-3会在forward和backward的时候,自动将模型参数分配到多个memory。ZeRO-Stage3将模型参数分片到不同的GPU上,通过交换节点间通信来降低显存占用,但需要进行额外的通信操作,因此可能会导致训练速度的下降。
- 第2种优化:针对残差状态内存的优化:ZeRO-R优化:
- 对residual states的优化(activation)
- 灵活设置保存部分activation或者每个GPU维护部分activation
- 固定大小的内存buffer
- 碎片化的存储空间进行重新整合
- 第3种优化:Custom mixed precision training handling:混合精度
- 第4种优化:Offload优化
- ZeRO-Offload to CPU and NVMe:
- 将模型参数分布在CPU和GPU上,通过CPU去计算一部分梯度,从而减少显存占用,但也会带来一定的计算开销。
- ZeRO-Infinity是ZeRO-3的拓展,将forward中间结果保存到内存、硬盘(NVMe)等缓存中,然后在需要时进行加载或重计算,进一步降低显存占用
- ZeRO-Offload to CPU and NVMe:
- 第5种优化:A range of fast CUDA-extension-based optimizers
看上去比较高大上,可能让你很难专心去理解,但实际上,这个概念非常简单。这只是通常的 DDP,只是没有每个 GPU 都复制完整的模型参数、梯度和优化器状态,而是每个 GPU 只存储其中的一部分。在随后的运行过程中,当需要给定层的完整层参数时,所有 GPU 同步以相互提供它们缺失的部分 —— 仅此而已。
3 混合精度模型
后文都是用如下混合精度实现方式来计算模型在训练时需要的存储大小,假设模型的参数 W W W大小是 Φ \Phi Φ (此处可以理解为参数数量),以byte为单位,存储如下:

- 必存(共计 12 Φ 12\Phi 12Φ):
- Parameters(FP32占4个字节,共 Φ \Phi Φ个): W 必存 = 4 Φ W_{必存}=4\Phi W必存=4Φ
- momentum(FP32占2个字节,共 Φ \Phi Φ个): M = 4 Φ M=4\Phi M=4Φ
- variance(FP32占2个字节,共 Φ \Phi Φ个) : V = 4 Φ V = 4\Phi V=4Φ
- 中间值(共计 4 Φ 4\Phi 4Φ):
- Parameters(FP16): W 中间 = 2 Φ W_{中间}=2\Phi W中间=2Φ
- Gradients(FP16): G = 2 Φ G=2\Phi G=2Φ
因为采用了Adam优化,所以才会出现momentum和variance,当然你也可以选择别的优化办法。因此这里为了更通用些,记模型必存的数据大小为 K Φ K\Phi KΦ 。因此最终内存开销为: 2 Φ + 2 Φ + K Φ 2\Phi+2\Phi+K\Phi 2Φ+2Φ+KΦ
4 ZeRO-DP
ZeRO-DP(Zero Redundancy Optimizer-Data Parallelism)是来自于论文《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》中的一种显存优化方法ZeRO的核心部分。通过该方法可以大幅度的优化显存占用,从而在有限的资源下训练更大的模型。ZeRO通过在数据并行进程中划分模型状态(参数,梯度和优化器状态),而不是全复制它们,从而消除了数据并行进程中的内存冗余。它在训练期间使用动态通信计划,以在分布式设备之间共享必要的状态,以保持计算粒度和数据并行性的通信量。
ZeRO驱动的数据并行性,它允许每个设备的内存使用量随数据并行性的程度线性扩展,并产生与数据并行性相似的通信量。 ZeRO支持的数据并行性可以适合任意大小的模型,只要聚合的设备内存足够大以共享模型状态即可。
针对模型状态的存储优化(去除冗余),ZeRO使用的方法是分片(partition),即每张卡只存 1 N \frac{1}{N} N1的模型状态量,这样系统内只维护一份模型状态。
4.0 存储模型与前置知识
作出如下假设:模型参数 W W W的个数为 Φ \Phi Φ ,梯度个数也为 Φ \Phi Φ ,GPU个数为 N N N
对单个GPU使用DDP来说:
- Reduce-Scatter阶段,通讯量为 ( N − 1 ) Φ N \frac{(N−1)\Phi}{N} N(N−1)Φ(每一个数据块大小 Φ N \frac{\Phi}{N} NΦ,ring上走 N − 1 N-1 N−1次就可以在某一个GPU上完成归并)(如果考虑发送与收到则需要两份如此通讯量)
- All-Gather阶段,通讯量为 ( N − 1 ) Φ N \frac{(N−1)\Phi}{N} N(N−1)Φ(每一个数据块大小 Φ N \frac{\Phi}{N} NΦ,ring上走 N − 1 N-1 N−1块数据次就可以把某一块GPU上完成归并的数据块下发至每一个GPU)(如果考虑发送与收到则需要两份如此通讯量)
- 单卡单向总通讯量为 2 ( N − 1 ) Φ N \frac{2(N−1)\Phi}{N} N2(N−1)Φ,随着 N N N的增大,可以近似为 2 Φ 2\Phi 2Φ,双向则为 4 Φ 4\Phi 4Φ
- 全卡单向总通讯量为 2 N Φ 2N\Phi 2NΦ,双向则为 4 N Φ 4N\Phi 4NΦ
一般互联带宽为双向带宽,即同时实现相同带宽的收与发。假设收与发的带宽均为 B B B。
| DP | 收 | 发 | 收发通讯时间 |
|---|---|---|---|
| Server(单卡) | N Φ N\Phi NΦ | N Φ N\Phi NΦ | N Φ B \frac{N\Phi}{B} BNΦ |
| Worker(单卡) | Φ \Phi Φ | Φ \Phi Φ | Φ B \frac{\Phi}{B} BΦ |
| 集群视角(以server为瓶颈) | N Φ N\Phi NΦ | N Φ N\Phi NΦ | N Φ B \frac{N\Phi}{B} BNΦ |
| DDP | 收 | 发 | 收发通讯时间 |
|---|---|---|---|
| 单卡Reduce-Scatter阶段 | Φ \Phi Φ | Φ \Phi Φ | Φ B \frac{\Phi}{B} BΦ |
| 单卡All-Gather阶段 | Φ \Phi Φ | Φ \Phi Φ | Φ B \frac{\Phi}{B} BΦ |
| 单卡All-Reduce | 2 Φ 2\Phi 2Φ | 2 Φ 2\Phi 2Φ | 2 Φ B \frac{2\Phi}{B} B2Φ |
| 集群视角All-Reduce阶段 | 2 N Φ 2N\Phi 2NΦ | 2 N Φ 2N\Phi 2NΦ | 2 N Φ N B = 2 Φ B \frac{2N\Phi}{NB}=\frac{2\Phi}{B} NB2NΦ=B2Φ |
DP中Server为瓶颈,搬运数据量均阻塞在Server的通讯能力上。DDP把通讯量均衡负载到了每一时刻的每个Worker上,当越来越多的GPU分布在距离较远的机器上时,DP的通讯时间是会增加的,但是DDP可以基本不变。
4.1 P o s P_{os} Pos:分割优化状态
首先,从 optimizer state 开始优化。将optimizer state分成若干份,每块GPU上各自维护一份(不再在单块GPU上完成所有的optimizer state的数据)。这样就减少了相当一部分的显存开销。如下图:

此时,整体数据并行的流程如下(数据量模型参考章节4.0、混合精度参考章节3):
- step1:每块GPU上存一份完整的参数 W 中间 W_{中间} W中间(使用4.0中数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ), W 必存 W_{必存} W必存参考step4中的optimizer states
- step2:将一个batch的数据分成 N N N份(X1、X2、X3)(上图为X=3),每块GPU各吃一份,做完一轮foward和backward后,各得一份梯度 G G G,(数据个数为 Φ \Phi Φ,数据格式为FP16)
- step3:对分散在不同GPU上的梯度做一次AllReduce,得到完整的梯度 G G G(数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ),产生单卡通讯量 2 Φ 2\Phi 2Φ (通讯量参考4.0 DDP表格中第4行,注意这里的通讯量没有考虑精度,仅衡量个数)
- step4:得到完整梯度 G G G,就可以对 W 必存 W_{必存} W必存做更新。我们知道 W 必存 W_{必存} W必存的更新由optimizer states、 W 必存 W_{必存} W必存原始值和梯度 G G G共同决定。这里就是 P o s P_{os} Pos的核心,每块GPU上只存部分optimizer states与部分 W 必存 W_{必存} W必存,因此只能将部分对应的 W 中间 W_{中间} W中间进行更新(下图蓝色部分,数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP16)
- 细节说明1: W 必存 W_{必存} W必存(数据个数 Φ \Phi Φ,数据精度FP32,实际存储空间 4 Φ 4\Phi 4Φ),momentum(数据个数 Φ \Phi Φ,数据精度FP32,实际存储空间 4 Φ 4\Phi 4Φ),variance(数据个数 Φ \Phi Φ,数据精度FP32,实际存储空间 4 Φ 4\Phi 4Φ)。因为采用了Adam优化,所以才会出现momentum和variance,当然你也可以选择别的优化办法。因此这里为了更通用些,记模型用于更新的数据大小为 K Φ K\Phi KΦ 。因此分片后实际内存开销为: K N Φ \frac{K}{N}\Phi NKΦ
- 细节说明2:为什么 G G G需要做Allreduce,同时存下完整的数据,因为对于每一片GPU都只看到了部分数据,需要所有的梯度归并之后才能看到全局的、经过所有batchsize数据反应的真是梯度,即对于每一个GPU来说确实只需要对应数据的分片梯度,但是这里有一个地方容易搞混G1G2G3说的是不同数据回传出来的所有参数的梯度,需要求和归并后,才能得到对应数据的准确分片梯度
- step5:每块GPU上都有部分 W 中间 W_{中间} W中间没有完成更新(图中白色部分)。所以我们需要对 W 中间 W_{中间} W中间做一次All-Gather,从别的GPU上把更新好的部分 W 中间 W_{中间} W中间取回来。产生单卡通讯量 Φ \Phi Φ(通讯量参考4.0 DDP表格中第3行,注意这里的通讯量没有考虑精度,仅衡量个数),完成刷新后的 W 中间 W_{中间} W中间(数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ)
step2和step3可以用下图表示,作为下图后蓝色块再allgather到所有GPU上变成全部蓝色:

做完 P o s P_{os} Pos 后,显存和通讯量的情况如下:
| 显存占用 | 实跑显存 K = 12 K=12 K=12, Φ = 7.5 B \Phi=7.5B Φ=7.5B , N = 64 N=64 N=64 | 单卡通讯量 | |
|---|---|---|---|
| DDP | ( 2 + 2 + K ) Φ (2+2+K)\Phi (2+2+K)Φ | 120GB | 2 Φ 2\Phi 2Φ |
| P o s P_{os} Pos | ( 2 + 2 + K N ) Φ (2+2+\frac{K}{N})\Phi (2+2+NK)Φ | 31.4GB | 3 Φ 3\Phi 3Φ |
表格说明如下:
- 显存占用:考虑实际的byte,考虑混合精度,考虑中间变量,具体细节参考3,分别是
- step1与step5中的: W 中间 W_{中间} W中间(数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ)
- step3中的: G G G(数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ)
- step4中的:OS与 W 必存 W_{必存} W必存分片:(数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP32,由于优化手段不一样,实际存储空间 K N Φ \frac{K}{N}\Phi NKΦ)
- 单卡通讯量:没有考虑byte,可以理解为仅包含个数信息,但是不影响比例关系比较
假设各变量大小如表格第二列所示,那么 P o s P_{os} Pos 在增加1.5倍单卡通讯开销的基础上,将单卡存储降低了接近 N N N倍。看起来是个还不错的trade-off,那么还能做得更好吗
4.2 P o s P_{os} Pos+ P G P_{G} PG :分割优化状态与梯度
现在,更近一步,把梯度也拆开,每个GPU格子维护一块梯度。

此时,整体数据并行的流程如下(数据量模型参考章节4.0、混合精度参考章节3):
- step1:每块GPU上存一份完整的参数 W 中间 W_{中间} W中间(使用4.0中数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ), W 必存 W_{必存} W必存参考step4中的optimizer states
- step2:将一个batch的数据分成 N N N份(X1、X2、X3)(上图为X=3),每块GPU各吃一份,做完一轮foward和backward后,各得一份梯度 G G G,下图中绿色+白色
- step3:对分散在不同GPU上的梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。例如对GPU1,它负责维护G1,因此其他的GPU只需要把G1对应位置的梯度发给GPU1做加总就可。汇总完毕后,保留下图中的绿色,白色块对本GPU无用,可以从显存中移除。所以不同GPU分别维护分片G1/G2/G3(数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP16,实际存储空间 2 Φ N \frac{2\Phi}{N} N2Φ),单卡通讯量 Φ \Phi Φ 。(通讯量参考4.0 DDP表格中第2行,注意这里的通讯量没有考虑精度,仅衡量个数)
- step4:每块GPU用自己对应的 O O O和 G G G去更新相应的 W 必存 W_{必存} W必存(参考章节4.1,数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP32,由于优化手段不一样,实际存储空间 K N Φ \frac{K}{N}\Phi NKΦ)
- step5(同4.1):每块GPU上都有部分 W 中间 W_{中间} W中间没有完成更新。所以我们需要对 W 中间 W_{中间} W中间做一次All-Gather,从别的GPU上把更新好的部分 W 中间 W_{中间} W中间取回来。产生单卡通讯量 Φ \Phi Φ(通讯量参考4.0 DDP表格中第2行,注意这里的通讯量没有考虑精度,仅衡量个数),完成刷新后的 W 中间 W_{中间} W中间(数据个数为 Φ \Phi Φ,数据格式为FP16,实际存储空间 2 Φ 2\Phi 2Φ)
Step2和Step3见下图:

做完 P o s P_{os} Pos+ P G P_{G} PG 后,显存和通讯量的情况如下:
| 显存占用 | 实跑显存 K = 12 K=12 K=12, Φ = 7.5 B \Phi=7.5B Φ=7.5B , N = 64 N=64 N=64 | 单卡通讯量 | |
|---|---|---|---|
| DDP | ( 2 + 2 + K ) Φ (2+2+K)\Phi (2+2+K)Φ | 120GB | 2 Φ 2\Phi 2Φ |
| P o s P_{os} Pos | ( 2 + 2 + K N ) Φ (2+2+\frac{K}{N})\Phi (2+2+NK)Φ | 31.4GB | 3 Φ 3\Phi 3Φ |
| P o s P_{os} Pos+ P G P_G PG | ( 2 + 2 + K N ) Φ (2+\frac{2+K}{N})\Phi (2+N2+K)Φ | 16.6GB | 2 Φ 2\Phi 2Φ |
和DDP相比,上述例子中存储降了8倍,单卡通讯量持平,(通信量的优化主要是因为不需要全部G进行allreduce通信)好像更牛皮了呢!那么,还可以优化吗?
4.3 P o s P_{os} Pos+ P G P_G PG+ P P P_P PP :分割优化状态、梯度与参数
看到这里,也许你有点感觉了,ZeRO的思想就是:万物皆可切,万物皆可抛。所以现在,我们把参数也切开。每块GPU置维持对应的optimizer states,gradients和parameters(即 W W W)。

此时,整体数据并行的流程如下(数据量模型参考章节4.0、混合精度参考章节3):
- step1:每块GPU上只保存参数 W 中间 W_{中间} W中间的部分切片(数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP16,实际存储空间 2 Φ N \frac{2\Phi}{N} N2Φ)
- step2:将一个batch的数据分成 N N N份,每块GPU各吃一份,做forward时,对 W 中间 W_{中间} W中间做一次All-Gather,取回分布在别的GPU上的 W 中间 W_{中间} W中间,得到一份完整的 W 中间 W_{中间} W中间,单卡通讯量 Φ \Phi Φ 。forward做完,立刻把不是自己维护的 W 中间 W_{中间} W中间抛弃。
- step3:做backward时,对 W 中间 W_{中间} W中间做一次All-Gather,取回完整的 W 中间 W_{中间} W中间,单卡通讯量 Φ \Phi Φ 。backward做完,立刻把不是自己维护的 W 中间 W_{中间} W中间抛弃。各得一份梯度 G G G
- step4:对分散在不同GPU上的梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度,其余不相关梯度扔掉。(分片 G G G:数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP16,实际存储空间 2 Φ N \frac{2\Phi}{N} N2Φ),单卡通讯量 Φ \Phi Φ
- step5:每块GPU用自己对应的 O O O和 G G G去更新相应的 W 必存 W_{必存} W必存(参考章节4.1,数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP32,由于优化手段不一样,实际存储空间 K N Φ \frac{K}{N}\Phi NKΦ)
- step6(同4.1):每块GPU上都有部分 W 中间 W_{中间} W中间没有完成更新。由于只维护部分 W 中间 W_{中间} W中间,因此无需再对 W 中间 W_{中间} W中间做任何AllReduce操作,此时就回到了step1的状态(数据个数为 Φ N \frac{\Phi}{N} NΦ,数据格式为FP16,实际存储空间 2 Φ N \frac{2\Phi}{N} N2Φ)。
做完 P o s P_{os} Pos+ P G P_{G} PG+ P P P_P PP 后,显存和通讯量的情况如下:
| 显存占用 | 实跑显存 K = 12 K=12 K=12, Φ = 7.5 B \Phi=7.5B Φ=7.5B , N = 64 N=64 N=64 | 单卡通讯量 | |
|---|---|---|---|
| 朴素DP | ( 2 + 2 + K ) Φ (2+2+K)\Phi (2+2+K)Φ | 120GB | 2 Φ 2\Phi 2Φ |
| P o s P_{os} Pos | ( 2 + 2 + K N ) Φ (2+2+\frac{K}{N})\Phi (2+2+NK)Φ | 31.4GB | 3 Φ 3\Phi 3Φ |
| P o s P_{os} Pos+ P G P_G PG | ( 2 + 2 + K N ) Φ (2+\frac{2+K}{N})\Phi (2+N2+K)Φ | 16.6GB | 2 Φ 2\Phi 2Φ |
| P o s P_{os} Pos+ P G P_G PG+ P p P_p Pp | ( 2 + 2 + K N ) Φ (\frac{2+2+K}{N})\Phi (N2+2+K)Φ | 1.9GB | 3 Φ 3\Phi 3Φ |
到这一步,我们用1.5倍的通讯开销,换回近120倍的显存。只要梯度计算和异步更新做的好,通讯时间大部分可以被计算时间隐藏,因此这样的额外通讯开销,也是划算的。
到这里,我们可以放出原始论文中的说明图了。

4.4 ZeRO VS 模型并行
模型并行,是指在forward和backward的过程中,只需要用自己维护的那块 W W W来计算就行。即同样的输入 X X X,每块GPU上各算模型的一部分,最后通过某些方式聚合结果。
大家可能会想,既然ZeRO都把参数 W W W给切了,那它应该是个模型并行呀?为什么要归到数据并行里呢?
其实ZeRO是模型并行的形式,数据并行的实质。
对ZeRO来说,它做forward和backward的时候,是需要把各GPU上维护的 W W W聚合起来的,即本质上还是用完整的 W W W进行计算。它是不同的输入 X X X,完整的参数 W W W,最终再做聚合。
5 ZeRO-R
说完了以上对model states的显存优化,现在来看对residual states的优化。
3.1 P a Pa Pa:Partitioned Activation Checkpointing
前面说过,对activation的存储是灵活的。不像optimizer states,gradients和parameters对模型更新是必须的,activation只是起到加速梯度计算的作用。因此,在哪几层保存activation,保存哪些activation都是可以灵活设置的。同样,我们也可以仿照以上切割方式,每块GPU上只维护部分的activation,需要时再从别的地方聚合过来就行。需要注意的是,activation对显存的占用一般会远高于模型本身,通讯量也是巨大的,所以这块要灵活、有效地实验设计。
3.2 C B C_B CB:Constant Size Buffer
固定大小的内存buffer,它的目的在于:
- 提升带宽利用率。当GPU数量上升,GPU间的通讯次数也上升,每次的通讯量可能下降(但总通讯量不会变)。数据切片小了,就不能很好利用带宽了。所以这个buffer起到了积攒数据的作用:等数据积攒到一定大小,再进行通讯。
- 使得存储大小可控。在每次通讯前,积攒的存储大小是常量,是已知可控的。更方便使用者对训练中的存储消耗和通讯时间进行预估。
3.3 M D M_D MD:Memory Defragmentation
在前文提过,设置机制,对碎片化的存储空间进行重新整合,整出连续的存储空间。防止出现总存储足够,但连续存储不够而引起的存储请求fail
6 ZeRO-Offload与ZeRO-Infinity
最后,简单介绍一下ZeRO-Offload。它的核心思想是:显存不够,内存来凑。
如果我把要存储的大头卸载(offload)到CPU上,而把计算部分放到GPU上,这样比起跨机,是不是能既降显存,也能减少一些通讯压力呢?
- Offload优化
- ZeRO-Offload和ZeRO-Stage3是DeepSpeed中的不同的Zero-Redundancy Optimization技术,用于加速分布式训练,主要区别在资源占用和通信开销方面。
- ZeRO-Offload将模型参数分布在CPU和GPU上,通过CPU去计算一部分梯度,从而减少显存占用,但也会带来一定的计算开销。
- ZeRO-Infinity是ZeRO-3的拓展。允许通过使用 NVMe 固态硬盘扩展 GPU 和 CPU 内存来训练大型模型。ZeRO-Infinity 需要启用 ZeRO-3。
6.1 ZeRO-Offload
ZeRO-Offload的核心思路就是让CPU和内存也参与到训练中去,回顾一下前文用到的训练流程的图,ZeRO-Offload就是把这个流程用上图的方式把fp32参数的更新和float2half操作拆分到了CPU和内存上计算,而前向和后向传播依然由GPU负责
具体做法是:
- forward和backward计算量高,因此和它们相关的部分,例如参数W(fp16),activation,就全放入GPU。
- update的部分计算量低,因此和它相关的部分,全部放入CPU中。例如W(fp32),optimizer states(fp32)和gradients(fp16)等。
具体切分如下图:


6.2 ZeRO-infinity
也是同理,它们在解决的事情都是:找个除GPU之外的地方,存数据。
7 总结
- 在DP中,每个GPU上都拷贝一份完整的模型,每个GPU上处理batch的一部分数据,所有GPU算出来的梯度进行累加后,再传回各GPU用于更新参数
- DP多采用参数服务器这一编程框架,一般由若个计算Worker和1个梯度聚合Server组成。Server与每个Worker通讯,Worker间并不通讯。因此Server承担了系统所有的通讯压力。基于此DP常用于单机多卡场景。
- 异步梯度更新是提升计算通讯比的一种方法,延迟更新的步数大小决定了模型的收敛速度。
- Ring-AllReduce通过定义网络环拓扑的方式,将通讯压力均衡地分到每个GPU上,使得跨机器的数据并行(DDP)得以高效实现。
- DP和DDP的总通讯量相同,但因负载不均的原因,DP需要耗费更多的时间搬运数据
- ZeRO:通过增加部分通信的代价,减少本GPU中不必要存储或者计算的数据,进一步缩小GPU的内存开销
8 参考
- https://arxiv.org/pdf/1910.02054.pdf
- https://blog.51cto.com/u_14691718/5631471
- https://blog.csdn.net/qq_43307074/article/details/127688761
- https://web.eecs.umich.edu/~mosharaf/Readings/Parameter-Server.pdf
- https://zh.d2l.ai/chapter_computational-performance/parameterserver.html
- https://blog.csdn.net/dpppBR/article/details/80445569
- https://arxiv.org/abs/1910.02054
- https://blog.51cto.com/u_14691718/5631471
- [LLM]大模型训练DeepSpeed(一)-原理介绍-CSDN博客
- 图解大模型训练之:数据并行下篇( DeepSpeed ZeRO,零冗余优化) - 知乎
相关文章:
大模型训练(5):Zero Redundancy Optimizer(ZeRO零冗余优化器)
0 英文缩写 Large Language Model(LLM)大型语言模型Data Parallelism(DP)数据并行Distributed Data Parallelism(DDP)分布式数据并行Zero Redundancy Optimizer(ZeRO)零冗余优化器 …...
C# 实现 “Hello World” 教程
.NET学习资料 .NET学习资料 .NET学习资料 C# 作为一种广泛应用于.NET 开发的编程语言,以其简洁、高效和类型安全等特性,深受开发者喜爱。在踏入 C# 编程领域时,编写经典的 “Hello World” 程序是重要的起点,它能帮助我们快速熟…...
LabVIEW无线齿轮监测系统
本案例介绍了基于LabVIEW的无线齿轮监测系统设计。该系统利用LabVIEW编程语言和改进的天牛须算法优化支持向量机,实现了无线齿轮故障监测。通过LabVIEW软件和相关硬件,可以实现对齿轮箱振动信号的采集、传输和故障识别,集远程采集、数据库存储…...
IM 即时通讯系统-01-概览
前言 有时候希望有一个 IM 工具,比如日常聊天,或者接受报警信息。 其实主要是工作使用,如果是接收报警等场景,其实DD这种比较符合场景。 那么有没有必要再创造一个DD呢? 答案是如果处于个人的私有化使用࿰…...
【人工智能】 在本地运行 DeepSeek 模型:Ollama 安装指南
持续更新。。。。。。。。。。。。。。。 【人工智能】 在本地运行 DeepSeek 模型:Ollama 安装指南 安装 Ollama安装 DeepSeek 模型选择版本 ,版本越高,参数越多 性能越好使用 DeepSeek 模型 安装 Ollama 访问 Ollama 官网: 前往 https://oll…...
【Linux系统】信号:信号保存 / 信号处理、内核态 / 用户态、操作系统运行原理(中断)
理解Linux系统内进程信号的整个流程可分为: 信号产生 信号保存 信号处理 上篇文章重点讲解了 信号的产生,本文会讲解信号的保存和信号处理相关的概念和操作: 两种信号默认处理 1、信号处理之忽略 ::signal(2, SIG_IGN); // ignore: 忽略#…...
探索 Copilot:开启智能助手新时代
探索 Copilot:开启智能助手新时代 在当今数字化飞速发展的时代,人工智能(AI)正以前所未有的速度改变着我们的工作和生活方式。而 Copilot 作为一款强大的 AI 助手,凭借其多样的功能和高效的应用,正在成为众…...
解锁豆瓣高清海报(二) 使用 OpenCV 拼接和压缩
解锁豆瓣高清海报(二): 使用 OpenCV 拼接和压缩 脚本地址: 项目地址: Gazer PixelWeaver.py pixel_squeezer_cv2.py 前瞻 继上一篇“解锁豆瓣高清海报(一) 深度爬虫与requests进阶之路”成功爬取豆瓣电影海报之后,本文将介绍如何使用 OpenCV 对这些海报进行智…...
我用Ai学Android Jetpack Compose之Card
这篇学习一下Card。回答来自 通义千问。 我想学习Card,麻烦你介绍一下 当然可以!在 Jetpack Compose 中,Card 是一个非常常用的组件,用于创建带有阴影和圆角的卡片式布局。它可以帮助你轻松实现美观且一致的 UI 设计,…...
NLP深度学习 DAY4:Word2Vec详解:两种模式(CBOW与Skip-gram)
用稀疏向量表示文本,即所谓的词袋模型在 NLP 有着悠久的历史。正如上文中介绍的,早在 2001年就开始使用密集向量表示词或词嵌入。Mikolov等人在2013年提出的创新技术是通过去除隐藏层,逼近目标,进而使这些单词嵌入的训练更加高效。…...
论文阅读(十):用可分解图模型模拟连锁不平衡
1.论文链接:Modeling Linkage Disequilibrium with Decomposable Graphical Models 摘要: 本章介绍了使用可分解的图形模型(DGMs)表示遗传数据,或连锁不平衡(LD),各种下游应用程序之…...
Python中容器类型的数据(上)
若我们想将多个数据打包并且统一管理,应该怎么办? Python内置的数据类型如序列(列表、元组等)、集合和字典等可以容纳多项数据,我们称它们为容器类型的数据。 序列 序列 (sequence) 是一种可迭代的、元素有序的容器类型的数据。 序列包括列表 (list)…...
PySPARK带多组参数和标签的SparkSQL批量数据导出到S3的程序
设计一个基于多个带标签SparkSQL模板作为配置文件和多组参数的PySPARK代码程序,实现根据不同的输入参数自动批量地将数据导出为Parquet、CSV和Excel文件到S3上,标签和多个参数(以“_”分割)为组成导出数据文件名,文件已…...
蓝桥杯备考:模拟算法之字符串展开
P1098 [NOIP 2007 提高组] 字符串的展开 - 洛谷 | 计算机科学教育新生态 #include <iostream> #include <cctype> #include <algorithm> using namespace std; int p1,p2,p3; string s,ret; void add(char left,char right) {string tmp;for(char ch left1;…...
使用LLaMA-Factory对AI进行认知的微调
使用LLaMA-Factory对AI进行认知的微调 引言1. 安装LLaMA-Factory1.1. 克隆仓库1.2. 创建虚拟环境1.3. 安装LLaMA-Factory1.4. 验证 2. 准备数据2.1. 创建数据集2.2. 更新数据集信息 3. 启动LLaMA-Factory4. 进行微调4.1. 设置模型4.2. 预览数据集4.3. 设置学习率等参数4.4. 预览…...
@Nullable 注解
文章目录 解释 Nullable 注解注解的组成部分:如何使用 Nullable 注解a. 标注方法返回值:b. 标注方法参数:c. 标注字段: 结合其他工具与 Nonnull 配合使用总结 Nullable 注解在 Java 中的使用场景通常与 Nullability(空…...
Arduino大师练成手册 -- 控制 AS608 指纹识别模块
要在 Arduino 上控制 AS608 指纹识别模块,你可以按照以下步骤进行: 硬件连接 连接指纹模块:将 AS608 指纹模块与 Arduino 连接。通常,AS608 使用 UART 接口进行通信。你需要将 AS608 的 TX、RX、VCC 和 GND 引脚分别连接到 Ardu…...
Mask R-CNN与YOLOv8的区别
Mask R-CNN与YOLOv8虽然都是深度学习在计算机视觉领域的应用,但它们属于不同类型的视觉框架,各有特点和优势。 以下是关于 Mask R-CNN 和 YOLOv8 的详细对比分析,涵盖核心原理、性能差异、应用场景和选择建议: 1. 核心原理与功能…...
在Ubuntu上使用Docker部署DeepSeek
在Ubuntu上使用Docker部署DeepSeek,并确保其可以访问公网网址进行对话,可以按照以下步骤进行: 一、安装Docker 更新Ubuntu的软件包索引: sudo apt-get update安装必要的软件包,这些软件包允许apt通过HTTPS使用存储库…...
MySQL的覆盖索引
MySQL的覆盖索引 前言 当一个索引包含了查询所需的全部字段时,就可以提高查询效率,这样的索引又被称之为覆盖索引。 以MySQL常见的三种存储引擎为例:InnoDB、MyISAM、Memory,对于覆盖索引提高查询效率的方式均不同,…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
TDengine 快速体验(Docker 镜像方式)
简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
遍历 Map 类型集合的方法汇总
1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...
【Linux】C语言执行shell指令
在C语言中执行Shell指令 在C语言中,有几种方法可以执行Shell指令: 1. 使用system()函数 这是最简单的方法,包含在stdlib.h头文件中: #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...
rnn判断string中第一次出现a的下标
# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...
【Linux】Linux 系统默认的目录及作用说明
博主介绍:✌全网粉丝23W,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...
逻辑回归暴力训练预测金融欺诈
简述 「使用逻辑回归暴力预测金融欺诈,并不断增加特征维度持续测试」的做法,体现了一种逐步建模与迭代验证的实验思路,在金融欺诈检测中非常有价值,本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...
WPF八大法则:告别模态窗口卡顿
⚙️ 核心问题:阻塞式模态窗口的缺陷 原始代码中ShowDialog()会阻塞UI线程,导致后续逻辑无法执行: var result modalWindow.ShowDialog(); // 线程阻塞 ProcessResult(result); // 必须等待窗口关闭根本问题:…...
针对药品仓库的效期管理问题,如何利用WMS系统“破局”
案例: 某医药分销企业,主要经营各类药品的批发与零售。由于药品的特殊性,效期管理至关重要,但该企业一直面临效期问题的困扰。在未使用WMS系统之前,其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...
