导读:本期聚焦于小伙伴创作的《Python中PyTorch如何保存和加载模型?使用state_dict序列化权重参数该怎么做》,敬请观看详情,探索知识的价值。以下视频、文章将为您系统阐述其核心内容与价值。如果您觉得《Python中PyTorch如何保存和加载模型?使用state_dict序列化权重参数该怎么做》有用,将其分享出去将是对创作者最好的鼓励。

在PyTorch的深度学习开发中,模型的保存和加载是训练流程的重要收尾环节,而state_dict作为存储模型所有可学习参数(如权重、偏置)的字典对象,是实现权重序列化的标准方案。它不包含模型结构本身,只记录参数的张量数据,因此灵活性和安全性都更高。

Python中PyTorch如何保存和加载模型?使用state_dict序列化权重参数该怎么做

什么是state_dict

state_dict是Python字典类型的对象,键是模型各层的参数名称,值是对应参数的张量数据。无论是完整的神经网络模型,还是单个优化器,都有对应的state_dict属性。

我们可以通过简单代码查看模型和优化器的state_dict结构:

import torch
import torch.nn as nn

# 定义一个简单的全连接网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 初始化模型和优化器
model = SimpleNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 打印模型state_dict的键
print("模型state_dict键:")
for key in model.state_dict().keys():
    print(key)

# 打印优化器state_dict的键
print("n优化器state_dict键:")
for key in optimizer.state_dict().keys():
    print(key)

使用state_dict保存模型

保存模型权重只需要调用torch.save()函数,传入模型的state_dict即可,通常保存为.pth或.pt格式的文件。

仅保存模型权重

最常用的保存方式是只保存模型的state_dict,这种方式占用空间小,后续加载时可以灵活适配模型结构的变化。

# 保存模型state_dict到文件
torch.save(model.state_dict(), "simple_net_weights.pth")
print("模型权重保存完成")

同时保存模型和优化器状态

如果需要中断训练后恢复继续训练,还需要保存优化器的state_dict以及当前的epoch、损失等信息。

# 假设当前训练到epoch 10,损失为0.35
epoch = 10
loss = 0.35

# 保存所有相关信息到字典
checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss
}

# 保存检查点文件
torch.save(checkpoint, "checkpoint.pth")
print("训练检查点保存完成")

使用state_dict加载模型

加载模型时需要先定义和保存时一致的模型结构,再调用load_state_dict()方法加载权重参数。

加载仅保存的模型权重

加载前需要先初始化模型实例,然后加载state_dict,注意strict参数的使用:默认值为True,要求保存的键和模型的键完全匹配;如果模型结构有微调,可以设为False忽略不匹配的键。

# 重新初始化模型结构
loaded_model = SimpleNet()

# 加载保存的权重
loaded_model.load_state_dict(torch.load("simple_net_weights.pth"))
# 设置为评估模式,避免dropout、batchnorm等层影响推理结果
loaded_model.eval()

print("模型权重加载完成,已切换为评估模式")

加载训练检查点恢复训练

加载检查点时需要分别恢复模型、优化器的状态,以及训练进度相关的参数。

# 重新初始化模型和优化器
model = SimpleNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 加载检查点文件
checkpoint = torch.load("checkpoint.pth")

# 恢复模型和优化器状态
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

# 恢复训练进度参数
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

# 设置为训练模式
model.train()

print(f"已恢复训练,当前epoch为{epoch},上次损失为{loss}")

注意事项

  • 加载权重前必须保证模型结构和保存时的结构一致,否则会出现键不匹配的错误。
  • 推理阶段加载权重后一定要调用eval()方法,否则batchnorm、dropout等层会保持训练状态的逻辑,导致推理结果异常。
  • 保存和加载路径要确保有读写权限,跨设备加载时如果保存时使用了GPU,加载到CPU需要先调用torch.load("file.pth", map_location=torch.device("cpu"))
  • state_dict只保存参数数据,不保存模型类的代码,因此加载时对应的模型类定义必须存在,否则会报错。

PyTorchstate_dict模型保存模型加载序列化权重修改时间:2026-07-01 07:33:22

免责声明:​ 已尽一切努力确保本网站所含信息的准确性。网站内容多为原创整理与精心编撰,观点力求客观中立。本站旨在免费分享,内容仅供个人学习、研究或参考使用。若引用了第三方作品,版权归原作者所有。如内容涉及您的权益,请联系我们处理。
内容垂直聚焦
专注技术核心技术栏目,确保每篇文章深度聚焦于实用技能。从代码技巧到架构设计,为用户提供无干扰的纯技术知识沉淀,精准满足专业提升需求。
知识结构清晰
覆盖从开发到部署的全链路。AI、前端、编程、数据库、服务器、建站、系统层层递进,构建清晰学习路径,帮助用户系统化掌握开发与运维所需的核心技术。
深度技术解析
拒绝泛泛而谈,深入技术细节与实践难点。无论是数据库优化还是服务器配置,均结合真实场景与代码示例进行剖析,致力于提供可直接应用于工作的解决方案。
专业领域覆盖
精准对应开发生命周期。从前端界面到后端编程,从数据库操作到服务器运维,形成完整闭环,一站式满足全栈工程师和运维人员的技术需求。
即学即用高效
内容强调实操性,步骤清晰、代码完整。用户可根据教程直接复现和应用于自身项目,显著缩短从学习到实践的距离,快速解决开发中的具体问题。
持续更新保障
专注既定技术方向进行长期、稳定的内容输出。确保各栏目技术文章持续更新迭代,紧跟主流技术发展趋势,为用户提供经久不衰的学习价值。