IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> DataLoader的collate_fn参数 default_collate 与 自定义的 collate_fn -> 正文阅读

[人工智能]DataLoader的collate_fn参数 default_collate 与 自定义的 collate_fn


collate_fn :DataLoader中的一个参数 实现自定义的batch输出。主要是自己在不满意默认的default_collate的batch处理结果的情况下,自己写一个collate函数来处理batch数据,以适配自己的模型数据接口。

1、构建数据,datasets

import torch
import numpy as np
import torch.utils.data as Data
from sklearn.datasets import make_classification
x,y = make_classification(n_samples=10, n_features=5, n_informative=3, n_classes=3, random_state=78465654)
x,y = torch.Tensor(x) , torch.Tensor(y)
'''
这里使用了TensorDataset,在此补上它的源码,注意放回的是元组形式
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)
'''
datasets = Data.TensorDataset(x,y)
print(x,y)

输出:

tensor([[ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448],
        [ 1.0556,  1.8977,  0.2893,  1.1656,  1.4490],
        [ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375],
        [ 0.3055,  0.7218,  0.4611, -0.0954,  0.6447],
        [ 1.9943,  1.5691, -0.9082,  1.4362,  1.8467],
        [ 0.0846, -0.4411,  0.4979, -1.9108,  0.3863],
        [-0.7324,  0.2127,  1.1746, -0.7633, -0.1693],
        [ 1.5626,  0.9349, -0.8652,  0.8973,  1.3515],
        [-0.1696, -1.7607, -1.0723, -0.7889, -0.8615],
        [ 1.0102,  0.6907, -0.5504,  0.7180,  0.8804]]) 
        
tensor([0., 1., 0., 2., 1., 2., 2., 1., 0., 0.])

2、使用default_collate函数构建dataloader(也就是不调用自己实现的collate_fn)

dataloader_with_defaultcollate = Data.DataLoader(dataset=datasets, batch_size = 3)
for (i,j) in dataloader_with_defaultcollate:
    print(i)
    print(j)

输出:

tensor([[ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448],
        [ 1.0556,  1.8977,  0.2893,  1.1656,  1.4490],
        [ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
tensor([0., 1., 0.])
tensor([[ 0.3055,  0.7218,  0.4611, -0.0954,  0.6447],
        [ 1.9943,  1.5691, -0.9082,  1.4362,  1.8467],
        [ 0.0846, -0.4411,  0.4979, -1.9108,  0.3863]])
tensor([2., 1., 2.])
tensor([[-0.7324,  0.2127,  1.1746, -0.7633, -0.1693],
        [ 1.5626,  0.9349, -0.8652,  0.8973,  1.3515],
        [-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]])
tensor([2., 1., 0.])
tensor([[ 1.0102,  0.6907, -0.5504,  0.7180,  0.8804]])
tensor([0.])

3、使用自己的mycollate函数构建dataloader,以实现和default_collate一样的效果:

def mycollate(batch):
    '''
        输入的batch数据
        [(tensor([ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448]), tensor(0.)),
         (tensor([1.0556, 1.8977, 0.2893, 1.1656, 1.4490]), tensor(1.)),
         (tensor([ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]), tensor(0.))]

        目标:
        tensor([[ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448],
                [ 1.0556,  1.8977,  0.2893,  1.1656,  1.4490],
                [ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
        tensor([0., 1., 0.])
    '''
    #----------------------zip(*)-----------------------
    '''
    	zip(*)的使用示例
        a = [(1,2), (1,0),(4,4)]
        x,y = zip(*a)
        print(x,y)# 直接输出
        for i in zip(*a): # 迭代输出
            print(i)
        output:
            (1, 1, 4) (2, 0, 4)
            (1, 1, 4)
            (2, 0, 4)
    '''
    #-------------------torch.stack--------------------------
    '''
    torch.stack(inputs, dim=0): 把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,
    也就是在增加新的维度进行堆叠。 outputs = torch.stack(inputs, dim=?)→Tensor
    '''
    #print(batch)
    batch = [x for x in zip(*batch)]
    x,y = batch
    return (torch.stack(x,0), torch.stack(y,0))
dataloader_with_mycollate = Data.DataLoader(dataset=datasets, batch_size = 3, collate_fn=mycollate)
for (i,j) in dataloader_with_mycollate:
    print(i)
    print(j)

输出:

tensor([[ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448],
        [ 1.0556,  1.8977,  0.2893,  1.1656,  1.4490],
        [ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
tensor([0., 1., 0.])
tensor([[ 0.3055,  0.7218,  0.4611, -0.0954,  0.6447],
        [ 1.9943,  1.5691, -0.9082,  1.4362,  1.8467],
        [ 0.0846, -0.4411,  0.4979, -1.9108,  0.3863]])
tensor([2., 1., 2.])
tensor([[-0.7324,  0.2127,  1.1746, -0.7633, -0.1693],
        [ 1.5626,  0.9349, -0.8652,  0.8973,  1.3515],
        [-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]])
tensor([2., 1., 0.])
tensor([[ 1.0102,  0.6907, -0.5504,  0.7180,  0.8804]])
tensor([0.])

4、在输出的格式上稍作修改,输出字典格式

def mycollate(batch):
#     print(batch)
    batch = [x for x in zip(*batch)]
    x,y = batch
    return {"DATA":torch.stack(x,0), "CLASS":torch.stack(y,0)}
dataloader_with_mycollate = Data.DataLoader(dataset=datasets, batch_size = 3, collate_fn=mycollate)
for i in dataloader_with_mycollate:
    print(i)

输出:

{'DATA': tensor([[ 1.9991, -0.5978, -1.2000, -1.9008,  1.6448],
        [ 1.0556,  1.8977,  0.2893,  1.1656,  1.4490],
        [ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]]), 'CLASS': tensor([0., 1., 0.])}
{'DATA': tensor([[ 0.3055,  0.7218,  0.4611, -0.0954,  0.6447],
        [ 1.9943,  1.5691, -0.9082,  1.4362,  1.8467],
        [ 0.0846, -0.4411,  0.4979, -1.9108,  0.3863]]), 'CLASS': tensor([2., 1., 2.])}
{'DATA': tensor([[-0.7324,  0.2127,  1.1746, -0.7633, -0.1693],
        [ 1.5626,  0.9349, -0.8652,  0.8973,  1.3515],
        [-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]]), 'CLASS': tensor([2., 1., 0.])}
{'DATA': tensor([[ 1.0102,  0.6907, -0.5504,  0.7180,  0.8804]]), 'CLASS': tensor([0.])}

5、关于Default_collate函数

如果没有自定义的collate_fn函数的情况下,系统会自动调用default_collate函数,

根据其输出的batch的类型不同其处理的过程,结果也有不同。

下面是输出类型映射的一般输入类型(基于批处理中元素的类型)
 * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
 * NumPy Arrays -> :class:`torch.Tensor`
 * `float` -> :class:`torch.Tensor`
 * `int` -> :class:`torch.Tensor`
 * `str` -> `str` (unchanged)
 * `bytes` -> `bytes` (unchanged)
 * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
 * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
 default_collate([V2_1, V2_2, ...]), ...]`
 * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
 default_collate([V2_1, V2_2, ...]), ...]`

Pytorch官网上的一些例子

# Example with a batch of `int`s:
default_collate([0, 1, 2, 3])
# Example with a batch of `str`s:
default_collate(['a', 'b', 'c'])
# Example with `Map` inside the batch:
default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
# Example with `NamedTuple` inside the batch:
Point = namedtuple('Point', ['x', 'y'])
default_collate([Point(0, 0), Point(1, 1)])
# Example with `Tuple` inside the batch:
default_collate([(0, 1), (2, 3)])
# Example with `List` inside the batch:
default_collate([[0, 1], [2, 3]])

Default_collate源码:(一堆if else 判断输出数据的类型以便进行不同的处理,通过递归的形式处理多重嵌套)

def default_collate(batch):
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage).resize_(len(batch), *list(elem.size()))
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        try:
            return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
        except TypeError:
            # The mapping type may not support `__init__(iterable)`.
            return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

        if isinstance(elem, tuple):
            return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
        else:
            try:
                return elem_type([default_collate(samples) for samples in transposed])
            except TypeError:
                # The sequence type may not support `__init__(iterable)` (e.g., `range`).
                return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

参考链接:
default_collate源码

PyTorch 1.11.0 documentation

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-13 11:44:22  更:2022-05-13 11:45:51 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/1 21:50:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码