IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> tensorflow实现猫狗分类器(三)Inception V3迁移学习 -> 正文阅读

[人工智能]tensorflow实现猫狗分类器(三)Inception V3迁移学习

部分内容来自 博主史丹利复合田的Keras 入门课6 – 使用Inception V3模型进行迁移学习
地址:https://blog.csdn.net/tsyccnh/article/details/78889838

迁移学习主要分为两种

  • 第一种即所谓的transfer learning,迁移训练时,移掉最顶层,比如ImageNet训练任务的顶层就是一个1000输出的全连接层,换上新的顶层,比如输出为10的全连接层,然后训练的时候,只训练最后两层,即原网络的倒数第二层和新换的全连接输出层。可以说transfer learning将底层的网络当做了一个特征提取器来使用。
  • 第二种叫做fine tune,和transfer learning一样,换一个新的顶层,但是这一次在训练的过程中,所有的(或大部分)其它层都会经过训练。也就是底层的权重也会随着训练进行调整。

下载Inception V3相关数据

import os

from tensorflow.keras import layers
from tensorflow.keras import Model
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 \
    -O /tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
  
from tensorflow.keras.applications.inception_v3 import InceptionV3

local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'

设置pre_model架构(也叫base_model)

InceptionV3模型,其两个参数比较重要,一个是weights,如果是’imagenet’,Keras就会自动下载已经在ImageNet上训练好的参数,如果是None,系统会通过随机的方式初始化参数,目前该参数只有这两个选择。另一个参数是include_top,如果是True会保留全连接层。如果是False,会去掉顶层的全连接网络
这里的input_shape = (150, 150, 3)是我们输入网络的猫狗图片结构

pre_trained_model = InceptionV3(input_shape = (150, 150, 3), 
                                include_top = False, 
                                weights = None)

pre_trained_model.load_weights(local_weights_file)

冻结pre_trained_model所有层,让骨架模型不再被训练

for layer in pre_trained_model.layers:
  layer.trainable = False

由于网络结构太长这里只展示部分

pre_trained_model.summary()

在这里插入图片描述在这里插入图片描述

这里我们不采用InceptionV3的全部层次,让网络取到mixed7层作为输出连接到我们新添加的网络

last_layer = pre_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_layer.output_shape)
last_output = last_layer.output

last layer output shape: (None, 7, 7, 768)

给网络添加我们自己的几个层次,采用dropout减少过拟合

from tensorflow.keras.optimizers import RMSprop

# Flatten the output layer to 1 dimension
x = layers.Flatten()(last_output)
# Add a fully connected layer with 1,024 hidden units and ReLU activation
x = layers.Dense(1024, activation='relu')(x)
# Add a dropout rate of 0.2
x = layers.Dropout(0.2)(x)                  
# Add a final sigmoid layer for classification
x = layers.Dense  (1, activation='sigmoid')(x)           

model = Model( inputs=pre_trained_model.input,  outputs=x) 

model.compile(optimizer = RMSprop(lr=0.0001), 
              loss = 'binary_crossentropy', 
              metrics = ['acc'])

在这里插入图片描述

定义目录,采用数据增强

base_dir = '/tmp/cats_and_dogs_filtered'

train_dir = os.path.join( base_dir, 'train')
validation_dir = os.path.join( base_dir, 'validation')


train_cats_dir = os.path.join(train_dir, 'cats') # Directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs') # Directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats') # Directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')# Directory with our validation dog pictures

train_cat_fnames = os.listdir(train_cats_dir)
train_dog_fnames = os.listdir(train_dogs_dir)

# Add our data-augmentation parameters to ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator( rescale = 1.0/255. )

# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(train_dir,
                                                    batch_size = 20,
                                                    class_mode = 'binary', 
                                                    target_size = (150, 150))     

# Flow validation images in batches of 20 using test_datagen generator
validation_generator =  test_datagen.flow_from_directory( validation_dir,
                                                          batch_size  = 20,
                                                          class_mode  = 'binary', 
                                                          target_size = (150, 150))

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

训练网络

history = model.fit_generator(
            train_generator,
            validation_data = validation_generator,
            steps_per_epoch = 100,
            epochs = 20,
            validation_steps = 50,
            verbose = 2)

Epoch 1/20
100/100 - 29s - loss: 0.3360 - acc: 0.8655 - val_loss: 0.1211 - val_acc: 0.9470
Epoch 2/20
100/100 - 23s - loss: 0.2193 - acc: 0.9145 - val_loss: 0.1096 - val_acc: 0.9640
Epoch 3/20
100/100 - 23s - loss: 0.2038 - acc: 0.9290 - val_loss: 0.0888 - val_acc: 0.9660
Epoch 4/20
100/100 - 22s - loss: 0.1879 - acc: 0.9315 - val_loss: 0.1198 - val_acc: 0.9590
Epoch 5/20
100/100 - 23s - loss: 0.1760 - acc: 0.9415 - val_loss: 0.1155 - val_acc: 0.9660
Epoch 6/20
100/100 - 22s - loss: 0.1771 - acc: 0.9375 - val_loss: 0.1540 - val_acc: 0.9450
Epoch 7/20
100/100 - 23s - loss: 0.1916 - acc: 0.9370 - val_loss: 0.1616 - val_acc: 0.9550
Epoch 8/20
100/100 - 22s - loss: 0.1594 - acc: 0.9440 - val_loss: 0.1422 - val_acc: 0.9630
Epoch 9/20
100/100 - 23s - loss: 0.1669 - acc: 0.9465 - val_loss: 0.1099 - val_acc: 0.9650
Epoch 10/20
100/100 - 23s - loss: 0.1677 - acc: 0.9445 - val_loss: 0.1245 - val_acc: 0.9600
Epoch 11/20
100/100 - 22s - loss: 0.1653 - acc: 0.9470 - val_loss: 0.0918 - val_acc: 0.9730
Epoch 12/20
100/100 - 22s - loss: 0.1542 - acc: 0.9455 - val_loss: 0.1623 - val_acc: 0.9570
Epoch 13/20
100/100 - 22s - loss: 0.1525 - acc: 0.9520 - val_loss: 0.1087 - val_acc: 0.9670
Epoch 14/20
100/100 - 23s - loss: 0.1454 - acc: 0.9565 - val_loss: 0.1314 - val_acc: 0.9640
Epoch 15/20
100/100 - 22s - loss: 0.1279 - acc: 0.9525 - val_loss: 0.1515 - val_acc: 0.9630
Epoch 16/20
100/100 - 23s - loss: 0.1255 - acc: 0.9530 - val_loss: 0.1306 - val_acc: 0.9650
Epoch 17/20
100/100 - 22s - loss: 0.1430 - acc: 0.9575 - val_loss: 0.1226 - val_acc: 0.9660
Epoch 18/20
100/100 - 23s - loss: 0.1350 - acc: 0.9510 - val_loss: 0.1583 - val_acc: 0.9520
Epoch 19/20
100/100 - 22s - loss: 0.1288 - acc: 0.9580 - val_loss: 0.1170 - val_acc: 0.9710
Epoch 20/20
100/100 - 22s - loss: 0.1363 - acc: 0.9550 - val_loss: 0.1260 - val_acc: 0.9660

绘制损失和准确率图

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()


plt.show()

在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-22 14:41:05  更:2021-09-22 14:43:13 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 12:35:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码