Pytorch实现AverageModel

Pytorch中AveragedModel实现了随机权重平均 (SWA) 和指数移动平均 (EMA) 的平均模型。

官方文档

SWA示例

在下面的示例中,swa_model是累积权重平均值的 SWA 模型。我们总共训练模型 300 个时期,并切换到 SWA 学习率计划,并开始在第 160 个时期收集参数的 SWA 平均值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if epoch > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)

EMA示例

在下面的示例中,ema_model是 EMA 模型,它累积权重的指数衰减平均值,衰减率为 0.999。我们总共训练模型 300 个时期,并立即开始收集 EMA 平均值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
loader, optimizer, model, loss_fn = ...
ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))

for epoch in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
ema_model.update_parameters(model)

# Update bn statistics for the ema_model at the end
torch.optim.swa_utils.update_bn(loader, ema_model)
# Use ema_model to make predictions on test data
preds = ema_model(test_input)