collate英文释义
vt. 校对,整理 n. 核对;小吃 n. 校对者,整理者
PyTorch中提供的参数细节
【link】 collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. 这里是把一些列样本融合为小batch的张量。
代码示例
DataLoader中默认的collate_function函数:
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
主要功能:对我们输入的batch进行处理,得到新的batch。
疑问解析
为什么设置好collate_fn会出现报错?
数据的输入过程中不仅和collate_fn有关系和batch_size也存在关系。 Pytorch对于训练维度的检查时按照每个batch_size的维度进行检查的,因此数据的数量和batch_size的大小会影响结果,不同的batch维度不匹配也会导致报错。
维度不一致代码示例
import numpy as np
import torch
# collate_fn:即用于collate的function,用于整理数据的函数
class NewDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self): # 必须重写
return len(self.data)
# 这个方法对数据集添加索引,进行相应取值
def __getitem__(self, idx): # 必须重写
return self.data[idx]
# batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值
def New_collate(batch):
real_batch = np.array(batch) # 等式右边是对batch的处理过程
real_batch = torch.from_numpy(real_batch)
return real_batch
# 如果数据维度不一致,使用默认的collate_fn会报错
# 大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘。
# 如果后面解决这个问题的方法是:在不足维度上进行补0操作
test = [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]
dataset = NewDataset(test)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=New_collate)
# list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。
# 对获取到的迭代器不断使?next()函数来获取下?条数据。
it = iter(dataloader)
c = next(it)
print(c)
c = next(it)
print(c)
运行结果示例:
|