导读:本期聚焦于小伙伴创作的《Python中PyTorch分布式训练时如何仅在主进程保存模型避免冲突》,敬请观看详情,探索知识的价值。以下视频、文章将为您系统阐述其核心内容与价值。如果您觉得《Python中PyTorch分布式训练时如何仅在主进程保存模型避免冲突》有用,将其分享出去将是对创作者最好的鼓励。

在PyTorch分布式训练的实际落地过程中,多进程并行执行任务时如果都触发模型保存操作,大概率会出现文件覆盖、写入中断等问题,因此仅让主进程执行保存是更稳妥的方案。

Python中PyTorch分布式训练时如何仅在主进程保存模型避免冲突

PyTorch分布式训练的进程标识

PyTorch的分布式训练框架会给每个参与训练的进程分配唯一的rank值,其中rank=0的进程就是默认的主进程。我们可以通过torch.distributed模块获取当前进程的rank信息,以此作为判断主进程的依据。

首先需要初始化分布式环境,初始化完成后才能正确获取rank值,初始化的代码通常放在训练脚本的最开始部分:

import torch
import torch.distributed as dist

def init_distributed():
    # 初始化进程组,使用nccl后端适合GPU训练,cpu训练可以用gloo
    dist.init_process_group(backend='nccl')
    # 获取当前进程的rank
    local_rank = int(dist.get_rank())
    return local_rank

if __name__ == '__main__':
    rank = init_distributed()
    print(f"当前进程rank为: {rank}")

仅在主进程保存模型的实现逻辑

判断当前进程是主进程之后,只需要把模型保存的代码放在对应的条件判断分支中即可。需要注意的是,模型保存前最好先同步所有进程的状态,避免主进程保存时其他进程还在更新模型参数。

基础保存示例

下面是一个完整的分布式训练片段,包含主进程判断和模型保存逻辑:

import torch
import torch.distributed as dist
import os

def save_model_on_main_process(model, save_path, epoch):
    # 先同步所有进程,确保所有进程的参数更新都已完成
    dist.barrier()
    # 判断当前是否为rank 0的主进程
    if dist.get_rank() == 0:
        # 创建保存目录,避免目录不存在报错
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        # 保存模型参数,推荐保存state_dict而不是整个模型
        torch.save(model.state_dict(), save_path)
        print(f"第{epoch}轮模型已保存到{save_path}")
    # 再次同步,避免主进程保存完成后其他进程还在等待
    dist.barrier()

# 模拟训练流程
if __name__ == '__main__':
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    # 模拟模型
    model = torch.nn.Linear(10, 2)
    # 模拟训练循环
    for epoch in range(10):
        # 模拟训练逻辑
        # ...
        # 每5轮保存一次模型
        if epoch % 5 == 0:
            save_path = f"./checkpoints/model_epoch_{epoch}.pth"
            save_model_on_main_process(model, save_path, epoch)
    dist.destroy_process_group()

多机训练的主进程判断

如果是多机分布式训练,有时候需要区分全局主进程和每台机器的本地主进程,此时可以用dist.get_rank()获取全局rank,用os.environ.get('LOCAL_RANK')获取本地rank:

import os
import torch.distributed as dist

def is_global_main_process():
    return dist.get_rank() == 0

def is_local_main_process():
    return int(os.environ.get('LOCAL_RANK', 0)) == 0

# 多机场景下如果只想让第一台机器的主进程保存,用全局主进程判断即可
if is_global_main_process():
    print("这是全局主进程,执行保存操作")

相关注意事项

  • 保存模型前务必调用dist.barrier()同步所有进程,否则可能出现主进程保存了还在更新中的参数,导致模型不可用。
  • 推荐使用torch.save(model.state_dict(), path)的方式保存模型,而不是保存整个模型对象,这样后续加载时兼容性更好。
  • 如果使用了torch.nn.parallel.DistributedDataParallel包装模型,保存时可以直接保存model.module.state_dict(),也可以先获取原始模型再保存,两种方式都可行。
  • 保存路径尽量不要使用相对路径,避免不同进程的当前工作目录不一致导致保存位置错误,建议使用绝对路径。

常见问题解答

如果忘记判断主进程会有什么后果

多进程同时写入同一个模型文件时,会出现文件内容错乱,保存的pth文件无法被正常加载,严重时会直接抛出文件写入异常导致训练中断。

主进程保存后需要通知其他进程吗

建议在主进程保存完成后再次调用dist.barrier()同步所有进程,让其他进程知道保存操作已经完成,避免出现时序问题。

分布式训练中的进程同步是非常重要的环节,除了模型保存,其他涉及文件读写、日志打印的操作也建议只在主进程执行,减少不必要的资源消耗和冲突风险。

PyTorch分布式训练模型保存主进程rank修改时间:2026-06-19 22:57:28

免责声明:​ 已尽一切努力确保本网站所含信息的准确性。网站内容多为原创整理与精心编撰,观点力求客观中立。本站旨在免费分享,内容仅供个人学习、研究或参考使用。若引用了第三方作品,版权归原作者所有。如内容涉及您的权益,请联系我们处理。
内容垂直聚焦
专注技术核心技术栏目,确保每篇文章深度聚焦于实用技能。从代码技巧到架构设计,为用户提供无干扰的纯技术知识沉淀,精准满足专业提升需求。
知识结构清晰
覆盖从开发到部署的全链路。AI、前端、编程、数据库、服务器、建站、系统层层递进,构建清晰学习路径,帮助用户系统化掌握开发与运维所需的核心技术。
深度技术解析
拒绝泛泛而谈,深入技术细节与实践难点。无论是数据库优化还是服务器配置,均结合真实场景与代码示例进行剖析,致力于提供可直接应用于工作的解决方案。
专业领域覆盖
精准对应开发生命周期。从前端界面到后端编程,从数据库操作到服务器运维,形成完整闭环,一站式满足全栈工程师和运维人员的技术需求。
即学即用高效
内容强调实操性,步骤清晰、代码完整。用户可根据教程直接复现和应用于自身项目,显著缩短从学习到实践的距离,快速解决开发中的具体问题。
持续更新保障
专注既定技术方向进行长期、稳定的内容输出。确保各栏目技术文章持续更新迭代,紧跟主流技术发展趋势,为用户提供经久不衰的学习价值。