for epoch inrange(300): forinput, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() if epoch > swa_start: swa_model.update_parameters(model) swa_scheduler.step() else: scheduler.step()
# Update bn statistics for the swa_model at the end torch.optim.swa_utils.update_bn(loader, swa_model) # Use swa_model to make predictions on test data preds = swa_model(test_input)
EMA示例
在下面的示例中,ema_model是 EMA
模型,它累积权重的指数衰减平均值,衰减率为 0.999。我们总共训练模型 300
个时期,并立即开始收集 EMA 平均值。
for epoch inrange(300): forinput, target in loader: optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step() ema_model.update_parameters(model)
# Update bn statistics for the ema_model at the end torch.optim.swa_utils.update_bn(loader, ema_model) # Use ema_model to make predictions on test data preds = ema_model(test_input)