我先随便写一下,时间紧任务重,后面再慢慢完善……毕竟是第一次用_
Read_data
利用console直接读图片 在console中可以看到img的相关信息 集成Dataset类,写自己的dataset
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, index):
img_name=self.img_path[index]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="C:/Users/JiangShan/Documents/git/brain_cell/first_source/pytorch_practice/hymenoptera_data/train"
ant_label_dir="ants"
ants_dataset=MyData(root_dir,ant_label_dir)
ima,label=ants_dataset[0]
print(label)
可以在console中看到各个变量或者类的属性及方法
Transform
看官方文档,主要关注输入输出,需要什么参数 或者在python文件中导入
from torchvision import transforms
在pycharm中按住ctrl点transform,再点transform会出现这个文件transforms.py transforms.py中有很多类,里面都有很好的解释,选择transforms中需要的类来作为工具,然后使用对应的功能
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path="../hymenoptera_data/train/ants/0013035.jpg"
img=Image.open(img_path)
print(img)
tensor_trans=transforms.ToTensor()
tensor_img=tensor_trans(img)
transform中几个常用类的的使用
from PIL import Image
from torchvision import transforms
imag=Image.open("../hymenoptera_data/train/ants/0013035.jpg")
#ToTensor
trans_totensor=transforms.ToTensor()
imag_tensor=trans_totensor(imag)
print(imag_tensor[0][0][0])
#Normalize
trans_normal= transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
imag_normal=trans_normal(imag_tensor)
print(imag_normal[0][0][0])
#Resize
print(imag.size)
trans_resize=transforms.Resize((512,512))
imag_resize=trans_resize(imag)
print(imag_resize.size)
trans_resize_2=transforms.Resize(512)
trans_compose=transforms.Compose([trans_resize_2])
imag_resize_2=trans_compose(imag)
print(imag_resize_2.size)
#RandomCrop
trans_rc=transforms.RandomCrop(512)
trans_compose_2=transforms.Compose([trans_rc])
for i in range(10):
i=trans_compose_2(imag)
print(i)
Dataset
import torchvision
train_set=torchvision.datasets.CIFAR10(root="../",train="true",download=True)
test_set=torchvision.datasets.CIFAR10(root="../",train="false",download=True)
print(train_set)
print(test_set)
print(test_set[0])
print(test_set.classes)
img,target=test_set[0]
print(img)
print(target)
print(test_set.classes[target])
dataset+transform
dataset_transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set=torchvision.datasets.CIFAR10(root="../",train="true",transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="../",train="false",transform=dataset_transform,download=True)
print(test_set[0])
Dataloader
from torch.utils.data import DataLoader
import torchvision
test_set=torchvision.datasets.CIFAR10(root="../",train="false",transform=torchvision.transforms.ToTensor(),download=True)
test_loader=DataLoader(dataset=test_set,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
img,target=test_set[0]
print(target)
print(img.shape)
print(len(test_loader))
'''for data in test_loader:
imgs,targets=data
print(len(data))'''
|