IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch——迁移学习实战宝可梦精灵分类 -> 正文阅读

[人工智能]pytorch——迁移学习实战宝可梦精灵分类

数据集

使用宝可梦精灵的图片数据集。数据集地址:

  • 链接:https://pan.baidu.com/s/1zDERMsV1AvwfZudhuae6Ew
  • 提取码:rs4h

数据集中的每一类别的图片放在一个文件夹中
在这里插入图片描述
数据集共包含5个类别的图片,我们取每个文件夹(类别):

  • 前60%做训练集
  • 60%~80%做验证集
  • 80%~100%做测试集
    在这里插入图片描述

数据集处理

'''
load图片数据集
'''
import torch
import os, glob
import random, csv

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from PIL import Image


class Pokemon(Dataset):

    def __init__(self, root, resize, mode):
        '''
        :param root: 数据集目录
        :param resize: 图片的输出size
        :param mode: train/val/test
        '''
        super(Pokemon, self).__init__()

        self.root = root  # 根目录
        self.resize = resize  # 图片的输出size
        self.name2label = {} # 对目录名(类别)进行编码
        for name in sorted(os.listdir(os.path.join(root))):  # 遍历目录和文件
            if not os.path.isdir(os.path.join(root, name)):  # 如果不是目录(是图片)
                continue

            self.name2label[name] = len(self.name2label.keys())  # 用字典保存类别的编码
        # print(self.name2label)

        '''读入图片数据集'''
        # image, label
        self.images, self.labels = self.load_csv('images.csv')

        '''划分train、val、test集'''
        if mode=='train':  # train: 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode=='val':  # val: 20% = 60%->80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:  # test: 20% = 80%->100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]


    def load_csv(self, filename):
        '''
        一次加载进所有图片可能会造成内存不够用,因此我们可以把图片保存到一个csv文件
        :param filename:保存的文件名
        :return:
        '''

        # 如果csv文件不存在,就创建文件
        # 如果csv文件存在,就是之前已经创建过,直接读取就好了
        if not os.path.exists(os.path.join(self.root, filename)):

            '''把所有的文件放到一个list中去。文件的class可以通过路径名来判定'''
            images = []
            for name in self.name2label.keys():
                # 'pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            print(len(images), images)  # 1167

            random.shuffle(images)  # 打乱顺序

            '''写入csv文件'''
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([img, label])
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                print('writen into csv file:', filename)

        '''read from csv file'''
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # 'pokemon\\bulbasaur\\00000000.png', 0
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)  # 检查条件,不符合就终止

        return images, labels


    def __len__(self):
        '''
        返回总体样本数量
        :return:
        '''
        return len(self.images)


    def denormalize(self, x_hat):
        '''
        逆标准化处理
        :param x_hat: 标准化的tensor
        :return: 逆标准化的tensor
        '''
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x: [channel, high, wight]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        print(mean.shape, std.shape)
        x = x_hat * std + mean

        return x


    def __getitem__(self, idx):
        '''
        取得当前位置图片
        :param idx: 图片索引
        :return:
        '''

        img, label = self.images[idx], self.labels[idx]

        '''数据增强之后将图片转换为tensor'''
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path= > image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),  # 图片放大1.25倍
            transforms.RandomRotation(15),  # 随机旋转,在-15° ~ +15°之间
            transforms.CenterCrop(self.resize),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化,这几个数是大范围统计出来的rgb三原色的均值和方差
                                 std=[0.229, 0.224, 0.225])
        ])

        # tf = transforms.Compose([
        #     lambda x:Image.open(x).convert('RGB'),  # string path= > image data
        #     transforms.Resize((self.resize, self.resize)),  # 图片放大1.25倍
        #     transforms.ToTensor(),
        # ])

        img = tf(img)
        label = torch.tensor(label)

        return img, label


def main():
    '''
    可视化查看数据集

    此处需要安装并开启visdom
    安装:pip install visdom
    开启:python -m visdom.server
    '''
    import visdom
    import time
    import torchvision

    viz = visdom.Visdom()

    # 如果图片的存储很标准,可以用这种方法
    # tf = transforms.Compose([
    #                 transforms.Resize((64,64)),
    #                 transforms.ToTensor(),
    # ])
    # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    # loader = DataLoader(db, batch_size=32, shuffle=True)
    #
    # print(db.class_to_idx)
    #
    # for x,y in loader:
    #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    #
    #     time.sleep(10)


    # 通用的方法
    db = Pokemon('pokemon', 64, 'train')

    x,y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    # 加载一张图片
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    # viz.image(x, win='sample_x', opts=dict(title='sample_x'))

    # 加载一个batch的图片
    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)


if __name__ == '__main__':
    main()

迁移学习网络

原理

Pokemon和ImageNet都需要图片中提取特征,因此存在某些共性的knowledge。因此我们可以利用更加通用的ImageNet的模型,帮我们解决特定的图片分类任务。

我们采用torchvision.models中训练好的resnet18,使用它训练好的卷积部分提取图像特征,并训练新的分类器处理我们提取到的特征。

这样我们只需要训练分类器,而不用再训练特征提取器,因此可以减少所需训练量。
在这里插入图片描述

代码实现

辅助文件:utils.py

from matplotlib import pyplot as plt
import torch
from torch import nn

'''
定义一个神经网络层
第一个维度保持,其他维度打平成一个维度
'''
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


'''
把image打印在matplotlab上
'''
def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

实现网络构建,网络训练与评估的文件:train_transfer.py

'''
利用迁移学习

torchvision提供了训练好的resnet18、resnet34、resnet50...

此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''

import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader

from pokemon import Pokemon
from utils import Flatten

# 引入已经训练好的model
from torchvision.models import resnet18



batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
# 每次会开启num_work个线程,分别去加载dataset里面的数据,直到每个worker加载数据量为batch_size 大小(总共num_work*batch_size)才会进行下一步训练
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


viz = visdom.Visdom()

def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():  # 不计算梯度
            logits = model(x)  # 前向运算
            pred = logits.argmax(dim=1)  # 选出输出层最大的元素
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total

def main():

    '''初始化网络'''
    trained_model = resnet18(pretrained=True)  # 已经训练好的model
    # x: [b, 3, 224, 224]
    model = nn.Sequential(*list(trained_model.children())[:-1],  # [b, 3, 224, 224] => [b, 512, 1, 1] # 取出从0到17层,作为特征提取器
                          Flatten(),  # [b, 512, 1, 1] => [b, 512] # 自己定义的类,改变tensor维度
                          nn.Linear(512, 5)  # [b, 512] => [b, 5] # 随机初始化的一个新的线性层,作为分类器
                          ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    '''记录实验结果参数'''
    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    '''训练与评估'''
    for epoch in range(epochs):

        '''训练一次模型'''
        for step, (x, y) in enumerate(train_loader):  # 遍历
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)

            # logits: [b, 5]
            # y: [b]
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        '''评估模型'''
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')  # 保存评估结果最好的模型

                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    '''加载最优模型'''
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    '''测试模型'''
    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)



if __name__ == '__main__':
    main()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-12 16:36:04  更:2021-08-12 16:37:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:49:38-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码