在PyTorch分布式训练的实际落地过程中,多进程并行执行任务时如果都触发模型保存操作,大概率会出现文件覆盖、写入中断等问题,因此仅让主进程执行保存是更稳妥的方案。

PyTorch分布式训练的进程标识
PyTorch的分布式训练框架会给每个参与训练的进程分配唯一的rank值,其中rank=0的进程就是默认的主进程。我们可以通过torch.distributed模块获取当前进程的rank信息,以此作为判断主进程的依据。
首先需要初始化分布式环境,初始化完成后才能正确获取rank值,初始化的代码通常放在训练脚本的最开始部分:
import torch
import torch.distributed as dist
def init_distributed():
# 初始化进程组,使用nccl后端适合GPU训练,cpu训练可以用gloo
dist.init_process_group(backend='nccl')
# 获取当前进程的rank
local_rank = int(dist.get_rank())
return local_rank
if __name__ == '__main__':
rank = init_distributed()
print(f"当前进程rank为: {rank}")
仅在主进程保存模型的实现逻辑
判断当前进程是主进程之后,只需要把模型保存的代码放在对应的条件判断分支中即可。需要注意的是,模型保存前最好先同步所有进程的状态,避免主进程保存时其他进程还在更新模型参数。
基础保存示例
下面是一个完整的分布式训练片段,包含主进程判断和模型保存逻辑:
import torch
import torch.distributed as dist
import os
def save_model_on_main_process(model, save_path, epoch):
# 先同步所有进程,确保所有进程的参数更新都已完成
dist.barrier()
# 判断当前是否为rank 0的主进程
if dist.get_rank() == 0:
# 创建保存目录,避免目录不存在报错
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# 保存模型参数,推荐保存state_dict而不是整个模型
torch.save(model.state_dict(), save_path)
print(f"第{epoch}轮模型已保存到{save_path}")
# 再次同步,避免主进程保存完成后其他进程还在等待
dist.barrier()
# 模拟训练流程
if __name__ == '__main__':
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
# 模拟模型
model = torch.nn.Linear(10, 2)
# 模拟训练循环
for epoch in range(10):
# 模拟训练逻辑
# ...
# 每5轮保存一次模型
if epoch % 5 == 0:
save_path = f"./checkpoints/model_epoch_{epoch}.pth"
save_model_on_main_process(model, save_path, epoch)
dist.destroy_process_group()
多机训练的主进程判断
如果是多机分布式训练,有时候需要区分全局主进程和每台机器的本地主进程,此时可以用dist.get_rank()获取全局rank,用os.environ.get('LOCAL_RANK')获取本地rank:
import os
import torch.distributed as dist
def is_global_main_process():
return dist.get_rank() == 0
def is_local_main_process():
return int(os.environ.get('LOCAL_RANK', 0)) == 0
# 多机场景下如果只想让第一台机器的主进程保存,用全局主进程判断即可
if is_global_main_process():
print("这是全局主进程,执行保存操作")
相关注意事项
- 保存模型前务必调用
dist.barrier()同步所有进程,否则可能出现主进程保存了还在更新中的参数,导致模型不可用。 - 推荐使用
torch.save(model.state_dict(), path)的方式保存模型,而不是保存整个模型对象,这样后续加载时兼容性更好。 - 如果使用了
torch.nn.parallel.DistributedDataParallel包装模型,保存时可以直接保存model.module.state_dict(),也可以先获取原始模型再保存,两种方式都可行。 - 保存路径尽量不要使用相对路径,避免不同进程的当前工作目录不一致导致保存位置错误,建议使用绝对路径。
常见问题解答
如果忘记判断主进程会有什么后果
多进程同时写入同一个模型文件时,会出现文件内容错乱,保存的pth文件无法被正常加载,严重时会直接抛出文件写入异常导致训练中断。
主进程保存后需要通知其他进程吗
建议在主进程保存完成后再次调用dist.barrier()同步所有进程,让其他进程知道保存操作已经完成,避免出现时序问题。
分布式训练中的进程同步是非常重要的环节,除了模型保存,其他涉及文件读写、日志打印的操作也建议只在主进程执行,减少不必要的资源消耗和冲突风险。