IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> SWA实战:使用SWA进行微调,提高模型的泛化 -> 正文阅读

[人工智能]SWA实战:使用SWA进行微调,提高模型的泛化

摘要

论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa

论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第 i i i个epoch的checkpoint为 w i w_{i} wi?,一般情况下我们会选择训练过程中最后的一个epoch的模型 w n w_{n} wn?或者在验证集上效果最好的一个模型 w i ? w^{*}_{i} wi??作为最终模型。但SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值。

pytorch使用举例:

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
    for batch_idx, (data, target) in enumerate(train_loader):   
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        # 在反向传播前要手动将梯度清零
        optimizer.zero_grad()
        output = model(data)
        #计算losss
        loss = train_criterion(output, targets)
        # 反向传播求解梯度
        loss.backward()
        optimizer.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']   
    swa_model.update_parameters(model)
    swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。

详细实现过程

环境

pyotrch:1.10

准备

在开始今天的代码前,我们要准备好训练好的模型。然后才能开始今天的代码。

实现过程

定义模型,并将训练好的模型载入,代码如下:

    model_ft = efficientnet_b1(pretrained=True)
    print(model_ft)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Linear(num_ftrs, classes)
    model_ft.to(DEVICE)
    model_ft = torch.load(model_path)
    print(model_ft)
    fine_epoch = 80
    fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,
              use_amp)

定义模型为efficientnet_b1,这里要和训练的模型保持一致。

如果保存的整个模型,则使用torch.load(model_path)载入模型,如果只保存了权重信息,则要使用model_ft=load_state_dict(torch.load(model_path)),载入模型。

然后,设置fine的epoch为80。

接下来,我们一起去看fine_tune函数中的内容。

 # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
    if use_amp:
        model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”

定义优化器为SGD。

如果使用混合精度,则对amp初始化。

 # 随机权重平均SWA,实现更好的泛化
 swa_model = AveragedModel(model).to(device)
 # SWA调整学习率
 swa_scheduler = SWALR(optimizer, swa_lr=1e-6)

初始化SWA。

使用SWALR调整学习率。

接下来循环epoch,这里都是比较通用的逻辑。

 for epoch in range(1, epoch + 1):
        model.train()
        train_loss = 0
        total_num = len(train_loader.dataset)
        print(total_num, len(train_loader))
        for batch_idx, (data, target) in enumerate(train_loader):
            if len(data) % 2 != 0:
                print(len(data))
                data = data[0:len(data) - 1]
                target = target[0:len(target) - 1]
                print(len(data))
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            samples, targets = mixup_fn(data, target)
            output = model(samples)
            loss = train_criterion(output, targets)
            optimizer.zero_grad()
            if use_amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print_loss = loss.data.item()
            train_loss += print_loss
            if (batch_idx + 1) % 10 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                    epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                           100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
        swa_model.update_parameters(model)
        swa_scheduler.step()

主要步骤有:

1、计算loss。

2、是否使用amp混合精度,如果使用混合精度则使用scaled_loss反向传播求梯度,否则直接loss反向传播求梯度。

3、 swa_model.update_parameters(model)更新swa_model的参数。

4、 swa_scheduler.step()更新学习率。

等待所有的epoch执行完成后。

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), "last.pt")

更新BN层参数。

然后保存模型的权重。注意:这里只能保存模型的权重,不能保存整个模型。

完成之后就可以测试了,执行代码:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.models.mobilenetv3 import mobilenet_v3_large
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
from timm.models.efficientnet import efficientnet_b1
import numpy as np

def show_outputs(output):

    output_sorted = sorted(output, reverse=True)
    top5_str = '-----TOP 5-----\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        for j in range(len(index)):
            if (i + j) >= 5:
                break
            if value > 0:
                topi = '{}: {}\n'.format(index[j], value)
            else:
                topi = '-1: 0.0\n'
            top5_str += topi
    print(top5_str)

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = efficientnet_b1(pretrained=True)

num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 8)
swa_model = AveragedModel(model)
swa_model.load_state_dict(torch.load("last.pt"))
swa_model.to(DEVICE)
swa_model.eval()

path = 'test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = swa_model(img)
    out = out.data.cpu().numpy()[0]
    print(file)
    show_outputs(out)

这里测试代码和以前的写法没有啥区别,唯一不同的地方:

重新定义模型,然后载入权重。
运行结果:
image-20220425210850314
完整代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85223146

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-27 11:19:34  更:2022-04-27 11:20:41 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:15:12-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码