IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> collate_fn参数 -> 正文阅读

[Python知识库]collate_fn参数

collate英文释义

vt. 校对,整理
n. 核对;小吃
n. 校对者,整理者

PyTorch中提供的参数细节

【link】
collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
这里是把一些列样本融合为小batch的张量。

代码示例

DataLoader中默认的collate_function函数:

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

主要功能:对我们输入的batch进行处理,得到新的batch。

疑问解析

为什么设置好collate_fn会出现报错?

数据的输入过程中不仅和collate_fn有关系和batch_size也存在关系。
Pytorch对于训练维度的检查时按照每个batch_size的维度进行检查的,因此数据的数量和batch_size的大小会影响结果,不同的batch维度不匹配也会导致报错。

维度不一致代码示例

import numpy as np
import torch


# collate_fn:即用于collate的function,用于整理数据的函数

class NewDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):  # 必须重写
        return len(self.data)

    # 这个方法对数据集添加索引,进行相应取值
    def __getitem__(self, idx):  # 必须重写
        return self.data[idx]


# batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值
def New_collate(batch):
    real_batch = np.array(batch)  # 等式右边是对batch的处理过程
    real_batch = torch.from_numpy(real_batch)
    return real_batch


# 如果数据维度不一致,使用默认的collate_fn会报错
# 大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘。
# 如果后面解决这个问题的方法是:在不足维度上进行补0操作
test = [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]
dataset = NewDataset(test)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=New_collate)
# list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。
# 对获取到的迭代器不断使?next()函数来获取下?条数据。
it = iter(dataloader)
c = next(it)
print(c)
c = next(it)
print(c)

运行结果示例:
在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-10-22 21:11:19  更:2022-10-22 21:11:59 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 6:50:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码