一、Dataset、DataLoader介绍及使用
代码所用数据集:https://pan.baidu.com/s/16ac0Z97Za-nmwDD90EaUGw 提取码:160h
from torch.utils.data import Dataset, DataLoader
path = r"C:\Users\Administrator\PycharmProjects\pythonProject\data\SMSSpamCollection"
class MyDataset(Dataset):
def __init__(self):
self.lines = open(path, encoding='utf-8').readlines()
def __getitem__(self, index):
cur_line = self.lines[index].strip()
label = cur_line[:4].strip()
return cur_line, label
def __len__(self):
return len(self.lines)
my_dataset = MyDataset()
data_load = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True, drop_last=True)
if __name__ == '__main__':
for index, i in enumerate(data_load):
print(index, i)
break
print(len(my_dataset))
print(len(data_load))
二、pytorch自带数据集介绍
pytorch中自带数据集由两个API提供分别是:torchvision、torchtext
- torchvision.datasets
- torchtext.datasets
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
mnist = MNIST(root='./data', train=True, download=True)
DataLoader(dataset=mnist, batch_size=256, shuffle=True)
print(mnist[0])
|