一 pytorch 数据加载的研究
一、dataloader and dataset?
Dataset抽象类,所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。
DataLoader(): 迭代器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
二、类的实例化
大多数文章,并没有仔细探究Dataset这个类,究竟是怎么一步步完成数据和标签的加载的
first of all ,它是个类 所以,从类的角度,继承,重写,实例化,这个面向对象的思路,先研究一下
1.继承Dataset
代码如下(示例):xxx代表可以自己定义的内容
class myDataset(Dataset):
def __init__(self, xxx):
def __getitem__(self,index):
return xxx,xxx
def __len__(self):
return len(xxx)
可见,getitem 和 len 需要自己重写,并返回一些东西
2.重写父类函数
这里,采用了kaggle 的Dog Breed Identification项目的数据 是个分类任务,使用resnet vgg 就可以解决 数据集包含 3个文件 train(文件夹) test(文件夹) label.csv 可以到官网看 https://www.kaggle.com/competitions/dog-breed-identification/ 代码如下(示例):
from torch.utils.data import Dataset
import pandas as pd
import cv2
class myDataset(Dataset):
def __init__(self, dogdir):
self.imgset = dogdir["id"]
self.labelset = dogdir["breed"]
dog_breeds = sorted(list(set(self.labelset )))
n_classes = len(dog_breeds)
self.class_to_num = dict(zip(dog_breeds, range(n_classes)))
def __getitem__(self,index):
imgpath = "train/"+self.imgset[index] + ".jpg"
img = cv2.imread(imgpath)
labelname = self.labelset[index]
labelhot = self.class_to_num.get(labelname)
return img, labelhot
def __len__(self):
return len(self.imgset)
3.实例化
看看继承后的dataset
df = pd.read_csv('labels.csv')
myd = myDataset(df)
img,label = myd.__getitem__(4)
lenth = myd.__len__()
print(label)
print(lenth)
49 #one hot 编码后的标签 10222 # 总体数量
总结
提示:这里对文章进行总结:
例如:以上就是今天要讲的内容
|