前言
提示:个人分类数据集预处理:
MNIST是经典的手写数字分类数据集,数据集中的图像是灰度图像,图像格式为png,图像尺寸为28*28,最主要的是MNIST数据集格式如下图所示。
我们怎么把自己制作的分类数据集中jpg或者png格式的图片及标签转化为上面那种ubyte格式,因为,很多算法用如下代码来加载MNIST数据,所以我们也可以把自己的数据格式转化为MNIST格式,也这样调用。
train_dataset = torchvision.datasets.MNIST(root= data_path, train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_set = torchvision.datasets.MNIST(root= data_path, train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
提示:以下是本篇文章正文内容,下面案例可供参考
一、自建交通标志分类数据集并做预处理
自建的交通标志分类部分数据如下截图; 右转交通标志 左转交通标志
人行道交通标志
我的分类数据集为10类交通标志图片,类别分别为 [‘ahead’, ‘clearway_no_stopping’, ‘crosswalk’, ‘motorway’, ‘non-motorized_lane’, ‘speed_limit_50’, ‘split-way’, ‘turn_left’, ‘turn_right’, ‘warning_sign’]
10类交通标志分类数据集下载链接---->10类4000张交通标志数据集 有需要的可以下载
1.RGB图像转灰度图像
from PIL import Image
import os
INPUT_PATH = r'F:\DL_xian_data\tools\training-images\9'
OUPUT_PATH = r'F:\DL_xian_data\tools\training-images\9_gray'
files_list = os.listdir(INPUT_PATH)
for file in files_list:
I = Image.open(INPUT_PATH + "/" + file)
L = I.convert('L')
L.save(OUPUT_PATH + "/" + file)
2.灰度图像缩放为28*28尺寸
import os
import os.path
from PIL import Image
infile = r'F:\DL_xian_data\SNN_model\dataset\turn_left'
outfile = r'F:\DL_xian_data\SNN_model\dataset_28_28\7'
list_img = os.listdir(infile)
n = 0
l = len(list_img)
for each_img in list_img:
print(each_img)
image_input_fullname = infile + '/' + each_img
resize_img = Image.open(image_input_fullname)
out_w = 28
out_h = 28
out_img = resize_img.resize((out_w,out_h), Image.ANTIALIAS)
image_output_fullname = outfile + "/" + each_img
out_img.convert('RGB')
out_img.save(image_output_fullname)
n += 1
print('%d/%d img has been resized!' %(n,l))
print('total_num is {%d} success resized img!' %len(list_img))
二、转换为MNIST数据格式
1.转换代码如下
import os
from PIL import Image
from array import *
from random import shuffle
Names = [['train-images','train'], ['t10k-images','t10k']]
for name in Names:
data_image = array('B')
data_label = array('B')
FileList = []
for dirname in os.listdir(name[0]):
path = os.path.join(name[0],dirname)
for filename in os.listdir(path):
if filename.endswith(".png"):
FileList.append(os.path.join(name[0],dirname,filename))
shuffle(FileList)
for filename in FileList:
print(FileList)
print(filename)
label = int(filename.split('\\')[1])
print(label)
Im = Image.open(filename)
pixel = Im.load()
width, height = Im.size
for x in range(0,width):
for y in range(0,height):
data_image.append(Im.getpixel((x, y)))
data_label.append(label)
hexval = "{0:#0{1}x}".format(len(FileList),6)
header = array('B')
header.extend([0,0,8,1,0,0])
header.append(int('0x'+hexval[2:][:2],16))
header.append(int('0x'+hexval[2:][2:],16))
data_label = header + data_label
if max([width,height]) <= 256:
header.extend([0,0,0,width,0,0,0,height])
else:
raise ValueError('Image exceeds maximum size: 256x256 pixels');
header[3] = 3
data_image = header + data_image
output_file = open(name[1]+'-images-idx3-ubyte', 'wb')
data_image.tofile(output_file)
output_file.close()
output_file = open(name[1]+'-labels-idx1-ubyte', 'wb')
data_label.tofile(output_file)
output_file.close()
for name in Names:
os.system('gzip '+name[1]+'-images-idx3-ubyte')
os.system('gzip '+name[1]+'-labels-idx1-ubyte')
2.转换过程打印如下
3.最终转换结果
然后就可以用这些转好的MNIST格式的数据开始训练了
MNIST格式的10类交通标志数据集下载链接---->交通标志数据集10类下载链接MNIST格式
总结
以上就是今天要讲的内容,本文仅仅简单介绍了自建分类数据集转换为MNIST格式数据的过程,作者转换过程遇到一些坑,所以特此记录,供需要的人参考。
|