在PyTorch的深度学习开发中,模型的保存和加载是训练流程的重要收尾环节,而state_dict作为存储模型所有可学习参数(如权重、偏置)的字典对象,是实现权重序列化的标准方案。它不包含模型结构本身,只记录参数的张量数据,因此灵活性和安全性都更高。

什么是state_dict
state_dict是Python字典类型的对象,键是模型各层的参数名称,值是对应参数的张量数据。无论是完整的神经网络模型,还是单个优化器,都有对应的state_dict属性。
我们可以通过简单代码查看模型和优化器的state_dict结构:
import torch
import torch.nn as nn
# 定义一个简单的全连接网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 初始化模型和优化器
model = SimpleNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 打印模型state_dict的键
print("模型state_dict键:")
for key in model.state_dict().keys():
print(key)
# 打印优化器state_dict的键
print("n优化器state_dict键:")
for key in optimizer.state_dict().keys():
print(key)
使用state_dict保存模型
保存模型权重只需要调用torch.save()函数,传入模型的state_dict即可,通常保存为.pth或.pt格式的文件。
仅保存模型权重
最常用的保存方式是只保存模型的state_dict,这种方式占用空间小,后续加载时可以灵活适配模型结构的变化。
# 保存模型state_dict到文件
torch.save(model.state_dict(), "simple_net_weights.pth")
print("模型权重保存完成")
同时保存模型和优化器状态
如果需要中断训练后恢复继续训练,还需要保存优化器的state_dict以及当前的epoch、损失等信息。
# 假设当前训练到epoch 10,损失为0.35
epoch = 10
loss = 0.35
# 保存所有相关信息到字典
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss
}
# 保存检查点文件
torch.save(checkpoint, "checkpoint.pth")
print("训练检查点保存完成")
使用state_dict加载模型
加载模型时需要先定义和保存时一致的模型结构,再调用load_state_dict()方法加载权重参数。
加载仅保存的模型权重
加载前需要先初始化模型实例,然后加载state_dict,注意strict参数的使用:默认值为True,要求保存的键和模型的键完全匹配;如果模型结构有微调,可以设为False忽略不匹配的键。
# 重新初始化模型结构
loaded_model = SimpleNet()
# 加载保存的权重
loaded_model.load_state_dict(torch.load("simple_net_weights.pth"))
# 设置为评估模式,避免dropout、batchnorm等层影响推理结果
loaded_model.eval()
print("模型权重加载完成,已切换为评估模式")
加载训练检查点恢复训练
加载检查点时需要分别恢复模型、优化器的状态,以及训练进度相关的参数。
# 重新初始化模型和优化器
model = SimpleNet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 加载检查点文件
checkpoint = torch.load("checkpoint.pth")
# 恢复模型和优化器状态
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# 恢复训练进度参数
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
# 设置为训练模式
model.train()
print(f"已恢复训练,当前epoch为{epoch},上次损失为{loss}")
注意事项
- 加载权重前必须保证模型结构和保存时的结构一致,否则会出现键不匹配的错误。
- 推理阶段加载权重后一定要调用
eval()方法,否则batchnorm、dropout等层会保持训练状态的逻辑,导致推理结果异常。 - 保存和加载路径要确保有读写权限,跨设备加载时如果保存时使用了GPU,加载到CPU需要先调用
torch.load("file.pth", map_location=torch.device("cpu"))。 - state_dict只保存参数数据,不保存模型类的代码,因此加载时对应的模型类定义必须存在,否则会报错。
PyTorchstate_dict模型保存模型加载序列化权重修改时间:2026-07-01 07:33:22