IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 深度学习之基于CNN和ResNet实现鸟类识别 -> 正文阅读

[人工智能]深度学习之基于CNN和ResNet实现鸟类识别

本次利用迁移学习用已经构建好的ResNet网络对鸟类图片进行分类,但是结果不甚理想。

1.导入库

import numpy as np
import tensorflow as tf
import os,PIL
import random
import pathlib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.utils import np_utils

2.数据加载

(需要数据的可以私信我

加载数据
dataset_url = "E:/tmp/.keras/datasets/Birds_photos"
dataset_dir = pathlib.Path(dataset_url)
train_Bananaquit = os.path.join(dataset_dir,"train","Bananaquit")
train_BlackSki = os.path.join(dataset_dir,"train","Black Skimmer")
train_BTB = os.path.join(dataset_dir,"train","Black Throated Bushtiti")
train_Cockatoo = os.path.join(dataset_dir,"train","Cockatoo")
train_dir = os.path.join(dataset_dir,"train")

test_Bananaquit = os.path.join(dataset_dir,"test","Bananaquit")
test_BlackSki = os.path.join(dataset_dir,"test","Black Skimmer")
test_BTB = os.path.join(dataset_dir,"test","Black Throated Bushtiti")
test_Cockatoo = os.path.join(dataset_dir,"test","Cockatoo")
test_dir = os.path.join(dataset_dir,"test")
#统计训练集和测试集的数据数目
train_Bananaquit_num = len(os.listdir(train_Bananaquit))
train_BlackSki_num = len(os.listdir(train_BlackSki))
train_BTB_num = len(os.listdir(train_BTB))
train_Cockatoo_num = len(os.listdir(train_Cockatoo))
train_all = train_Bananaquit_num+train_BlackSki_num+train_BTB_num+train_Cockatoo_num

test_Bananaquit_num = len(os.listdir(test_Bananaquit))
test_BlackSki_num = len(os.listdir(test_BlackSki))
test_BTB_num = len(os.listdir(test_BTB))
test_Cockatoo_num = len(os.listdir(test_Cockatoo))
test_all = test_Bananaquit_num+test_BlackSki_num+test_BTB_num+test_Cockatoo_num

3.超参数的设置

其实这一模块博主一直不太明白,每次都是乱试,不知道怎样设置超参数才能使得效果最好。

batch_size = 32
epochs = 10
height = 224
width = 224

4.数据预处理

数据预处理的几步:归一化->调整图片大小

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)

train_data_gen = train_generator.flow_from_directory(
    batch_size=batch_size,
    directory=train_dir,
    shuffle=True,
    target_size=(height,width),
    class_mode="categorical"
)

test_data_gen = test_generator.flow_from_directory(
    batch_size=batch_size,
    directory=test_dir,
    shuffle=True,
    target_size=(height,width),
    class_mode="categorical"
)

5.CNN网络搭建&&编译

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),
    tf.keras.layers.AveragePooling2D((2,2)),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128,activation="relu"),
    tf.keras.layers.Dense(4,activation='softmax')
])
model.compile(optimizer="adam",
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=["acc"])

结果如下所示:
在这里插入图片描述
虽然加入了Dropout层,但是仍然出现了过拟合的现象。基于此,进行数据增强操作。

6.数据增强

这一部分应当与数据预处理合为一步操作。数据增强包括随机选择、水平翻转、放大操作等。

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255,
                                                                  rotation_range=45,#随机翻转
                                                                  width_shift_range=.15,
                                                                  height_shift_range=.15,
                                                                  horizontal_flip=True,#水平翻转
                                                                  zoom_range=0.5#放大操作
                                                                  )

经过数据增强之后的实验结果如下所示:
在这里插入图片描述
经过20次epochs之后,过拟合的现象得到了缓解。

7.ResNet

利用已经搭建好的ResNet网络对同样的数据集进行训练。

conv_base = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(height,width, 3))
conv_base.trainable = False
model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(4,activation='sigmoid'))
model.compile(optimizer='Adam',
              loss='binary_crossentropy',
              metrics=['acc'])

由于硬件的原因,训练速度特别慢,而且实验效果很差,在没有经过数据增强之前,过拟合现象(有可能不是这种现象)很严重,至于数据增强之后的效果如何,博主并未测试。
在这里插入图片描述
对于训练集,准确率有时高达100%。但是对于测试集,实验效果很难差强人意。希望过路的大佬指正。除此之外,博主还利用了VGG16网络进行训练,实验效果相对于ResNet50而言变好了,但是训练速度特别慢。

努力加油a啊

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-07-23 10:47:03  更:2021-07-23 10:47:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 11:06:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码