collate_fn :DataLoader中的一个参数 实现自定义的batch输出。主要是自己在不满意默认的default_collate的batch处理结果的情况下,自己写一个collate函数来处理batch数据,以适配自己的模型数据接口。
1、构建数据,datasets
import torch
import numpy as np
import torch.utils.data as Data
from sklearn.datasets import make_classification
x,y = make_classification(n_samples=10, n_features=5, n_informative=3, n_classes=3, random_state=78465654)
x,y = torch.Tensor(x) , torch.Tensor(y)
'''
这里使用了TensorDataset,在此补上它的源码,注意放回的是元组形式
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
'''
datasets = Data.TensorDataset(x,y)
print(x,y)
输出:
tensor([[ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448],
[ 1.0556, 1.8977, 0.2893, 1.1656, 1.4490],
[ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375],
[ 0.3055, 0.7218, 0.4611, -0.0954, 0.6447],
[ 1.9943, 1.5691, -0.9082, 1.4362, 1.8467],
[ 0.0846, -0.4411, 0.4979, -1.9108, 0.3863],
[-0.7324, 0.2127, 1.1746, -0.7633, -0.1693],
[ 1.5626, 0.9349, -0.8652, 0.8973, 1.3515],
[-0.1696, -1.7607, -1.0723, -0.7889, -0.8615],
[ 1.0102, 0.6907, -0.5504, 0.7180, 0.8804]])
tensor([0., 1., 0., 2., 1., 2., 2., 1., 0., 0.])
2、使用default_collate函数构建dataloader(也就是不调用自己实现的collate_fn)
dataloader_with_defaultcollate = Data.DataLoader(dataset=datasets, batch_size = 3)
for (i,j) in dataloader_with_defaultcollate:
print(i)
print(j)
输出:
tensor([[ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448],
[ 1.0556, 1.8977, 0.2893, 1.1656, 1.4490],
[ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
tensor([0., 1., 0.])
tensor([[ 0.3055, 0.7218, 0.4611, -0.0954, 0.6447],
[ 1.9943, 1.5691, -0.9082, 1.4362, 1.8467],
[ 0.0846, -0.4411, 0.4979, -1.9108, 0.3863]])
tensor([2., 1., 2.])
tensor([[-0.7324, 0.2127, 1.1746, -0.7633, -0.1693],
[ 1.5626, 0.9349, -0.8652, 0.8973, 1.3515],
[-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]])
tensor([2., 1., 0.])
tensor([[ 1.0102, 0.6907, -0.5504, 0.7180, 0.8804]])
tensor([0.])
3、使用自己的mycollate函数构建dataloader,以实现和default_collate一样的效果:
def mycollate(batch):
'''
输入的batch数据
[(tensor([ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448]), tensor(0.)),
(tensor([1.0556, 1.8977, 0.2893, 1.1656, 1.4490]), tensor(1.)),
(tensor([ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]), tensor(0.))]
目标:
tensor([[ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448],
[ 1.0556, 1.8977, 0.2893, 1.1656, 1.4490],
[ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
tensor([0., 1., 0.])
'''
'''
zip(*)的使用示例
a = [(1,2), (1,0),(4,4)]
x,y = zip(*a)
print(x,y)# 直接输出
for i in zip(*a): # 迭代输出
print(i)
output:
(1, 1, 4) (2, 0, 4)
(1, 1, 4)
(2, 0, 4)
'''
'''
torch.stack(inputs, dim=0): 把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,
也就是在增加新的维度进行堆叠。 outputs = torch.stack(inputs, dim=?)→Tensor
'''
batch = [x for x in zip(*batch)]
x,y = batch
return (torch.stack(x,0), torch.stack(y,0))
dataloader_with_mycollate = Data.DataLoader(dataset=datasets, batch_size = 3, collate_fn=mycollate)
for (i,j) in dataloader_with_mycollate:
print(i)
print(j)
输出:
tensor([[ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448],
[ 1.0556, 1.8977, 0.2893, 1.1656, 1.4490],
[ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]])
tensor([0., 1., 0.])
tensor([[ 0.3055, 0.7218, 0.4611, -0.0954, 0.6447],
[ 1.9943, 1.5691, -0.9082, 1.4362, 1.8467],
[ 0.0846, -0.4411, 0.4979, -1.9108, 0.3863]])
tensor([2., 1., 2.])
tensor([[-0.7324, 0.2127, 1.1746, -0.7633, -0.1693],
[ 1.5626, 0.9349, -0.8652, 0.8973, 1.3515],
[-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]])
tensor([2., 1., 0.])
tensor([[ 1.0102, 0.6907, -0.5504, 0.7180, 0.8804]])
tensor([0.])
4、在输出的格式上稍作修改,输出字典格式
def mycollate(batch):
batch = [x for x in zip(*batch)]
x,y = batch
return {"DATA":torch.stack(x,0), "CLASS":torch.stack(y,0)}
dataloader_with_mycollate = Data.DataLoader(dataset=datasets, batch_size = 3, collate_fn=mycollate)
for i in dataloader_with_mycollate:
print(i)
输出:
{'DATA': tensor([[ 1.9991, -0.5978, -1.2000, -1.9008, 1.6448],
[ 1.0556, 1.8977, 0.2893, 1.1656, 1.4490],
[ 0.3073, -0.9410, -0.6531, -0.8985, -0.0375]]), 'CLASS': tensor([0., 1., 0.])}
{'DATA': tensor([[ 0.3055, 0.7218, 0.4611, -0.0954, 0.6447],
[ 1.9943, 1.5691, -0.9082, 1.4362, 1.8467],
[ 0.0846, -0.4411, 0.4979, -1.9108, 0.3863]]), 'CLASS': tensor([2., 1., 2.])}
{'DATA': tensor([[-0.7324, 0.2127, 1.1746, -0.7633, -0.1693],
[ 1.5626, 0.9349, -0.8652, 0.8973, 1.3515],
[-0.1696, -1.7607, -1.0723, -0.7889, -0.8615]]), 'CLASS': tensor([2., 1., 0.])}
{'DATA': tensor([[ 1.0102, 0.6907, -0.5504, 0.7180, 0.8804]]), 'CLASS': tensor([0.])}
5、关于Default_collate函数
如果没有自定义的collate_fn函数的情况下,系统会自动调用default_collate函数,
根据其输出的batch的类型不同其处理的过程,结果也有不同。
下面是输出类型映射的一般输入类型(基于批处理中元素的类型)
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
* NumPy Arrays -> :class:`torch.Tensor`
* `float` -> :class:`torch.Tensor`
* `int` -> :class:`torch.Tensor`
* `str` -> `str` (unchanged)
* `bytes` -> `bytes` (unchanged)
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
default_collate([V2_1, V2_2, ...]), ...]`
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
default_collate([V2_1, V2_2, ...]), ...]`
Pytorch官网上的一些例子
default_collate([0, 1, 2, 3])
default_collate(['a', 'b', 'c'])
default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
Point = namedtuple('Point', ['x', 'y'])
default_collate([Point(0, 0), Point(1, 1)])
default_collate([(0, 1), (2, 3)])
default_collate([[0, 1], [2, 3]])
Default_collate源码:(一堆if else 判断输出数据的类型以便进行不同的处理,通过递归的形式处理多重嵌套)
def default_collate(batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == ():
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
try:
return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
except TypeError:
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = list(zip(*batch))
if isinstance(elem, tuple):
return [default_collate(samples) for samples in transposed]
else:
try:
return elem_type([default_collate(samples) for samples in transposed])
except TypeError:
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
参考链接: default_collate源码
PyTorch 1.11.0 documentation
|