DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其参数: 其中常用的参数有,dataset 为要使用的数据集;batch_size 为一次性要加载的数据个数;shuffle 为是否打乱数据,True 为打乱,False 为不打乱;num_workers 我们加载数据为多进程还是单进程,如果是单进程就写0 ,如果是多进程就写>=1 ;在windows 下如果是写多进程可能会报错,可以直接写成0,在Linux 下如果有多进程则可以写多进程;drop_last 为总共的数据除以batch_size 是否希望有余数,若不希望有余数则True ,若希望有余数则False 。
简单粗暴上代码:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
tran_tensor = transforms.ToTensor()
test_set = torchvision.datasets.CIFAR10(root ="./dataset",train=False,transforms=tran_tensor,download=True)
test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
for epoch in range(2):
step = 0
for data in test_loader:
writer.add_images("img_tensor",imgs,step)
step = step+1
writer.close()
run 之后,输入命令行:
tensorboard --logdir=logs
两个epoch,一样的step,里面的数据是不一样的,则证明shuffle=True成功!
上一章 初识Pytorch之torchvision中的数据集使用 下一章 初识Pytorch之nn.Module神经网络基本架构的使用
|