IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 机器学习(十一) 迁移学习 -> 正文阅读

[人工智能]机器学习(十一) 迁移学习

前言

??迁移学习在计算机视觉任务和自然语言处理任务中经常使用,这些模型往往需要大数据、复杂的网络结构。如果使用迁移学习,可将预训练的模型作为新模型的起点,这些预训练的模型在开发神经网络的时候已经在大数据集上训练好、模型设计也比较好,这样的模型通用性也比较好。如果要解决的问题与这些模型相关性较强,那么使用这些预训练模型,将大大地提升模型的性能和泛化能力。

1 原理

??迁移学习(Transfer Learning)是机器学习的一个研究方向,主要研究如何将任务 A 上面学习到的知识迁移到任务 B 上,以提高在任务 B 上的泛化性能。例如任务 A 为猫狗分类问题,需要训练一个分类器能够较好的分辨猫和狗的样本图片,任务 B 为牛羊分类问题。可以发现,任务 A 和任务 B 存在大量的共享知识,比如这些动物都可以从毛发、体型、形 态、发色等方面进行辨别。因此在任务 A 训练获得的分类器已经掌握了这部份知识,在训练任务 B 的分类器时,可以不从零开始训练,而是在任务 A 上获得的知识的基础上面进行训练(Feature Extraction)或微调(Fine tuning),这和“站在巨人的肩膀上”思想非常类似。通过迁移任务 A 上学习的知识,在任务 B 上训练分类器可以使用更少的样本和更少的训练代价,并且获得不错的泛化能力。
在这里插入图片描述
在神经网络迁移学习中,主要有两个应用场景:特征提取和微调。
? 特征提取(Feature Extraction) :冻结除最终完全连接层之外的所有网络的权重。最后一个全连接层被替换为具有随机权重的新层,因只需要更新最后一层全连接层,使得更新参数极大地减少,节省大量的 训练时间GPU 资源。
? 微调(Fine Tuning) :对于卷积神经网络,一般认为它能够逐层提取特征,越末层的网络的抽象特征提取能力越强,输出层一般使用与类别数相同输出节点的全连接层,作为分类网络的概率分布预测。对于相似的任务 A 和 B,如果它们的特征提取方法是相近的,则网络的前面数层可以重用。而微调技术就是使用预训练网络初始化网络,而不是随机初始化,用新数据训练部分或整个网络。小幅度更新前面的层的参数。

2 实例

??进行迁移学习需要使用对应的预训练模型。PyTorch提供了很多现成的预 训练模块,我们直接拿来使用就可以。主要集成在 torchvision.models 模块中,预训练模型可以通过传递参数 pretrained = True 构造。主要的模型有 AlexNet,VGG,ResNet,SqueezeNet,DenseNet,Inception v3,GoogLeNet,ShuffleNet v2 等。
??所有的预训练模型都要求输入图片以相同的方式进行标准化,即:小批l量3通道RGB格式 (3 × H × W) ,其中H和W应等于 224 。图片加载时像素值的范围应在 [0, 1] 内,然后通过指定 mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225] 进行标准化。

2.1 特征提取

??本次案例使用的数据集是 CIFAR-10数据集 ,目标是对数据集中10类物体进行分类,只使用几层的卷积和全连接层的分类正确率只有 68% 左右,结果不算好。此案例使用迁移学习中特征提取方法来实现这个任务,预训练模型采用 retnet18 网络。

#导入相关包
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from datetime import datetime

#为适合预训练模型,增加了一些预处理功能,如数据标准化,对图片进行裁剪等
trans_train = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406],
                         std = [0.229, 0.224, 0.225])])

#对测试集的预处理有一定不同,这一点对结果的影响很大
trans_vaild = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406],
                         std = [0.229, 0.224, 0.225])])

#CIFAR10数据集下载
trainset = torchvision.datasets.CIFAR10(
    root = 'data',
    download = False,
    train = True,
    transform = trans_train
)
trainloader = DataLoader(trainset, batch_size = 64, shuffle = True)

testset = torchvision.datasets.CIFAR10(
    root = 'data',
    download = False,
    train = False,
    transform = trans_vaild
)
testloader = DataLoader(testset, batch_size = 64, shuffle = False)

#这里会自动下载预训练模型,该模型网络架构为resnet18,
#已经在 ImageNet大数据集上训练好了,该数据集有1000类别
net = models.resnet18(pretrained = True)
#冻结于训练模型的全部参数
for param in net.parameters() :
    param.requires_grad = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#修改最后的全连接层,CIFAR10数据集只有10类
net.fc = nn.Linear(512, 10)

#查看总参数及训练参数
total_params = sum(p.numel() for p in net.parameters())
print(f'原参数的数量 : {total_params}') #11181642
total_params_trainable = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'需要训练的参数 : {total_params_trainable}') #5130

criterion = nn.CrossEntropyLoss()
#要注意这个地方 net.fc.parameters(),只更新最后的全连接参数而不是net.parameters()
optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, weight_decay = 0.001, momentum = 0.9)

#训练
net = net.to(device)
for epoch in range(20) :
    prev_time = datetime.now()
    train_losses = 0.0
    train_acc = 0.0
    net.train()
    for x, label in trainloader :
        x, label = x.to(device), label.to(device)
        optimizer.zero_grad()
        out = net(x)
        loss = criterion(out, label)
        train_losses += loss.item()
        _, pred = torch.max(out, dim = 1)
        num_correct = (pred == label).sum().item()
        train_acc += num_correct / x.size(0)
        loss.backward()
        optimizer.step()
    #计算每个循环所花费的时间
    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    #测试
    with torch.no_grad() :
        net.eval()
        test_losses = 0.0
        test_acc = 0.0
        for x, label in testloader :
            x, label = x.to(device), label.to(device)
            out = net(x)
            loss = criterion(out, label)
            test_losses += loss.item()
            _, pred = torch.max(out, dim = 1)
            num_correct = (pred == label).sum().item()
            test_acc += num_correct / x.size(0)

    print(f'Eopch {epoch}. Train Loss: {(train_losses / len(trainloader)):.4f}, '
          f'Train Acc: {(train_acc / len(trainloader)):.3f}, '
          f'Vaild Loss: {(test_losses / len(testloader)):.4f}, '
          f'Vaild Acc: {(test_acc / len(testloader)):.3f}, '
          f'Time: {time_str}')

结果:在这里插入图片描述
在前三个Epoch准确率就达到了 73.6% ,最终结果会达到 75% 左右,从精确率比第6章提升了近10个百分点。但是对于分类效果来说仍不是很理想。

2.2 微调

??微调允许修改预先训练好的网络参数来学习目标任务,所以,虽然训练时间要比特征抽取方法长,但精度更高。微调的大致过程是在预先训练过的网络上添加新的随机初始化层,此外预先训练的网络参数也会被更新,但会使用较小的学习率以防止预先训练好的参数发生较大的改变。
??在本次的微调任务中采用了数据增强的方法来使得分类效果更加。因为数据增强是提高模型的泛化能力最重要因素,数据增强技术主要有 水平或垂直翻转图像、裁剪、色彩 变换、扩展和旋转 等,通过数据增强技术不仅可以扩大训练数据集的规 模、降低模型对某些属性的依赖,从而提高模型的泛化能力,同时可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同的位置,从而减轻模型对物体出现位置的依赖性。并通过调整亮度、色彩等因素来降低模型对色彩的敏感度等。在PyTorch中图像增强的方法集成在 torchvision.transforms 模块中,主要的有:
? torchvision.transforms.Resize() :随机比例缩放。
? torhvision.transforms.RandomCrop() :在图像随机位置进行裁取。
? torhvision.transforms.CenterCrop() :在图像中心置进行裁取。
? torchvision.transforms.RandomHorizontalFlip() :随机水平翻转。
? torchvision.transforms.RandomVerticalFlip() :随机竖直翻转。
? torchvision.transforms.RandomRotation() :随机旋转。
? torchvision.transforms.ColorJitter() :改变亮度、对比度和颜色。
微调的代码与特征提取的不同地方主要在图像预处理部分和参数更新部分。
这里对训练数据添加了几种数据增强方法,如图像裁剪、旋转、颜色改变等方法。测试数据与特征提取一样,没有变化。

trans_train = transforms.Compose([
    transforms.RandomResizedCrop(256, scale = (0.8, 1.0)),
    transforms.RandomRotation(degrees = 15),
    transforms.ColorJitter(),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406],
                         std = [0.229, 0.224, 0.225])])

优化器部分,注意不要冻结预训练模型的参数。

optimizer = optim.SGD(net.parameters(), lr = 0.001, weight_decay = 0.001, momentum = 0.9)

结果:
在这里插入图片描述
由结果知微调+数据增强的方法在第三个Epoch正确率就可以达到 92% ,最终结果可达到 95% 左右,正确很高。本次实验只设置了20个Eopch,当继续增加Epoch时,正确率会接近 100% 。

参考文献:
?Python深度学习基于PyTorch
?TensorFlow深度学习

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-28 15:28:52  更:2022-02-28 15:32:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 17:42:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码