工作内容:
一、 在公共数据集中测试自己搭建的一维CNN模型。 二、 刘老师开会教授做PPT技巧,并且自己拿挑战杯PPT进行练手。
出现的问题:
针对一: 在训练过程中出现LOSS为Nan的情况。
解决过程:
针对一: 初步怀疑是学习率lr设置不得当,导致的梯度爆炸,将lr设置为0,发现此问题依旧存在;接下来目光转向数据集,数据集中可能存在空值。将batch_size设置为1,而且dataloader中shuffle参数设置为True,进行训练时观察第几步出现loss为nan的情况,调试进入观察输入值,发现输入数据中存在空值,故问题排查出来。 (由于是pandas工具包加载的数据集,所以用df.isnull()) 由于数据集的加载是用pandas工具包直接加载的csv文件,若对空值数据进行处理,有以下几种办法: https://zhuanlan.zhihu.com/p/457718042 1、 删除空值:df.dropna() 2、 填补空值:df.fillna()
其他收获:
一、发现一种简单的打包时间序列CSV格式数据集加载进模型的方法: 1、建立自己的数据集类 2、在类中直接加载csv数据集 3、对数据集样本的划分放在__getitem__()函数中
二、学到一种新的利用滑动窗口处理时间序列数据的方法: 借用__getitem__(self,idx)函数的特殊性质,idx为dataloader()函数每次取样本的索引,我们只需在此索引上加上5(方便说明所以选取5),即可得到一个序列长度为5的样本,即:5行数据=一个样本。注意:要在数据集长度的基础上减去5,防止末尾的idx取不到5个长度的值。 样本标签值的处理办法是(假设类别为3,且经过one-hot编码): 1、如果在类别交接处,比如: 2、先选取5条数据的标签值:y = self.df.iloc[idx: idx + 5, 3:].values
3、对以上标签数据按列相加:z = np.sum(y, axis=0)
4、利用np.argmax()返回以上数据数值最大处的索引:ind = np.argmax(z) ind的值为0,即索引0处的值最大。 5、创建一个全0数组,大小与热编码后的标签数据一样:label = np.zeros_like(self.df.iloc[0, 3:].values) ,如下所示:
6、将label的ind索引处设置为1,即此样本的标签就得到了。是5条数据中标签最多的作为此条样本的标签。label[ind] = 1
7、除了在__getitem__()函数中返回数据x,x = self.df.iloc[idx: idx + 5, : 3].values 还要返回label。 8、那么,一个csv格式的时间序列数据就可以打包成dataset数据集加载进模型中进行训练了。
class IMUDataset(Dataset):
def __init__(self, mode='test', transform=None):
if mode == 'train':
self.df = pd.read_csv('Data_pre/train.csv', header=0, index_col=0)
self.df = self.df.fillna(method='ffill',axis=0)
elif mode == 'test':
self.df = pd.read_csv('Data_pre/test.csv', header=0, index_col=0)
self.df = self.df.fillna(method='ffill',axis=0)
elif mode == 'val':
self.df = pd.read_csv('Data_pre/val.csv', header=0, index_col=0)
self.df = self.df.fillna(method='ffill',axis=0)
self.transform = transform
self.df = self.df.reset_index()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
y = self.df.iloc[idx: idx + 80, 3:].values
ind = np.argmax(np.sum(y, axis=0))
label = np.zeros_like(self.df.iloc[0, 3:].values)
label = label.astype('float')
label[ind] = 1
x = self.df.iloc[idx: idx + 80, : 3].values
x = x.astype('float')
assert (x.shape == (80, 3))
assert (label.shape == (6,))
return x, label
|