Datawhale分享
作者:真中合欢,来源:知乎
为什么会有这篇文章:虽然工作内容不是infra,但是我比较喜欢研究训练方法,魔改训练框架造轮子。正好最近看到OpenRLHF用ray管理VLLM的方案,感觉很有意思,遂研究了一下,发现VLLM的TP切分和Megatron是一套逻辑,用torch的rpc也可以代替ray的远程调用,所以打算用Megatron+TorchRPC+VLLM实现一套类似的框架,后期再把VLLM原地换掉直接megatron推理。在开始这个大工程之前,正好有机会写下这篇文章,就算是开工仪式了。
本文的主要内容
本文主要是从编程的角度,对LLM训练框架所涉及的一些前置编程知识进行讲解,并且会举一些应用技巧,对应到当前的LLM训练框架,辅助理解训练框架的代码逻辑。举个例子,下面是一段megatron初始化多卡通信组的代码:
rank = torch.distributed.get_rank()
for ranks in rank_generator.get_ranks('dp'): # 迭代生成所有数据并行ranks的列表
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
) # 根据数据并行ranks创建通信组
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
if rank in ranks: # 如果当前rank属于这个数据并行ranks,则保存创建的通信组
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
读完本文你会了解:
-
什么是rank、什么是world_size,什么是通信组group -
为什么这段代码是先建通信组,再根据rank in ranks决定是否保存?判断rank in ranks 了再创建group行不行? -
在子进程和子线程中创建通信组有什么区别和要注意的地方。 -
backend=”gloo”:什么是backend,gloo backend 是干什么的,为什么不用nccl backend。
再比如下面是一段deepspeed在参数上注册的回调函数:
def create_reduce_and_remove_grad_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.bit16_groups): # 遍历混精优化器的参数
for param in param_group:
if param.requires_grad:
def wrapper(param, i):
param_tmp = param.expand_as(param) # 在原始的参数上建立一个视图
grad_acc = param_tmp.grad_fn.next_functions[0][0] # 获得这个这个视图的前一个算子
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param, i)
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) # 在视图的前一个算子上注册回调函数
self.grad_accs.append(grad_acc)
wrapper(param, i)
读完本文你会了解到:
-
register_hook是干什么,什么时候触发。 -
通常都是直接在参数上注册hook,为什么这里要在参数的视图的前一个算子注册hook:param.expand_as(param).grad_fn.next_functions[0][0]
下面开始正文。
训练框架是干什么的
目前LLM训练有两大主流框架:Deepspeed和Megatron-LM。前者的主要提出和维护者是微软的工程师,后者是英伟达的工程师。两个框架从底层原理到设计语言可以说是大相径庭。训练框架的主要目标有2:一是在有限的GPU中尽可能地塞入一个大号模型,二是高效地利用多GPU进行训练。完成第一个目标主要依赖的是模型切分,或者更笼统地说是降低单卡显存占用。完成第二个目标依赖的是异步、高重叠度、高带宽的数据通信。Deepspeed在降低显存这方面应用的技术主要有Zero-1、2、3, 序列并行、CPU Offload,高效通信方面主要是依赖register_hook回调函数的异步通信、多cuda事件流、GPU计算重叠、连续锁页缓存。Megatron在降低显存方面的主要技术有Distributed Optimizer、Tensor Model Parallel、Pipeline Model Parallel、序列并行,在高效通信方面主要的技术有P2P通信、重叠流水线并行、梯度缓存。在设计语言上,DeepSpeed属于外挂框架,框架并不介入模型前向的计算图,因此对模型结构一般没有特殊要求,核心代码通过大量回调函数和torch派生类封装,并不暴露给用户。Megatron则是属于内嵌框架,直接改变模型计算图,因此限制模型结构必须是类Transformer的结构,代码全部暴露在外,不怎么依赖回调函数。外挂框架的好处是兼容性好,对新手友好,缺点是启动慢、计算速度略低,适合数据量不大、不关注训练加速技术,专注于模型和数据迭代的人。内嵌框架的优点是启动、训练效率高一点,对于想魔改底层的人更友好,但是对于想轻微修改训练逻辑的不太友好,适合大规模训练和想要了解、改动并行训练代码的人。
Hello World
我们以hello world来开场:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
print(f'Hello World from rank:{torch.distributed.get_rank()} out of {torch.distributed.get_world_size()} worlds.')
假设代码命名为t1.py,则可以用 torchrun –nproc-per-node 2 t1.py 来启动。(如果不做特殊说明,后面都用这组参数启动。后面给出的demo大多数没有GPU也可以跑,只需要将张量的创建位置改为cpu,以及通信相关的后端全部使用gloo后端即可)
正常情况下,可以看到程序打印出下面对代码稍作解释。
启动命令
torchrun是安装torch后自带的命令,作用是帮助我们以多进程方式启动指定脚本。
在分布式环境中,我们首先要区分多机和多卡,多机是指多台运行训练任务的服务器,多卡是指一台机器上有多个显卡。我们一般称一台服务器为node,有几台服务器就有几个node。
那么 –nproc-per-node 从字面意思上就能看出表示每台服务器上启动多少个进程。一般来说,我们启动的进程数与服务器上的gpu数相同,也就是8卡机 nproc-per-node 设为8。当然一台8卡机你可以只启动2个进程,也可以启动100个进程,这个并不是硬限制。但是保持进程数和卡数相同可以简化我们的编程逻辑,让每个进程只负责一张卡上的计算。
除了nproc-per-node,启动命令还有几个常见参数:
-
–master_addr和–master_port:当启动的是多机环境时,需要用这两个参数指定主机的IP地址和端口号。这两个参数在所有master和slaver机器上都是一样的,不需要修改。 -
–nnodes:当多机启动分布式时,使用这个参数,表明总共有多少台机器。master会根据这个参数配置来等待slaver链接,直到slaver数量凑够nnodes才会启动。这个参数在每台机器上都是一样的,不需要修改。 -
–node_rank:表示当前node是全部node中的第几个,这个参数需要根据启动的机器进行更改,master的rank必须是0。
上述启动参数要放在torchrun之后,脚本路径之前。写在脚本路径之前的参数不会被传递给脚本运行的环境,所以脚本中不要处理这些参数。
rank & world size
在启动的训练脚本中,需要先初始化分布式环境,也就是下面这几行:
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
如果是使用torchrun启动的脚本,torch会帮你在环境变量中写入RANK和WORLD_SIZE这两个变量,在脚本中可以读出来。它们分别表示当前进程的序号和总进程数。这个进程序号是全局的进程序号,也就是说如果每台服务器启动8个进程,总共10台服务器参与训练,那么第十台机器的8个进程rank分别是72、73、74、75、76、77、78、79。world_size也是全局的进程数,而不是单台服务器上的进程数。
获得rank和world_size后,就可以用torch.distributed.init_process_group初始化分布式环境。
在初始化之后的任何地方,都可以用torch.distributed.get_rank()和torch.distributed.get_world_size()来获取当前进程序号和总进程数。
另外需要提的一点是,torch并没有local_rank 这个概念。这是一些训练框架自己定义的,通常用来表示当前进程是这台服务器上的第几个进程。
后端 backend
在初始化命令中,我们还指定了backend参数,这个参数表示的是分布式环境默认使用的通信后端是什么,一般可以选择 nccl、gloo和mpi后端。gpu to gpu的通信选择nccl后端。cpu to cpu的通信选择gloo后端,一般不太会用到mpi后端。
nccl通信会比gloo通信快,所以应该尽量使用nccl进行通信。但是在有些时候,比如读取数据的阶段,如果多个进程之间需要通信,一般使用用gloo通信。因为这时,我们并不希望这些还没有开始正向传播的张量过早出现在gpu上,避免过多占用显存。在后面讲group的部分,我们会提到怎么让训练进程使用nccl后端,数据读取进程使用gloo后端。
训练脚本开始阶段的一些小细节
set device
分布式训练中,在我们创建gpu张量时,经常使用torch.cuda.current_device()获取当前rank使用的cuda设备,然后直接在这个设备上创建。使用这个函数的前提是之前通过set_device显式设定过默认设备。另外还有一些算子会要求用户必须执行过set_device。
所以一个好习惯是,在执行主要的训练逻辑之前,尽早设置默认的cuda设备:
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
如果不手动设置,current_device默认会返回cuda:0。
固定随机种子
分布式训练时随机种子是一个容易忽略的事,但却是很重要的事。在脚本初始化阶段设置一个固定的随机种子有两个好处,一个是让实验可复现,另一个就是让不同rank的模型以相同的方式初始化,不然还要用通信操作同步一下各个模型来保证模型初始化状态的一致性。
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
通信算子
通信算子指的是在分布式训练中,能够在不同rank之间进行数据交换的算子。torch提供了很多通信相关的算子,最常用可以划分为5类:规约、聚集、广播、点对点通信和其他。
规约
规约最典型的就是torch.distributed.all_reduce,这个算子能够对不同rank之间的数据进行求和、求均值、最大值等操作,特点是不论有多少个rank,最终结果只有一个tensor。例如下面这段代码,作用是对所有rank之间的”tensor“求和:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
tensor = torch.tensor([rank +。1], dtype=torch.long, device='cuda')
torch.distributed.all_reduce(tensor)
print(f'rank:{rank} {tensor}')
rank0的tensor=1,rank1的tensor=2,求和结果为3,因此能看到这样的输出:规约操作基本可以算是最常用的通信算子。例如rank之间通过规约操作同步梯度、计算平均loss,再比如分布式softmax利用规约计算最大值、归一化分母等等。
聚集
聚集操作中最典型的是torch.distributed.all_gather(),能够把不同rank的数据收集到一起,特点是参与通信的rank有多少,就会得到多少个tensor。例如下面这段代码,作用是将所有rank的tensor收集到一起:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
tensor = torch.tensor([rank+1], dtype=torch.long, device='cuda')
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensors,tensor)
print(f'rank:{rank} {tensors}')
由于“所有”rank都收集了“所有”tensor,所以每个rank都会拥有其他所有rank和自己的tensor,因此会打印出如下内容:all_gather可以模拟all_reduce,比如sum(tensors)就等于all_reduce sum,max(tensors) 就是all_reduce max。还有一个需要注意的点是不能使用列表乘法创建tensors:
tensors = [torch.empty_like(tensor)] * world_size
因为除了all_gather_object等支持python变量的通信算子,大部分都是原地操作,而列表乘法创建的tensors每一个tensor都是指向同一个对象,与原地操作会发生冲突。聚集相比规约更加灵活,但是规约的效率和显存占用通常会更好。
广播
广播的典型操作是 torch.distributed.broadcast(),作用是将指定rank的tensor发送给其他所有rank,比如下面这段代码,是将rank0的tensor发送给其他所有rank,这样最终所有rank的tensor都是一样的
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
tensor = torch.tensor([rank+1], dtype=torch.long, device='cuda')
torch.distributed.broadcast(tensor,0)
print(f'rank:{rank} {tensor}')
执行后你应该能看到下面这样的结果:广播的特点是数据由一个节点到所有节点,通常用于将只会出现在某一个rank的信号发送到全部rank。例如动态数据量训练,是否到达最后一个batch,会以最后一个数据并行rank能否获得一个完整的micro batch作为信号。此时可以利用广播操作,将最后一个数据并行rank的结束信号广播到其他所有rank。再比如在megatron中,当使用tensor并行时,仅第一个tensor rank会获取数据,其他tensor rank的数据是由第一个tensor rank广播来的。
点对点通信、p2p
除了上述几个所有rank之间的通信,有时我们也需要两个rank之间的两两通信,这时就会用到p2p通信。p2p通信的发送方使用torch.distributed.send(),接收方使用torch.distributed.recv()。比如下面这段代码,功能是所有偶数rank将tensor发送到下一个奇数rank:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
if rank % 2 == 0:
tensor = torch.tensor([999], dtype=torch.long, device='cuda')
torch.distributed.send(tensor, rank+1)
else:
tensor = torch.empty(1, dtype=torch.long, device='cuda')
torch.distributed.recv(tensor,rank-1)
print(f'rank:{rank} {tensor}')
运行会得到以下结果上面大多数算子都是原地操作,这就带来一个问题,原地操作要求先创建一个空张量,等待通信算子把数据放进来。但是一个rank怎么知道被通信过来的张量shape是什么样的,怎么提前创建这个空张量?尤其是在广播和p2p通信时经常会碰到这个问题。一个常用的方案是在通信以前,先通信一个固定ndim的张量用来表示接下来要通信的张量的shape,然后再通信真正的数据,比如下面这样:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
if rank % 2 == 0:
tensor = torch.randn(1, 4, dtype=torch.float16, device='cuda')
shape_tensor = torch.tensor(tensor.size(), dtype=torch.long, device='cuda')
torch.distributed.send(shape_tensor, rank+1)
torch.distributed.send(tensor, rank+1)
else:
shape_tensor = torch.empty(2, dtype=torch.long, device='cuda')
torch.distributed.recv(shape_tensor, rank-1)
tensor = torch.empty(torch.Size(shape_tensor), dtype=torch.float16, device='cuda')
torch.distributed.recv(tensor, rank-1)
print(f'rank:{rank} {tensor}')
这种方法有点像定义一种通信协议,第一次握手通信shape,第二次通信数据。如果在使用时ndim也固定不下来,或者tensor的 dtype也需要通信,那么我们就可以像定义通信协议一样,定义一个长一点的定长shape tensor,比如shape=(10,) ,用前9位表示接下来要通信的数据张量的shape,不足的位置补0,最后一位用一个数字表示数据类型。点对点通信主要的应用场景一个是pipeline并行,由上一个pipe发送数据到下一个pipe,一个是蒸馏的teacher 模型发送 probabilities 给 student,以及类似的reference-polocy model。
其他通信算子
还有一个经常使用的是同步屏障 torch.distributed.barrier(),这个操作不通信任何数据,作用是确保所有进程都运行到此处后再开始之后的动作。比如当存储checkpoint时,我们通常只会让第一个数据并行的rank进行保存,其他rank此时就应该使用同步屏障等待第一个rank保存结束。这样可以避免其他rank提前开始新的计算或提前结束导致保存失败,例如:
import torch
import os
import time
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
if rank == 0:
time.sleep(20) # 用sleep20秒模拟rank0在做一些高耗时行为,比如存储checkpoint
else:
a = 1 + 2
torch.distributed.barrier()
除了上面这几种,还有all-to-all算子和scatter算子。在LLM场景中经这两个算子经常出现在序列并行的正、反向传播中。
在使用通信算子时,需要确保所有相关rank都要执行到这一段代码,不然已经执行到这一步的rank会hang住一直等待,导致程序无法继续。
关于通信模式
上面提到的通信算子基本都对应有自己特殊的通信模式,比如reduce算子的背后有tree-reduce和ring-reduce通信模式,广播、p2p等等也都分别对应自己的通信模式。在LLM训练框架开发这个层面,我们可能需要关注下ring-reduce和tree-reduce这两种通信模式。按理说这两种通信模式属于all_reduce的底层实现,torch这个层面应该不需要过多关注,但是我在实际开发中多次遇到相关问题,比如:
-
all_reduce会根据通信数据量、GPU的数量、NVLink连接结构来推断用ring还是tree,这导致在模型并行数、卡数变化时,同一个all_reduce操作使用不同的通信模式,有时候会导致一些内存溢出的bug,尤其是当使用tree-reduce且通信环境比较复杂的时候。 -
ring-reduce和tree-reduce 在不同类型的GPU上精度损失是不一样的,知乎上有一篇很精彩的debug文就是关于追查这个问题的:AI训练与计算:由A800平台训练InternLM-7B无法收敛引发的思考(https://zhuanlan.zhihu.com/p/701623664) -
我在torch 2.3 版本还遇到了communicator 创建了错误数量的ranks、cuBus错误复用等问题,2.4得到修复。
当你怀疑可能自己也遇到了这类问题时,可以考虑通过环境变量NCCL_ALGO强制指定使用tree还是ring。或者尝试用all_gather替换all_reduce,或者更新torch版本。
当你打算涉足通信,就意味着你即将离开torch稳定的后方,站上痛苦debug的前线。
通信组
上面的提到的算子,除了p2p通信,基本都是所有rank参与的通信,如果只能这样未免有些太死板了。想要前5个rank all_reduce,后5个 rank all_gather应该怎么做?广播只广播给奇数rank怎么做?一部分操作在gpu上用nccl通信,一部分在cpu上用gloo通信怎么实现?如果2个rank同时用p2p通信不同张量,怎么做区分?这里就轮到通信组登场了。
注:之后的例子会更复杂一点,之后的所有脚本都会启动4个rank,也就是在此之后的所有 –nproc-per-node都为4
创建单个通信组
通信组是通过下面这种形式创建的:
ranks = [0,1]
group = torch.distributed.new_group(ranks,backend='nccl')
上面这段代码的意思是创建1个通信组,包含 rank0和rank1,以nccl作为后端。在使用通信算子时,我们可以通过group参数指定通信组,这样数据交换就只会发生在组内。比如当你这样使用barrier时,可以指定通信组,这样只要组内的rank都到达barrier,就可以继续,而不是等待所有rank都到达barrier才能继续:
import torch
import os
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
ranks = [0,1]
group = torch.distributed.new_group(ranks)
if rank in ranks:
torch.distributed.barrier(group=group)
print(f'rank:{rank} finish')
else:
torch.distributed.barrier()
print(f'rank:{rank} finish')
执行这段代码可以看到:因为rank0、1的barrier指定了通信组,所以只要0、1两个rank运行到barrier就可以继续,打印出finish。rank2、3的barrier没有指定通信组,因此他们会等待所有rank到达,但是rank0、1并不会进入这个else分支,所以rank2、3会卡住。
所有通信相关算子都支持指定通信组,作用都是类似的,就是把之前全部rank间的通信变为组内rank间的通信。对于p2p通信来说,group还有一个重要作用就是可以用来作为通信标识符。如果两个rank间同时进行多个p2p通信,不同的group可以用于区分不同的通信。
创建多个通信组
通信组可以创建多个,并且每个可以指定使用不同的后端,方便进行cuda或cpu张量的混合通信:
group1 = torch.distributed.new_group([0,1])
group2 = torch.distributed.new_group([2,3])
group3 = torch.distributed.new_group([1,2,3])
group4 = torch.distributed.new_group([0,3], backend='gloo')
在使用通信组时,rank自身必须是这个通信组的成员,比如rank3不能使用group1。那么既然rank3不能使用group1,那rank3能不能干脆就不创建group1,只创建group2、3、4呢?答案是不行。torch对创建通信组有两个要求:
-
通信组中涉及到的所有rank必须全都执行创建通信组的代码 -
通信组在所有rank上的创建顺序必须是一样的
比如在rank0创建的第一个通信组是[0, 1],那么rank1创建的第一个也必须这个,rank2、3也一样,即使他们不能用这个通信组。所以多通信组的创建代码一般是长这样的:
group = None
rank = torch.distributed.get_rank()
for ranks in [[0,1],[2,3]]:
_group = torch.distributed.new_group(ranks)
if rank in ranks:
group = _group
这样保证group的创建顺序是一致的,并且只保留自己这个rank能用的组。想一想这个设计也是合理的,想象一下两个并行的进程,相互之间是访问不到对方变量的,那两个进程怎么知道对方是用哪个通信组呢?首先构成通信组的rank不行,因为可以两个组可以由相同的rank组成,这个不具有唯一性。创建时间也不行,因为每个rank并不是完全同步的。可以人为地为每个组指定一个唯一id,那自增id也可以。
说到这里,其实基础的torch分布式训练功能就都讲完了,下面我们做个把这些功能都用上的小demo。
分布式训练demo
我们以语言模型为背景,实现一个蒸馏框架,这里teacher和student模型是模拟语言模型输入输出的假模型。蒸馏框架支持数据并行,训练数据也是模拟的。分布式优化器要自己实现。框架要支持teacher和student的重叠计算和数据通信。(如果有兴趣的话,推荐在megatron框架里实现一个蒸馏框架)
下面我们一步一步来完成,部分上面已经实现过的函数这里就不再实现了。写demo我习惯讲一步写一步,并且不给出可以直接执行的完整代码。
模拟数据
我们写一个模拟语言模型输入的迭代器来生成假的数据。语言模型的输入是token序列,也就是shape = [batch_size, seq_length] 值为 0 – vocab_size 的整型张量,这里seq_length 我们每次随机一个值。
import torch
class Dataloader:
def __init__(self,batch_size, max_length, vocab_size):
self.batch_size = batch_size
self.max_length = max_length
self.vocab_size = vocab_size
def __iter__(self):
while True:
length = torch.randint(2,self.max_length,size=(1,))
input_ids = torch.randint(0,self.vocab_size,size=(self.batch_size,length),device='cpu')
yield input_ids
teacher&student模型
teacher和student模型都是语言模型,输入是token序列,输出是token序列的词表概率,也就是输出一个 shape = [batch_size, seq_length, vocab_size] 的浮点张量。这里我们假设参数就假设只有一个lm head头。
class Model(torch.nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.lm_head = torch.nn.Parameter(torch.randn(1, vocab_size, dtype=torch.float16))
def forward(self,input_ids:torch.Tensor):
logits = input_ids.unsqueeze(-1).to(self.lm_head.dtype) @ self.lm_head
probs = logits.softmax(-1)
return probs
分布式优化器
我们要实现一个分布式优化器用来更新student模型的参数。因为我们使用的是数据并行,梯度分散在每张卡上,在更新模型参数前,需要先进行一次all_reduce,把所有梯度加在一起。这里优化器需要传入一个通信组参数,包含所有学生rank,因为teacher模型没有优化器,不参与all_reduce。
class DistrubutedAdam(torch.optim.Adam):
def __init__(self, *args, group=None, **kwargs):
self.group = group
super().__init__(*args, **kwargs)
def step(self, closure=None):
if closure is not None:
closure.mean().backward()
for group in self.param_groups:
for param in group['params']:
if param.grad is not None:
torch.distributed.all_reduce(param.grad, group=self.group)
super().step()
创建通信组
根据我们的任务目标,我们需要2种通信组:
-
数据并行的通信组,用于区分模型的角色(student还是teacher),以便优化器只对当前角色的全部模型进行规约。 -
teacher-student通信组,用于标记需要相互之间传递数据的teacher和student。这个通信组在本例里这种简单通信设定下不是必要的,不创建也可以。
我们这样规定我们的rank划分方式:
-
world_size 总数要求是偶数,前一半用来当student,后一半是teacher。 -
student0和teacher0配对通信数据,student1与teacher1配对通信数据,…
通信组的创建方式如下:
data_parallel_group = None
teacher_student_group = None
def create_group():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert world_size % 2 == 0
student_ranks = list(range(world_size // 2))
teacher_ranks = list(range(world_size // 2, world_size))
global data_parallel_group
global teacher_student_group
for ranks in [student_ranks,teacher_ranks]:
group = torch.distributed.new_group(ranks, backend='nccl')
if rank in ranks:
data_parallel_group = group
for ranks in zip(student_ranks,teacher_ranks):
group = torch.distributed.new_group(ranks, backend='nccl')
if rank in ranks:
teacher_student_group = group
数据发送&接收
teacher模型计算出的prob要发送给student模型。每条数据只需被一个teacher模型计算,teacher模型计算结果也只需要发给一个student,所以这里用p2p通信。我们的输入序列长度是随机变化的,所以这里需要用到上面提到的动态 shape 的 p2p通信。之前说p2p算子用的是send和recv,但是这两个算子是同步算子,这里我们用异步算子。
tensor发送
先发送shape,再发送真实tensor:
def send_tensor(tensor, dst):
shape = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda')
ops = []
ops.append(torch.distributed.P2POp(
torch.distributed.isend,
shape,
dst,
group=teacher_student_group
))
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
ops = []
ops.append(torch.distributed.P2POp(
torch.distributed.isend,
tensor,
dst,
group=teacher_student_group
))
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
tensor接收
这里实现的简单点,tensor的ndim和dtype作为参数传入
def recv_tensor(src, ndim, dtype):
shape = torch.empty(ndim, dtype=torch.int64, device='cuda')
ops = []
ops.append(torch.distributed.P2POp(
torch.distributed.irecv,
shape,
src,
group=teacher_student_group
))
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
tensor = torch.empty(torch.Size(shape), dtype=dtype, device='cuda')
ops = []
ops.append(torch.distributed.P2POp(
torch.distributed.irecv,
tensor,
src,
group=teacher_student_group
))
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor
注:异步p2p操作是先创建算子(op),再批量执行(batch_isend_irecv),可以增加并行度。比如在megatron的pp平行中,向后一张卡发送计算结果和接收后一张卡回传梯度是同时进行的。
主函数
上面各个重要模块的功能已经实现完了,最后还剩下主函数。首先还是环境初始化,设置默认cuda设备。然后创建通信组,定义vocab size为20,计算teacher rank的 offset:
if __name__ == '__main__':
rank = int(os.getenv('RANK','0'))
world_size = int(os.getenv('WORLD_SIZE','1'))
torch.distributed.init_process_group(rank=rank, world_size=world_size, backend='nccl')
devices = torch.cuda.device_count()
torch.cuda.set_device(rank % devices)
create_group()
vocab_size = 20
teacher_offset = world_size // 2
下面是student的逻辑:
-
rank小于 offset的都是student,首先设置随机种子,然后初始化数据集、模型、优化器。 -
load到input_ids之后,student先发送给teacher,然后再开始计算,计算得到prob后再从teacher接收tensor,来实现student和teacher的重叠。 -
计算kl散度作为loss,反向传播,然后调用optimizer.step更新模型参数。 -
打印变量,让我们看到每个rank的loss和总的平均loss,以及打印当前模型的前两个参数。 -
loss在下降,且每个rank模型参数都始终是相同的我们的代码才算是正确的。
if rank < teacher_offset:
# student
torch.random.manual_seed(1)
dataloader = Dataloader(1,200,vocab_size)
model = Model(vocab_size).half().cuda()
optimizer = DistrubutedAdam(model.parameters(), lr=1e-2, group=data_parallel_group, eps=1e-4)
for i,input_ids in enumerate(dataloader):
if i % teacher_offset != rank:
continue
optimizer.zero_grad()
input_ids = input_ids.cuda()
send_tensor(input_ids, rank + teacher_offset)
student_probs = model(input_ids)
teacher_probs = recv_tensor(rank + teacher_offset, 3, torch.float16)
kl_loss = teacher_probs * ((teacher_probs + 1e-5).log() - (student_probs + 1e-5).log())
kl_loss = kl_loss.sum(-1).mean() / torch.distributed.get_world_size(data_parallel_group)
kl_loss.backward()
optimizer.step()
reporting_kl_loss = kl_loss.clone()
torch.distributed.all_reduce(reporting_kl_loss, group=data_parallel_group)
print(f'rank:{rank} reporting kl loss:{reporting_kl_loss} kl loss:{kl_loss} weight:{model.lm_head.data[0,:2]}',flush=True)
torch.distributed.barrier(group=data_parallel_group)
if i >= 10:
break
else:
# teacher
下面是teacher
一般所有rank的随机种子我们都设成一样的就行,这里故意把teacher的随机种子设成不一样的,避免teacher和student计算出来的prob完全一致,kl始终为0。
if rank < teacher_offset:
# student
else:
# teacher
torch.random.manual_seed(2)
model = Model(vocab_size).half().cuda()
model.eval()
while True:
input_ids = recv_tensor(rank - teacher_offset, 2, torch.int64)
teacher_probs = model(input_ids)
send_tensor(teacher_probs, rank - teacher_offset)
运行后观察到这样的输出:把训练10条数据退出的逻辑去掉,应该能看到loss很快降到1e-4以下,这样demo就算完成了。
这个demo并不是最优的蒸馏框架。首先teacher和student的参数量不一样,且teacher不进行反向传播,因此两者的计算速度不一样,一个teacher配一个student效率并不一定高。可能会需要几个teacher对1个student,或者1个teacher对几个student的情况,这是完全体框架要实现的。其次,目前主流模型的词表大小在15万左右,训练数据的长度一般是8k,也就是最后的teacher_probs 是一个 8k * 150k的float张量,通信成本太高。一种优化策略是sample一部分,比如取top-n,或者放回、不放回采样。另一个策略是把teacher的lm_head层放到student所在的rank。在计算logits之前,把teacher的hidden_states通信过来,在student本地乘lm_head算出logits。
register_hook
register_hook虽然不是分布式相关的功能,但基本每个框架都会用到。register_hook 的作用是在参数或算子上注册一个回调函数,当该参数或算子的梯度计算完成,但还没有赋值给grad的时候调用。如果回调函数有返回值,会使用返回值替换原本的梯度。
import torch
def print_grad(grad):
print(grad)
return grad / 2
w = torch.nn.Parameter(torch.randn(2, 2))
w.register_hook(print_grad)
loss = (w - 1) ** 2
print('before backward')
loss.mean().backward()
print('after backward')
print(w.grad)
上面这段代码会打印出下面这样的内容:用0替换掉梯度里的nan值是一些文章介绍register_hook给出的例子,但是实际编程我不推荐这么做。我建议遇到nan直接抛出异常,不要改成某个安全值然后更新模型,否则模型出了问题完全无法定位。
register_hook不仅可以把回调函数注册在参数上,还可以注册在算子上,这也是各个框架对register_hook的主要用法。比如下面这个操作,就是注册在了加法算子上:
import torch
def parameter_hook(grad):
print('parameter hook')
def operator_hook(*grads):
print('operator hook' )
w = torch.nn.Parameter(torch.randn(2, 2))
w.register_hook(parameter_hook)
print('first')
y = w + 1
op1 = y.grad_fn
print(op1)
op1.register_hook(operator_hook)
y.sum().backward()
print('second')
z = w + 1
op2 = z.grad_fn
print(op2)
z.sum().backward()
运行时你会看到如下结果:算子一般都是一次性的,且是先执行算子的回调再执行参数的回调。但是有一个特殊的算子是梯度累积算子,它的回调函数发生在参数的回调函数之后,且这个算子不会每次都创建新的。
import torch
def parameter_hook(grad):
print('parameter hook')
def operator_hook(*grads):
print('operator hook' )
w = torch.nn.Parameter(torch.randn(2, 2))
w.register_hook(parameter_hook)
y = w + 1
op = y.grad_fn.next_functions[0][0]
print(op)
op.register_hook(operator_hook)
print('first')
y.sum().backward()
print('second')
z = w + 1
op2 = z.grad_fn.next_functions[0][0]
print(op2)
z.sum().backward()
很多框架会围绕着梯度累计算子的这个特性展开。为了获得梯度累积算子,需要创建一个计算图。一般用expand_as,这个计算结果的grad_fn指向的是expand_as自己,next_functions指向的是上一个算子,也就是梯度累积算子:
grad_acc_op = w.expand_as(w).grad_fn.next_functions[0][0]
然后可以利用闭包注册一个hook,让hook能够直接访问参数而不仅仅是梯度。
def make_grad_hook(param):
def hook(*grads):
print(param.grad)
return hook
grad_acc_op.register_hook(make_grad_hook(w))
这样就可以做一些骚操作了。
我们可以简单看下megatron和deepspeed的代码,看看他们是怎么用的。下面这段是megatron的用法:megatron的优化器并不使用param.grad,而是自己在参数上注册了一个param.main_grad,用它来累积梯度。这个main_grad不会删除,只会累积和清0。然后在最后一个microbatch的时候做allreduce,而不是等优化器来做:deepspeed也有类似的逻辑:deepspeed因为要支持连续内存和cpu offload,逻辑更加复杂,会等积攒的grad数量够了一批一批的操作:会使用不同事件流(stream)来增加计算和梯度累积的重叠度:使用了锁页内存,并且会在cpu和gpu之间来回来去传递张量来进行计算。有兴趣的读者可以自己研究一下,deepspeed的代码我一直没机会仔细读一遍。
多进程
在理解torch分布式训练时,多进程这个概念是一直伴随我们左右的。使用torchrun启动脚本,就是以多进程方式启动脚本。这里我们还可以再深入了解一下torch与多进程。
首先来看看python原生的多进程启动方式:
import multiprocessing as mp
def main(rank, world):
print(rank, world)
if __name__ == '__main__':
world_size = 4
ps = [mp.Process(None, main, args=(rank, world_size)) for rank in range(world_size)]
for p in ps:
p.start()
for p in ps:
p.join()
这种多进程能不能拿来作为torch的分布式训练环境呢?当然是可以的,只需要这样操作:
import multiprocessing as mp
import torch
import torch.distributed
def main(rank, world_size, master_addr='127.0.0.1', master_port=29500):
init_method = f'tcp://{master_addr}:{master_port}'
torch.distributed.init_process_group(rank=rank,world_size=world_size,init_method=init_method,backend='nccl')
print(f'rank:{torch.distributed.get_rank()} world_size:{torch.distributed.get_world_size()}')
if __name__ == '__main__':
world_size = 4
ps = [mp.Process(None, main, args=(rank, world_size)) for rank in range(world_size)]
for p in ps:
p.start()
for p in ps:
p.join()
这里init_method就是用来在初始化阶段,实现进程发现的方法,除了tcp,还可以用本地文件发现或者环境变量。除了multiprocessing 的多进程,用subprocess 的多进程也是一样可以的。当然torch也提供了一种启动多进程的方法:
import torch
def main(rank, world_size, master_addr='127.0.0.1', master_port=29500):
init_method = f'tcp://{master_addr}:{master_port}'
torch.distributed.init_process_group(rank=rank,world_size=world_size,init_method=init_method,backend='nccl')
print(f'rank:{torch.distributed.get_rank()} world_size:{torch.distributed.get_world_size()}')
if __name__ == '__main__':
world_size = 4
torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size)
你也可以在多进程里再创建子进程再初始化环境:
import multiprocessing as mp
import torch
import torch.distributed
def sub_process(rank, world_size, master_addr='127.0.0.1', master_port=29500):
init_method = f'tcp://{master_addr}:{master_port}'
torch.distributed.init_process_group(rank=rank, world_size=world_size, init_method=init_method, backend='nccl')
torch.distributed.barrier()
print(f'rank:{torch.distributed.get_rank()} world_size:{torch.distributed.get_world_size()}')
def main(rank, world_size, master_addr='127.0.0.1', master_port=29500):
init_method = f'tcp://{master_addr}:{master_port}'
process = mp.Process(None, sub_process, args=(rank + world_size, 2 * world_size,))
process.start()
torch.distributed.init_process_group(rank=rank, world_size=2*world_size, init_method=init_method, backend='nccl')
torch.distributed.barrier()
print(f'rank:{torch.distributed.get_rank()} world_size:{torch.distributed.get_world_size()}')
if __name__ == '__main__':
world_size = 2
ps = [mp.Process(None, main, args=(rank, world_size)) for rank in range(world_size)]
for p in ps:
p.start()
for p in ps:
p.join()
在torch的多进程中,再次启动子进程有一点需要注意的地方。那就是如果在启动子进程之前触发了任何与cuda相关的操作,比如使用了set_device,或者在cuda上创建了一个张量,那么子进程中就不能再使用cuda。比如下面这段代码:
import multiprocessing as mp
import torch
import torch.distributed
def sub_process():
tensor = torch.tensor([2]).cuda(0)
if __name__ == '__main__':
torch.cuda.set_device(0)
process = mp.Process(None,sub_process)
process.start()
process.join()
运行时会报错:这个报错的意思是,cuda环境只能初始化一次,并且与进程绑定。Linux上创建的子进程默认使用的是fork的方式。fork创建的子进程会继承父进程的内存空间,因此已经绑定了父进程的cuda环境被继承给了子进程,子进程使用cuda就会报错。报错中要求子进程以spawn方式启动,是因为spawn方式启动的子进程使用的是全新的解释器,cuda还处于未初始化的状态。
这里用的时候需要权衡清楚,究竟需不需要子进程继承父进程的内存,以及是否需要在子进程使用cuda。子进程如果用spawn方式启动不继承父进程,可能需要单独初始化分布式环境,父进程的全局变量子进程也用不了。如果用fork方式启动继承父进程内存,意味着继承了父进程创建的各种变量,以及父进程初始化过的分布式环境但是不能用cuda。另外需要注意的是,就算用fork方式启动,子进程也继承不了父进程创建的通信组,但是会继承“通信组的创建顺序”。意思是如果rank0顺序创建了5个group,rank1创建的3个group,然后用fork方式启动了一个子进程,子进程又创建了2个,这2个会去对应rank0的第4、5个group。
说了这么多,多进程好像挺麻烦,那他相比torchrun有啥好处呢?
好处就是可以更加灵活的使用init_process_group初始化环境,以区分不同角色。比如上面的我们的demo中的这个蒸馏场景,我们是4个rank,分成2个student2个teacher,通信还不是很复杂。那如果student和teacher不是各占一卡,而是用了3d混合并行占很多卡,相互之间还有tp、pp的通信,通信逻辑就很复杂了。我们可以考虑给teacher和student分别用一套不同的rank、world_size个init_method初始化,让他们在这个分布式环境中只能看见自己这个角色的进程,这样就只需要实现自己的3d混合并行就可以了。
再比如如果你想在自己的训练环境中引入一个VLLM模型。VLLM内部是会调用init_process_group创建自己的环境,使用自己的rank和world_size来实现TP并行的,和你训练环境的init_process_group是冲突的。这个时候使用自定义的多进程,就可以减小VLLM的干扰。
问题来了,每个角色都使用独立的分布式环境,相互之间怎么通信呢?这就是最后的部分,TorchRPC。
RPC
TorchRPC原本的用法是在本地创建远程变量的引用,在本地调用远程函数。但是我觉得这种编程不灵活且抽象,堪比tf的静态图,所以我这里不把rpc当作远程调用,只把他当作对p2p算子的封装,以及除init_process_group之外第二种建立rank间通信的方式。
rpc的初始化和分布式环境很像。rpc的初始化和init_process_group可以同时存在,且可以使用不同的rank和world_size:
import multiprocessing as mp
import torch
def main(rank,world_size):
torch.distributed.init_process_group(rank=0,world_size=1,backend='nccl',init_method=f'tcp://127.0.0.1:{29500+rank}')
options = torch.distributed.rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30001')
torch.distributed.rpc.init_rpc(f'worker-{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
print(f'rank: {torch.distributed.get_rank()}',
f' world_size: {torch.distributed.get_world_size()}',
f' {torch.distributed.rpc.get_worker_info()}')
torch.distributed.rpc.shutdown()
if __name__ == '__main__':
world_size = 4
ps = [mp.Process(None,main,args=(rank,world_size)) for rank in range(world_size)]
for p in ps:
p.start()
for p in ps:
p.join()
直接用python运行,打印出如下内容:这里需要注意一点,因为我们每个rank都独立各自初始化分布式环境,互不干扰,因此init_method的port要换一下。
我们在简单重写一下之前的蒸馏demo,主要为了演示在这种用法下如何使用rpc通信。
首先定义模型和模型的调用函数,再定义一个全局变量
import multiprocessing as mp
import torch
import torch.distributed
from torch.distributed import rpc
class Model:
def __call__(self, tensor) -> torch.Any:
return tensor + 1
def call_model(tensor):
return model(tensor)
model = None
定义teacher的逻辑:
def teacher(rank,world_size):
torch.distributed.init_process_group(rank=0,world_size=1,backend='nccl',init_method=f'tcp://127.0.0.1:{29500+rank}')
options = rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000')
global model
model = Model()
rpc.init_rpc('teacher', rank=rank, world_size=world_size, rpc_backend_options=options)
rpc.shutdown()
先初始化分布式环境,然后创建模型,赋值给全局变量,然后再初始化rpc,确保rpc初始化后模型一定已经准备好了,最后shutdown等待。这里初始化分布式环境只是模拟一下,RPC本身并不依赖分布式环境。
然后定义student的逻辑:
def student(rank,world_size):
torch.distributed.init_process_group(rank=0,world_size=1,backend='nccl',init_method=f'tcp://127.0.0.1:{29500+rank}')
options = rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000')
rpc.init_rpc('student', rank=rank, world_size=world_size, rpc_backend_options=options)
input_ids = torch.randn(4)
teacher_probs = rpc.rpc_async('teacher', call_model, args=(input_ids,))
# 这里student计算
student_probs = input_ids
loss = teacher_probs.wait() - student_probs
print(loss)
rpc.shutdown()
student也是先初始化分布式环境和rpc,然后模拟一下输入数据input_ids。然后通过rpc.rpc_async异步远程调用teacher进程的call_model函数。此时teacher进程的全局变量应该已经有值了,call_model可以正常返回。这里使用异步调用,不需要等待结果,直接继续。下面就是假装student在计算,得到student_probs,然后计算一下差值,使用teacher_probs.wait() 等待远程调用的结果,不出意外应该等于全1向量。
最后是启动和划分角色的代码:
def main(rank, world_size):
teacher_offset = world_size // 2
if rank < teacher_offset:
student(rank, world_size)
else:
teacher(rank, world_size)
if __name__ == '__main__':
world_size = 2
ps = [mp.Process(None,main,args=(rank,world_size)) for rank in range(world_size)]
for p in ps:
p.start()
for p in ps:
p.join()
RPC后端对NVlink、IB等都是支持的,也支持传递cuda张量,只需要在初始化rpc环境时,指定一下本机cuda和远程cuda的映射。
比如你可以在配置rpc后端时这样设置:
rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000', device_maps={'teacher':{0:1}})
这个表示把teacher进程的cuda1映射到本地的cuda0,这样本地cuda0张量通信到远端时就会被放到cuda1,不需要移动到cpu。
最后再说一下,RPC的官方用法是远程调用和远程引用,可以去看官网教程。
后记
目前已经在Megatron-RPC框架下实现了SFT、DPO、Distillation和on-policy RS,性能持平P2P通信,远超一些非IB、NVlink的通信方案,证明RPC方案确实可行。RPC是torch1.4版本就引入的特性,支持远程引用(本地创建一个模型,但是占用远程机器的显存)、链式异步调用(以异步函数的方式调用远程模型,且调用过程中支持继续调用其他远程模型)和自动求导(远程调用返回的结果可以求导,梯度传递给远程模型)。不得不感慨torch确实有前瞻性,以前都没太关注过这个特性。
(文:Datawhale)