IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 2021-11-08 -> 正文阅读

[人工智能]2021-11-08


土堆的pytorch学习记录

DAY1 p1-p15

1.数据的读取

from torch.utils.data import Dataset
from PIL import Image
import os


class MyData(Dataset):

    #os.path.join(A , B)是对路径进行一个拼接,前面的A加上后面的B
    #os.listdir(A)是对这个路径里的文件进行列表排序,以便索引
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir , self.label_dir)
        self.img_path = os.listdir(self.path)


    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir , img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img , label

    def __len__(self):
        return len(self.img_path)`

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir , ants_label_dir)
bees_dataset = MyData(root_dir , bees_label_dir)
img , label= bees_dataset[1]
img.show()

2.可视化展示

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image

write = SummaryWriter("logs")
# 这里的logs指的是生成的文件夹的名字,不用创建,默认生成

img_path = "data/train/ants_image/0013035.jpg"

# Image.open默认将打开的图片类型变成PIL型
# 下面做的是从PIL型变成Array型
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)

# 将图片进行可视化展示
write.add_image("test" , img_array , 1 , dataformats='HWC')

# y=2x
for i in range(100):
    # 可视化绘图
    write.add_scalar("y=2x" , 2*i , i)

write.close()

注意最后要查看展示的图片是要运行代码后,在terminal里敲 tensorboard --logdir=SummaryWriter里生成的文件夹名字

3.一些有用的transforms函数

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

writer = SummaryWriter("logs")
img = Image.open("dataset/train/ants/0013035.jpg")
print(img)

# Totensor的使用:将图片的类型转成tensor型
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("ToTensor", img_tensor)



# Normalize的使用:使用均值和标准差对张量图像进行归一化
# Normalize输入的图像类型得是tensor型
# output[channel] = (input[channel] - mean[channel]) / std[channel]
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5,0.5,0.5] , [0.5,0.5,0.5])
# 一般RGB图像channels为3,灰度图像为1
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize" , img_norm)

# Resize:重新调整图片大小至想要的尺寸
# Resize的输入类型得是PIL
print(img.size)
trans_resize = transforms.Resize((512,512))
# img PIL -> resize ->img_resize PIL
img_resize = trans_resize(img)
# img_resize PIL -> tensor ->img_resize tensor
img_resize = trans_totensor(img)
writer.add_image("Resize" , img_resize , 0)
print(img_resize)

# Compose - resize - 2
trans_resize_2 = transforms.Resize(312)
trans_compose = transforms.Compose([trans_resize_2,trans_totensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize" , img_resize_2 , 1)


writer.close()

4.网上数据的下载与安装

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 定义一个transform操作,将数据集中所有的图片都转变成tensor类型
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

# 下载数据集并载入
# root路径是存放的路径,定好名后可以自己生成 , train表示是不是训练集 ,downlo一般设为true
train_set = torchvision.datasets.CIFAR10(root="./dataseset_CIFAR10", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataseset_CIFAR10", train=False, transform=dataset_transform, download=True)

# print(test_set[0])
# print(test_set.classes)
#
# img , target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()

# print(test_set[0])

writer = SummaryWriter("p10")
for i in range(10):
    img , target = test_set[i]
    writer.add_image("test_set" , img , i)

writer.close()

5.数据集的载入

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataseset_CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

#载入数据集:batch_size是指一次抽取多少张图片,shuffle(true指打乱顺序抽取,false指固定顺序抽取) ,num_workers一般设置为0 , drop_last是指最后抽剩的数据还要不要(false要,true不要)
test_loader = DataLoader(dataset=test_data , batch_size=64 , shuffle=True , num_workers=0 , drop_last=False)

# 测试数据集中第一张图片及target
img , target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
step=0
for data in test_loader:
    imgs , targets = data
    # print(imgs.shape)
    # print(targets)
    writer.add_images("test_data" , imgs ,step)
    step = step + 1

writer.close()
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-09 19:29:17  更:2021-11-09 19:33:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 8:07:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码