IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 第四周作业:卷积神经网络(Part2) -> 正文阅读

[人工智能]第四周作业:卷积神经网络(Part2)

AI研习社“猫狗大战”

结果展示

vgg结果

正确率卡在了91%
在这里插入图片描述
在这里插入图片描述

resnet结果

第一次迁移只有50左右的正确率,修改了损失函数有了80左右的正确率。
在这里插入图片描述
请添加图片描述

lenet结果

在这里插入图片描述

自己的cnn网络

import os
import torch
from torchvision import transforms,datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())

数据预处理

data_transform = transforms.Compose([
    transforms.Resize(84),
    transforms.CenterCrop(84),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])

获取数据,先把数据的路径,名称,标签都写入txt,然后通过txt获得信息,

class MyDataSet(Dataset):
    def __init__(self, txtPath, data_transform):
        self.imgPathArr = []
        self.labelArr = []
        with open(txtPath, "rb") as f:
            txtArr = f.readlines()
        for i in txtArr:
            fileArr = str(i.strip(), encoding = "unicode_escape").split(" ")
            self.imgPathArr.append(fileArr[0])
            self.labelArr.append(fileArr[1])
        self.transforms = data_transform

    def __getitem__(self, index):
        label = np.array(int(self.labelArr[index]))
        img_path = self.imgPathArr[index]
        pil_img = Image.open(img_path)
        if self.transforms:
            data = self.transforms(pil_img)
        else:
            pil_img = np.asarray(pil_img)
            data = torch.from_numpy(pil_img)
        return data, label

    def __len__(self):
        return len(self.imgPathArr)

一个简单的cnn网络

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(3, 16, 5, 1, 2),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        self.out = nn.Linear(32 * 21 * 21, 2)   

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           
        output = self.out(x)
        return output, x  
if __name__=='__main__':
    
    train_dataset = MyDataSet('D:/猫狗大战数据集/cat_dog/train.txt', data_transform) 
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 4,shuffle = True,num_workers = 4)

    test_dataset = MyDataSet('D:/猫狗大战数据集/cat_dog/text.txt', data_transform) 
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 1,shuffle = True,num_workers = 4)

  
    net = Net().to(device)
     
    cirterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),lr = 0.0001,momentum = 0.9)
    t = 0
    # Train
    try:
        for epoch in range(10):
            running_loss = 0.0  
            for i,data in enumerate(train_loader,0):
                inputs,labels = data
                inputs,labels = Variable(inputs),Variable(labels.long())
                outputs = net(inputs)[0]
                predicted = torch.max(outputs.data,1)[1].data.numpy()
            
                loss = cirterion(outputs,labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
          
                if i % 100 == 99:
                    print('[%d %5d] loss: %.3f' % (epoch + 1,i + 1,running_loss / 100))
                    running_loss = 0.0
    finally:
        print('finished training!')
        torch.save(net.state_dict(),'net_params.pkl')


    # Test
    correct = 0
    total = 0
     
    for data in test_loader:
        images,labels = data
        images,labels = Variable(images),labels
        print(labels)
        outputs = net(images)[0]

        predicted = torch.max(outputs.data,1)[1].data.numpy()
        total += labels.size(0)
        
        correct += (predicted == labels.numpy()).sum()
     
    print('Accuracy of the network on the 2000 test images: %d %%' % (100 * correct / total))

vgg迁移学习

别的部分没太大的区别,就是把网络换一下,然后把网络的输入输出跟自己的数据做一下适配。

model_vgg = models.vgg16(pretrained=True)
model_vgg_new = model_vgg;

for param in model_vgg_new.parameters():
    param.requires_grad = False
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2)
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1)

model_vgg_new = model_vgg_new.to(device)

resnet迁移学习

跟vgg迁移过程类似。
在vgg迁移学习的基础上迁移了resnet网络,效果却并不理想,准确率只有50%左右。改了损失函数以后acc上升到了80%多,现在仍在修改。
下一步准备扩大训练集试试效果。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-27 14:05:27  更:2021-09-27 14:08:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 12:51:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码