在线训练模式下,模型需要持续接收新数据并更新参数,而灾难性遗忘指的是模型在学习新任务后,对之前已学习任务的表现出现大幅下滑的现象,这是持续学习领域需要解决的核心问题之一。

灾难性遗忘的产生原因
从模型参数更新的角度来看,灾难性遗忘主要是因为新任务的梯度更新会覆盖掉旧任务对应的参数空间。当新数据和旧数据的分布差异较大时,模型为了适配新数据,会调整大量参数,这些调整恰好破坏了旧任务的知识存储结构,最终导致旧任务性能下降。
常用的防止方法
1. 正则化约束方法
这类方法的核心是在损失函数中加入正则项,限制模型更新参数时偏离旧任务的参数空间过远,典型代表是弹性权重整合(EWC)方法。EWC会给对旧任务重要的参数赋予更高的惩罚权重,避免这些参数被大幅修改。
以下是EWC损失函数的简化实现代码:
import torch
import torch.nn as nn
import torch.optim as optim
class EWC_Loss(nn.Module):
def __init__(self, model, fisher_matrix, old_params, lambda_ewc=1000):
super(EWC_Loss, self).__init__()
self.model = model
self.fisher_matrix = fisher_matrix # 旧任务对应的Fisher信息矩阵
self.old_params = old_params # 旧任务训练完成后的参数
self.lambda_ewc = lambda_ewc # EWC惩罚项的权重
self.criterion = nn.CrossEntropyLoss() # 基础分类损失
def forward(self, outputs, labels):
# 计算基础分类损失
base_loss = self.criterion(outputs, labels)
# 计算EWC正则惩罚项
ewc_loss = 0
for name, param in self.model.named_parameters():
if param.requires_grad:
# 获取该参数对应的Fisher信息和旧参数值
fisher = self.fisher_matrix[name]
old_param = self.old_params[name]
# 累加惩罚项:(Fisher * (当前参数 - 旧参数)^2) 的和
ewc_loss += (fisher * (param - old_param).pow(2)).sum()
# 总损失 = 基础损失 + EWC惩罚项
total_loss = base_loss + self.lambda_ewc * ewc_loss
return total_loss
2. 记忆回放方法
记忆回放的思路是维护一个小的记忆缓冲区,存储旧任务的部分样本,在训练新任务时,混合旧样本和新样本一起输入模型,让模型同时学习新旧知识。这种方法实现简单,效果也比较稳定。
以下是一个简单的记忆缓冲区实现示例:
import random
class MemoryBuffer:
def __init__(self, capacity=1000):
self.capacity = capacity # 缓冲区最大容量
self.buffer = [] # 存储样本,每个样本为(数据, 标签)
def add_samples(self, samples, labels):
# 添加新样本到缓冲区
for data, label in zip(samples, labels):
if len(self.buffer) >= self.capacity:
# 缓冲区满了就随机移除一个旧样本
self.buffer.pop(random.randint(0, len(self.buffer)-1))
self.buffer.append((data, label))
def sample(self, batch_size=32):
# 从缓冲区随机采样指定数量的样本
if len(self.buffer) < batch_size:
return zip(*self.buffer)
batch = random.sample(self.buffer, batch_size)
datas, labels = zip(*batch)
return datas, labels
3. 参数隔离方法
参数隔离方法会为不同的任务分配不同的模型参数,避免新任务更新旧任务的参数。比如可以通过掩码的方式,标记出对旧任务重要的参数,在新任务训练时冻结这些参数,只更新剩余参数。
方法对比
以下是三种主流方法的特性对比:
| 方法类型 | 实现难度 | 内存占用 | 适用场景 |
|---|---|---|---|
| 正则化约束 | 中等 | 低,只需要存储Fisher矩阵和旧参数 | 任务数量较少、模型参数规模不大的场景 |
| 记忆回放 | 低 | 中等,需要存储部分旧样本 | 新旧任务数据分布差异不大的场景 |
| 参数隔离 | 高 | 高,需要为不同任务维护参数子集 | 任务数量多、对旧任务性能要求极高的场景 |
实践注意事项
- 记忆回放方法中,记忆缓冲区的样本选择要尽量覆盖旧任务的分布,避免采样偏差导致旧任务知识丢失
- 正则化方法的惩罚权重需要根据实际任务调整,权重过大会导致模型学不会新任务,过小则无法缓解遗忘
- 可以结合多种方法使用,比如同时使用记忆回放和EWC正则,往往能获得更好的防遗忘效果
- 在线训练过程中需要定期评估模型在旧任务上的性能,及时发现遗忘问题并调整策略