在PyTorch的模型开发流程中,训练阶段通常会保存包含模型参数、优化器状态、当前训练轮次等信息的完整检查点,方便后续断点续训。但当模型训练完成进入推理部署阶段时,我们只需要模型本身的参数即可,优化器状态、学习率调度器状态等训练相关的冗余信息完全可以移除,这样能大幅减少检查点文件的体积,节省磁盘存储空间。

完整检查点与推理检查点的区别
首先我们需要明确两种检查点的内容差异,才能理解为什么移除优化器状态可以节省空间。完整训练检查点通常包含以下内容:
- 模型状态字典:存储所有模型层的参数,是推理必须的内容
- 优化器状态字典:存储优化器的动量、缓存等参数,仅训练时需要
- 训练轮次与损失值:记录训练进度,推理时完全无用
- 学习率调度器状态:训练时调整学习率使用,推理不需要
其中优化器状态字典的体积往往和模型参数字典相当,甚至更大,移除这部分内容后,检查点文件大小通常能减少50%以上。
保存推理检查点的核心方法
PyTorch中模型参数都存储在模型的state_dict()方法中,我们只需要单独保存这个字典,就可以得到仅包含推理所需参数的检查点,完全不需要包含优化器相关内容。
1. 基础保存方法
下面是训练完成后保存推理检查点的基础代码示例:
import torch
import torch.nn as nn
# 定义一个简单的测试模型
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# 初始化模型和优化器
model = TestModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 模拟训练过程,给模型赋值一些参数
dummy_input = torch.randn(1, 10)
output = model(dummy_input)
loss = output.sum()
loss.backward()
optimizer.step()
# 保存推理检查点:只保存模型的状态字典,不包含优化器状态
infer_checkpoint = model.state_dict()
torch.save(infer_checkpoint, "infer_checkpoint.pth")
# 对比:保存完整训练检查点,包含优化器状态
full_checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": 10,
"loss": loss.item()
}
torch.save(full_checkpoint, "full_checkpoint.pth")
2. 加载推理检查点的方法
保存的推理检查点加载时也非常简单,只需要把状态字典加载到模型中即可:
# 初始化新模型
new_model = TestModel()
# 加载推理检查点
infer_checkpoint = torch.load("infer_checkpoint.pth")
new_model.load_state_dict(infer_checkpoint)
# 设置为推理模式
new_model.eval()
# 测试推理
test_input = torch.randn(1, 10)
with torch.no_grad():
result = new_model(test_input)
print("推理结果:", result)
两种保存方式的文件大小对比
我们可以通过实际测试查看两种检查点的体积差异,以上面的简单模型为例,测试结果如下:
| 检查点类型 | 包含内容 | 文件大小 |
|---|---|---|
| 完整训练检查点 | 模型参数、优化器状态、训练轮次、损失值 | 约4KB |
| 推理检查点 | 仅模型参数 | 约2KB |
如果是参数量更大的模型,比如常见的ResNet50,完整检查点可能达到200MB以上,而推理检查点只需要100MB左右,节省的空间非常可观。
注意事项
- 保存推理检查点前,建议调用
model.eval()将模型设置为推理模式,避免Dropout、BatchNorm等层在推理时出现不符合预期的行为,不过这一步不影响检查点文件的大小,只影响后续推理结果的正确性。 - 如果模型使用了自定义层或者需要保存其他推理相关的配置(比如类别标签映射),可以在保存时额外添加这些轻量信息,但不要加入优化器相关的内容。
- 加载推理检查点时,确保模型的结构和保存时的结构完全一致,否则会出现参数加载失败的错误。
总结来说,PyTorch中保存推理检查点只需要提取模型自身的state_dict()进行保存,完全不需要包含优化器的状态字典,这种方式既能满足推理需求,又能最大程度节省磁盘空间,是模型部署前的推荐操作。