IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 每周汇报 - 树叶的图像分类 -> 正文阅读

[人工智能]每周汇报 - 树叶的图像分类

数据集 ??

图片简介

这项任务是预测树叶图像的类别。?该数据集包含176个类别,18353幅图像。?每个类别至少有50幅图像用于训练。

图片样品

代码实现

引入相关类库

import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

?使用pandas读取csv文件,csv有两列,一列是图片位置,另一列是对应的标签。

df = pd.read_csv('../input/classify-leaves/train.csv')
df

?

csv的图像位置不完全正确,需要改成正确的相对位置(以当前运行目录为基准)。

df["image"] = "../input/classify-leaves/" + df["image"]
df

?

使用sample对dataframe文件进行打乱,frac=1是全选的意思,并重置索引,丢弃原索引。因为dataframe的真实排列是以索引为准,如果不更正索引,即使表面上看行被打乱,实际数据的排列也会跟原来一样。

df = df.sample(frac=1).reset_index(drop = True)

取90%的数据集作为训练数据,取另外10%作为测试数据。

train_df = df[1800:]
test_df = df[:1800]

构建训练数据集和验证数据集的迭代器,preprocessing_function使用tf.keras的mobilenet_v2,对数据集进行归一化处理,使用了preprocessing_function之后就不能再使用rescale,不然会重复处理。这里的validation_split=0.11是把训练集切11%出来作为验证集,这样整体的训练集:验证集:测试集 = 8: 1: 1。rotation_range、width_shift_range等操作是数据增强的方法,这里暂时不用,因为会大大增大训练的量。

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
    validation_split=0.11,
#     rotation_range=360,
#     width_shift_range=0.1,
#     height_shift_range=0.1,
#     shear_range=0.1,
#     zoom_range=0.1,
#     horizontal_flip=True,
#     vertical_flip=True,
)

同样,也需要一个测试集的迭代器。因为要用真实的数据测试,所以不会使用数据增强。

test_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet50.preprocess_input
)

?训练批量为 128。

BATCH_SIZE = 128

使用train_generator,从dataframe中产生对应的训练集和验证集。

train_images = train_generator.flow_from_dataframe(
    dataframe=train_df,
    x_col='image',
    y_col='label',
    color_mode='rgb',
    class_mode='categorical',
    target_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
    subset='training'
)

//Found 14733 validated image filenames belonging to 176 classes.
val_images = train_generator.flow_from_dataframe(
    dataframe=train_df,
    x_col='image',
    y_col='label',
    color_mode='rgb',
    class_mode='categorical',
    target_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
    subset='validation'
)

//Found 1820 validated image filenames belonging to 176 classes.

同样,使用 test_generator 产生测试集。

test_images = test_generator.flow_from_dataframe(
    dataframe=test_df,
    x_col='image',
    y_col='label',
    color_mode='rgb',
    class_mode='categorical',
    target_size=(224, 224),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

//Found 1800 validated image filenames belonging to 176 classes.

?模型搭建,使用迁移学习,以ResNet50为主模型。

def create_model(input_shape=(224, 224, 3)):
    
    inputs = tf.keras.layers.Input(input_shape)
    
    base_model = tf.keras.applications.ResNet50(
        input_shape=(224, 224, 3), # 图片尺寸是 224 x 224 , 3 频道(RGB)。 
        weights='imagenet',  # imagenet权重
        pooling='avg',    # 平均池化
        include_top=False 
        )
    
    base_model.trainable = False # 只训练最后一层的参数,其他保持不动。
    
    x = base_model(inputs)
    x = tf.keras.layers.Dense(256, activation='relu')(x) # 全连接层
    x = tf.keras.layers.Dropout(0.25)(x)    # dropout层
    outputs =  tf.keras.layers.Dense(176, activation='softmax')(x) # 输出176种树叶分类情况
    
    model = tf.keras.models.Model(inputs, outputs)
    
    return model

实例化模型

model = create_model(input_shape = (224, 224, 3))

?模型的编译,使用adam优化器,使用交叉熵,标尺是准确率。

model.compile(
            optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

?模型概况

model.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
resnet50 (Functional)        (None, 2048)              23587712  
_________________________________________________________________
dense_1 (Dense)              (None, 256)               524544    
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 176)               45232     
=================================================================
Total params: 24,157,488
Trainable params: 569,776
Non-trainable params: 23,587,712

正式训练

history = model.fit(
    train_images,
    validation_data=val_images,
    epochs=150,
    callbacks=[
        tf.keras.callbacks.EarlyStopping( # 提前停止,防止过拟合
            monitor='val_loss', # 监控验证集的损失
            patience=5, # 最多忍耐5个epoch
            restore_best_weights=True # 当停止时,恢复最佳的权重参数
        )
    ]
)

?展示前5个epoch

Epoch 1/150
116/116 [==============================] - 45s 390ms/step - loss: 3.3070 - accuracy: 0.2583 - val_loss: 1.7681 - val_accuracy: 0.5522
Epoch 2/150
116/116 [==============================] - 43s 373ms/step - loss: 1.5507 - accuracy: 0.5736 - val_loss: 1.1069 - val_accuracy: 0.7082
Epoch 3/150
116/116 [==============================] - 44s 375ms/step - loss: 1.0763 - accuracy: 0.6924 - val_loss: 0.8298 - val_accuracy: 0.7846
Epoch 4/150
116/116 [==============================] - 44s 375ms/step - loss: 0.8250 - accuracy: 0.7579 - val_loss: 0.7166 - val_accuracy: 0.8033
Epoch 5/150
116/116 [==============================] - 44s 378ms/step - loss: 0.6627 - accuracy: 0.8061 - val_loss: 0.6385 - val_accuracy: 0.8143

?使用测试集对模型进行测试

results = model.evaluate(test_images)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-16 18:49:57  更:2021-11-16 18:50:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 6:21:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码