IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 开发工具 -> Kaggle入门:手写数字识别Digit Recognizer竞赛 -> 正文阅读

[开发工具]Kaggle入门:手写数字识别Digit Recognizer竞赛

一、Kaggle介绍

Kaggle是在AI行业领域内,一个国内外都比较出名的网站,国内的阿里天池就是对标这个网站。上面有着丰富的数据集,包括计算机视觉CV领域的(包括:图像识别、目标检测、语义分割等)数据集,也有像波士顿房价预测这样的数据集,还有语音识别方面的数据集等等,上面不定时会举办一些比赛,任何注册成员都可以参加,Kaggle不但免费提供数据集,每周还有一定的免费GPU和TPU使用额度(GPU每周免费使用30个小时,TPU免费使用20个小时),我们的代码可以在kaggle所提供的NoteBook上运行,你可以学习参考别人分享的代码,也可以将自己的代码分享出来。总的来说,这是一个AI方面学习者很好的社区,有着良好的开源氛围,通过参加比赛和交流也可以增加我们的相关经验,免费的GPU和CPU也降低了我们学习的门槛。

Kaggle网站地址:Kaggle: Your Machine Learning and Data Science Community

二、Digit Recognizer竞赛

这是一个计算机视觉图像识别方面入门级的比赛,你可以理解为学习计算机视觉方面的“Hello World”,这个Mnist手写数字数据集可以追溯到上个世纪,可能比你的年纪还要大,当时是训练后运用到 邮票编码的数字识别银行支票的数字识别,有些模型甚至到现在都还在使用。

数据集简介:

MNIST数据集由Yann LeCun搜集,是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,较简单,适合深度学习初学者练习建立模型、训练、预测。MNIST数据库中的图像集是NIST(National Institute of Standards and Technology)的两个数据库的组合:专用数据库1和特殊数据库3。数据集是有250人手写数字组成,一半是高中生,一半是美国人口普查局。

MNIST数据集共有训练数据60000项、测试数据10000项。每张图像的大小为28*28(像素),每张图像都为灰度图像,位深度为8(灰度图像是0-255)。大概长这样:

三、实战代码

我的代码是在Kaggle的Notebook上直接运行的。

1.训练:

import os
import math
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, models, transforms
import torch.optim as optim
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataframe_train_valid = pd.read_csv(os.path.join('../input/digit-recognizer/', 'train.csv'), dtype=np.float32)
dataframe_test = pd.read_csv(os.path.join('../input/digit-recognizer/', 'test.csv'), dtype=np.float32)
class mnist_data(Dataset):
    def __init__(self, type, dataframe, transform):
        if type == 'train' or type == 'valid':
            labels = dataframe.label.values 
            features = (dataframe.loc[:, dataframe.columns != "label"].values) 
            # 划分训练集与验证集
            features_train, features_valid, labels_train, labels_valid = \
            train_test_split(features, labels, test_size=0.2, random_state=0)
            if type == 'train':
                self.X = features_train.reshape((-1,28,28))
                self.y = labels_train
            elif type == 'valid':
                self.X = features_valid.reshape((-1,28,28))
                self.y = labels_valid
        if type == 'test':
            self.X = dataframe.values.reshape((-1,28,28))
            self.y = None
        self.transform = transform
    
    def __getitem__(self, index):
        if self.y is not None:
            return self.y[index], self.transform(self.X[index])
        else:
            return self.transform(self.X[index])
    
    def __len__(self):
        return self.X.shape[0]

batch_size = 256
train_dataset = mnist_data('train', dataframe_train_valid, 
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(mean=(0.1307,), std=(0.3081,))
                           ]))
valid_dataset = mnist_data('valid', dataframe_train_valid, 
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(mean=(0.1307,), std=(0.3081,))
                           ]))
test_dataset = mnist_data('test', dataframe_test, 
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(mean=(0.1307,), std=(0.3081,))
                           ]))
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

model = models.resnet18()#pretrained=True
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# 选择优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

#optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-2)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=1000)#, eta_min=1e-6
'''
# 若训练时测量值(如loss)停滞,则调整学习率
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                       patience=5, 
                                                       verbose=1,
                                                       factor=0.5, 
                                                       min_lr=1e-5)'''
# 选择loss function
criterion = nn.CrossEntropyLoss()


# 使用gpu进行训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion.to(device)

count = 0
loss_list = []
iteration_list = []
accuracy_list = []
best_accuracy = 0
for epoch in range(1000):
    for i, (labels, images) in enumerate(train_dataloader):
        train = Variable(images.type(torch.FloatTensor)).to(device)
        labels = Variable(labels.type(torch.LongTensor)).to(device)
        
        optimizer.zero_grad()
        outputs = model(train)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        count = count + 1
        if count % 50 == 0:
        # 检查loss与该模型在验证集下的识别准确率
            correct = 0
            total = 0
            for i, (labels, images) in enumerate(valid_dataloader):
                valid = Variable(images.type(torch.FloatTensor)).to(device)
                labels = Variable(labels.type(torch.LongTensor)).to(device)
                outputs = model(valid)
                predicted = torch.max(outputs.data, 1)[1]
                total += len(labels)
                correct += (predicted == labels).sum()
            accuracy = 100 * correct / float(total)
            loss_list.append(loss.data)
            iteration_list.append(count)
            accuracy_list.append(accuracy)
            print('Epoch:{} Iteration: {}  Loss: {}  Accuracy: {} %'.format(epoch,count, 
                                                                   loss.data, 
                                                                   accuracy))
            if accuracy > best_accuracy:
                torch.save(model.state_dict(),'./mymodel.pt')
    scheduler.step()#loss
    optimizer.step()

2.模型评估并且生成提交文件:

from torchvision import datasets, models, transforms
import torch.nn as nn
model = models.resnet18()#pretrained=True
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)#识别种类数
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.load_state_dict(torch.load('./mymodel.pt'))
model.eval()

model.to(device)
prediction = []
with torch.no_grad():
    for i, images in enumerate(test_dataloader):
        test = Variable(images.type(torch.FloatTensor)).to(device)
        outputs = model(test)
        predicted = torch.max(outputs.data, 1)[1]
        prediction.append(predicted.cpu())
p = [x.numpy() for x in prediction]
p = np.array(p,dtype=object)
p = np.hstack(p)
print(p.shape)
submission =  pd.DataFrame({
        "ImageId": np.arange(len(p))+1,
        "Label": p.tolist()
})
submission.to_csv('./sample_submission_leonard2021.csv', index=False)
print(submission)

?完成训练和评估后,可以刷新Output界面,将 提交文件 下载到你的电脑再上传到比赛界面中就可得到你的成绩和排名了。

我使用的自己改造过ResNet18的网络,将网络的输入图片格式改为单通道,并且将全连接层的输出改为10,使用的是Aadm优化器(初始学习率为0.01),学习率调整算法为CosineAnnealingLR余弦退火。

我简单训练提交后,比赛的评估成绩是准确率为0.99325,排名在前17%左右,中规中矩,你也可以自己改进模型和调参,或者尝试冻结部分网络参数,来达到更好的成绩。

?——————————————————————————————————————————

如果本文对你有帮助,欢迎一键三连!!!

  开发工具 最新文章
Postman接口测试之Mock快速入门
ASCII码空格替换查表_最全ASCII码对照表0-2
如何使用 ssh 建立 socks 代理
Typora配合PicGo阿里云图床配置
SoapUI、Jmeter、Postman三种接口测试工具的
github用相对路径显示图片_GitHub 中 readm
Windows编译g2o及其g2o viewer
解决jupyter notebook无法连接/ jupyter连接
Git恢复到之前版本
VScode常用快捷键
上一篇文章      下一篇文章      查看所有文章
加:2022-05-03 09:26:51  更:2022-05-03 09:27:00 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 2:38:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码