IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 图像识别实战(一)----数据集的预处理 -> 正文阅读

[Python知识库]图像识别实战(一)----数据集的预处理

图像识别实战(一)----数据集的预处理

1.模块的导入

import os
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.数据集的读取

data_dir = './flower_data'
train_dir = data_dir+ '/train'
valid_dir = data_dir+ '/valid'

3.数据集的预处理

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45度到45度之间
                                transforms.CenterCrop(224),#从中心开始裁剪
                                transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率
                                transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
                                transforms.ColorJitter(brightness=0.2, contrast=0.1,saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
                                transforms.RandomGrayscale(p=0.025),#概率转化为灰度值,3通道就是R=G=B
                                transforms.ToTensor(),#转化为Tensor格式,在预处理结束后必须添加
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#均值,标准差,经过这样处理后的数据符合标准正态分布,即均值为0,标准差为1。使模型更容易收敛。
    'valid': transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]), 

transforms.Compose()这个类的主要功能是串联图片的变换操作,类似于一个列表。

4.数据集的组织与加载

batch_size = 8
image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}#就是用来包装所使用的数据,每次抛出一批数据
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
dataset=torchvision.datasets.ImageFolder(
                       root, #图片存储的根目录
                       transform=None, #图片的预处理操作
                       target_transform=None, #对图片类别做预处理操作
                       loader=<function default_loader>, #数据集加载方式
                       is_valid_file=None)#获取图像文件的路径并检查该文件是否为有效文件
#print(dataset.classes)  #根据分的文件夹的名字来确定的类别
#print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
#print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

我们打印出 image_datasets

{'train': Dataset ImageFolder
     Number of datapoints: 3614
     Root location: F:/flower_data/train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=(-45, 45), resample=False, expand=False)
                CenterCrop(size=(224, 224))
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 56
     Root location: F:/flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=PIL.Image.BILINEAR)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )}

我们打印出 dataset_sizes 帮助理解{}中的逻辑

{'train': 3614, 'valid': 56}

5.数据集图像展示

def im_convert(tensor):
    """展示数据"""
    image = tensor.to('cpu').clone().detach()#将Tensor数据从GPU放到CPU,复制和这个Tensor并且去掉梯度
    image = image.numpy().squeeze()#祛除数组中为1 的维度
    image = image.transpose(1,2,0)#Pytorch中为[Channels, H, W],而plt.imshow()中则是[H, W, Channels],所以交换一下通道
    image = image*np.array((0.229, 0.224, 0.225))+np.array((0.485, 0.456, 0.406))# 反转一下transforms.Normalize()的过程
    image = image.clip(0, 1)#归一化
    return image
fig = plt.figure(figsize=(20, 12))#设置图像尺寸
columns = 4
rows = 2
#我们设置的一个batchsize=8,所以dataloaders里只有8张图片,最多显示8张图片
dataiter = iter(dataloaders['valid'])#iter()迭代器
inputs, classes = dataiter.next()
for idx in range (columns*rows):
    ax = fig.add_subplot(rows,columns, idx+1,xticks=[], yticks=[])#图像区域划分row行,colums列,第idx+1个
    ax.set_title(class_names[classes[idx].item()])
    plt.imshow(im_convert(inputs[idx]))

plt.show()   

在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-10-17 11:57:20  更:2021-10-17 11:59:28 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 11:51:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计