Torch跑的时候报了这个:
TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
我之前加载数据的操作应该木啥问题好吧:
就是先继承 torch.utils.data.Dataset 类写一个子类 然后 init 一个 torch.utils.data.DataLoader 对象
结果在调用的时候:
for x, y in trainloader:
print(x.shape)
print(y.shape)
break
TypeError
default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
离谱,我在 torch.utils.data.Dataset.__getitem__ 下面明明写的是返回numpy数组啊:
if self.mode != 'test':
label_info = currrent_Series[17:-1].values
return number_info_array, label_info
else:
return number_info_array
无论咋打印 number_info_array , label_info 他俩都是 ndarray 直到我把他们打印出来一看,hhh:
这个数组里边的 dtype 都是 object , hhh,想起来之前在这个数组里边存过 str, 怪不得现在是 object
解决方式:手动改一下类型就行了
|