时间驱动
使用双层全连接网络进行MNIST分类
我们关注clock_driven.examples.lif_fc_mnist.py ,只有一个main函数,原注释如下
def main():
'''
* :ref:`API in English <lif_fc_mnist.main-en>`
.. _lif_fc_mnist.main-cn:
:return: None
使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。
* :ref:`中文API <lif_fc_mnist.main-cn>`
.. _lif_fc_mnist.main-en:
The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training
and shows accuracy on test dataset.
'''
a. 网络结构
双层全连接FC网络结构如下
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False),
neuron.LIFNode(tau=tau),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau)
)
LIF神经元的动态微分方程
τ
m
d
V
(
t
)
d
t
=
?
(
V
(
t
)
?
V
reset?
)
+
R
m
I
(
t
)
(1)
\tau_{m} \frac{\mathrm{d} V(t)}{\mathrm{d} t}=-\left(V(t)-V_{\text {reset }}\right)+R_{m} I(t) \tag{1}
τm?dtdV(t)?=?(V(t)?Vreset??)+Rm?I(t)(1)
相应的差分方程:
τ
m
(
V
(
t
)
?
V
(
t
?
1
)
)
=
?
(
V
(
t
?
1
)
?
V
r
e
s
e
t
)
+
X
(
t
)
(2)
\tau_{m}(V(t)-V(t-1))=-\left(V(t-1)-V_{r e s e t}\right)+X(t) \tag{2}
τm?(V(t)?V(t?1))=?(V(t?1)?Vreset?)+X(t)(2)
neuron.py 中LIFNODE 继承自BaseNode ,BaseNode 中forward按照充电、放电、重置的顺序进行前向传播。
class BaseNode(nn.Module):
def __init__(self, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False, monitor_state=False):
'''
* :ref:`API in English <BaseNode.__init__-en>`
.. _BaseNode.__init__-cn:
:param v_threshold: 神经元的阈值电压
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,则电压会被减去 ``v_threshold``
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:param detach_reset: 是否将reset过程的计算图分离
:param monitor_state: 是否设置监视器来保存神经元的电压和释放的脉冲。
若为 ``True``,则 ``self.monitor`` 是一个字典,键包括 ``h``, ``v`` ``s``,分别记录充电后的电压、释放脉冲后的电压、释放的脉冲。
对应的值是一个链表。为了节省显存(内存),列表中存入的是原始变量转换为 ``numpy`` 数组后的值。
还需要注意,``self.reset()`` 函数会清空这些链表
可微分SNN神经元的基类神经元。
'''
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.detach_reset = detach_reset
self.surrogate_function = surrogate_function
self.monitor = monitor_state
self.reset()
@abstractmethod
def neuronal_charge(self, dv: torch.Tensor):
'''
定义神经元的充电差分方程。子类必须实现这个函数。
'''
raise NotImplementedError
def neuronal_fire(self):
'''
根据当前神经元的电压、阈值,计算输出脉冲。
'''
if self.monitor:
if self.monitor['h'].__len__() == 0:
if self.v_reset is None:
self.monitor['h'].append(self.v.data.cpu().numpy().copy() * 0)
else:
self.monitor['h'].append(self.v.data.cpu().numpy().copy() * self.v_reset)
else:
self.monitor['h'].append(self.v.data.cpu().numpy().copy())
self.spike = self.surrogate_function(self.v - self.v_threshold)
if self.monitor:
self.monitor['s'].append(self.spike.data.cpu().numpy().copy())
def neuronal_reset(self):
'''
根据当前神经元释放的脉冲,对膜电位进行重置。
'''
if self.detach_reset:
spike = self.spike.detach()
else:
spike = self.spike
if self.v_reset is None:
self.v = self.v - spike * self.v_threshold
else:
self.v = (1 - spike) * self.v + spike * self.v_reset
if self.monitor:
self.monitor['v'].append(self.v.data.cpu().numpy().copy())
def forward(self, dv: torch.Tensor):
'''
:param dv: 输入到神经元的电压增量
:return: 神经元的输出脉冲
按照充电、放电、重置的顺序进行前向传播。
'''
self.neuronal_charge(dv)
self.neuronal_fire()
self.neuronal_reset()
return self.spike
def set_monitor(self, monitor_state=True):
'''
设置开启或关闭monitor。
'''
if monitor_state:
self.monitor = {'h': [], 'v': [], 's': []}
else:
self.monitor = False
def extra_repr(self):
'''
设置模块的额外表示信息。
'''
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
def reset(self):
'''
重置神经元为初始状态,也就是将电压设置为 ``v_reset``。
如果子类的神经元还含有其他状态变量,需要在此函数中将这些状态变量全部重置。
一般在下一个epoch训练前进行
'''
if self.v_reset is None:
self.v = 0.0
else:
self.v = self.v_reset
self.spike = None
if self.monitor:
self.monitor = {'h': [], 'v': [], 's': []}
LIFNODE 继承BaseNode ,根据公式(2)的神经元更新方程,重写neuronal_charge
class LIFNode(BaseNode):
def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False,
monitor_state=False):
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)
self.tau = tau
def extra_repr(self):
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, tau={self.tau}'
def neuronal_charge(self, dv: torch.Tensor):
if self.v_reset is None:
self.v += (dv - self.v) / self.tau
else:
self.v += (dv - (self.v - self.v_reset)) / self.tau
回顾一下网络
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False),
neuron.LIFNode(tau=tau),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau)
)
这里neuron.LIFNode 的作用就是将14*14的模拟连续的神经元变为LIF神经元,按0,1发放脉冲,整个网络也就变成了SNN
b. 泊松编码
源程序对输入图像进行了泊松编码,无状态编码器的基类是StatelessEncoder
class StatelessEncoder(nn.Module):
def __init__(self):
"""
无状态编码器的基类。无状态编码器 ``encoder = StatelessEncoder()``,直接调用 ``encoder(x)`` 即可将 ``x`` 编码为 ``spike``。
"""
super().__init__()
@abstractmethod
def forward(self, x: torch.Tensor):
"""
:param x: 输入数据
:type x: torch.Tensor
:return: ``spike``, shape 与 ``x.shape`` 相同
:rtype: torch.Tensor
"""
raise NotImplementedError
PoissonEncoder 继承自StatelessEncoder
class PoissonEncoder(StatelessEncoder):
def __init__(self):
"""
无状态的泊松编码器。输出脉冲的发放概率与输入 ``x`` 相同。
.. warning::
必须确保 ``0 <= x <= 1``。
"""
super().__init__()
def forward(self, x: torch.Tensor):
out_spike = torch.rand_like(x).le(x)
return out_spike
c. 训练、重置
训练部分即按一次batchsize 个样本进行权重的更新,更新之后LIF神经元部分需要重置。
for epoch in range(train_epoch):
net.train()
for img, label in tqdm(train_data_loader):
img = img.to(device)
label = label.to(device)
label_one_hot = F.one_hot(label, 10).float()
optimizer.zero_grad()
for t in range(T):
if t == 0:
out_spikes_counter = net(encoder(img).float())
else:
out_spikes_counter += net(encoder(img).float())
out_spikes_counter_frequency = out_spikes_counter / T
loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
loss.backward()
optimizer.step()
functional.reset_net(net)
accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
writer.add_scalar('train_accuracy', accuracy, train_times)
train_accs.append(accuracy)
train_times += 1
这里的functional.reset_net(net) 函数如下:
def reset_net(net: nn.Module):
'''
:param net: 任何属于 ``nn.Module`` 子类的网络
:return: None
将网络的状态重置。做法是遍历网络中的所有 ``Module``,若含有 ``reset()`` 函数,则调用。
reset()函数即是LIFNode类中的方法
'''
for m in net.modules():
if hasattr(m, 'reset'):
m.reset()
一个epoch参数更新后在整个测试集上进行模型测试,方便训练时看效果:
net.eval()
with torch.no_grad():
test_sum = 0
correct_sum = 0
for img, label in test_data_loader:
img = img.to(device)
for t in range(T):
if t == 0:
out_spikes_counter = net(encoder(img).float())
else:
out_spikes_counter += net(encoder(img).float())
correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
test_sum += label.numel()
functional.reset_net(net)
test_accuracy = correct_sum / test_sum
writer.add_scalar('test_accuracy', test_accuracy, epoch)
test_accs.append(test_accuracy)
max_test_accuracy = max(max_test_accuracy, test_accuracy)
print(f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}')
d. 结果、分析
需要注意的是,训练这样的SNN,所需显存数量与仿真时长 T 线性相关,更长的 T 相当于使用更小的仿真步长,训练更为“精细”,但训练效果不一定更好,因此 T 太大,SNN在时间上展开后就会变成一个非常深的网络,梯度的传递容易衰减或爆炸。由于我们使用了泊松编码器,因此需要较大的 T。
100个epoch下不同batch_size的结果。大的batch_size达到相同精度需要更多的epoch,但是处理单个epoch的时间更短。
|