IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题 -> 正文阅读

[Python知识库]治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题

torch大神请忽略此文。。。

1,一个简单例子回顾DataSet

from torch.utils.data import Dataset

class dataset(Dataset):
    def __init__(self):
        # 需要转化为array,不然运行结果会很奇怪
        self.data = np.array([[1,1,1,1],
                [2,2,2,2],
                [3,3,3,3],
                [4,4,4,4],
                [5,5,5,5],
                [6,6,6,6]])        
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # return data[idx],这句少了self,报错...
        return self.data[idx]

dataset = dataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=2,
                                         shuffle=False,
                                         num_workers=1)

for i, data_ in enumerate(dataloader):
    print(i)
    print(data_)

观察运行结果:

0
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
        [4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
        [6, 6, 6, 6]])

这个例子足够理解DataSet了

2,维度

参考这篇文章:
https://blog.csdn.net/xddwz/article/details/108405817

# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
 
import cv2
from PIL import Image
 
 
class MyDataset(Dataset):
    def __init__(self, transform=None):
        self.transform = transforms.Compose([
            transforms.ToTensor()      # 这里仅以最基本的为例
        ])
        self.image_path = './image_data2/'
        self.image_names = os.listdir(self.image_path)
 
    def __len__(self):
        return len(self.image_names)
 
    def __getitem__(self, item):
        image_name = self.image_names[item]
 
        image = cv2.imread(os.path.join(self.image_path, image_name))   # 读到的是BGR数据
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)                  # 转化为RGB,也可以用img = img[:, :, (2, 1, 0)]
        # 这时的image是H,W,C的顺序,因此下面需要转化为C, H, W
        image = torch.from_numpy(image).permute(2, 0, 1)             
 
        # image = Image.open(os.path.join(self.image_path, image_name))
        # # print(image.shape)
        # image = self.transform(image)
        return image

这段代码是上述链接中的,搬运过来的原因是强调一个事情:我们知道torch的维度顺序是BCHW,而上述代码中的__getitem__(),是返回一张图片,那么这个时候我们需要注意的是,单张图片本来的维度顺序是HWC,即维度是(height, width, channel),我们需要将它的维度调整为(channel, height, width),然后再返回。

同时,根据上述代码,也需要注意,返回的image的shape不是(1, channel, height, width),而是(channel, height, width),batch对应的那一维在__getitem__()中不需要考虑。

对于一些channel为1的图片,如果需要增加channel维度,那么只需要squeeze(0)就行了。

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-04-22 18:32:27  更:2022-04-22 18:32:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 13:17:36-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计