IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 【Spikingjelly】SNN框架教程的代码解读 -> 正文阅读

[Python知识库]【Spikingjelly】SNN框架教程的代码解读

时间驱动

使用双层全连接网络进行MNIST分类

我们关注clock_driven.examples.lif_fc_mnist.py,只有一个main函数,原注释如下

def main():
    '''
    * :ref:`API in English <lif_fc_mnist.main-en>`

    .. _lif_fc_mnist.main-cn:

    :return: None

    使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

    * :ref:`中文API <lif_fc_mnist.main-cn>`

    .. _lif_fc_mnist.main-en:

    The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training
    and shows accuracy on test dataset.
    '''

a. 网络结构

双层全连接FC网络结构如下

    # 定义并初始化网络
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 14 * 14, bias=False), #这里不加bias应该是偏置在SNN中不好表示
        neuron.LIFNode(tau=tau),
        nn.Linear(14 * 14, 10, bias=False),
        neuron.LIFNode(tau=tau)
    )

LIF神经元的动态微分方程
τ m d V ( t ) d t = ? ( V ( t ) ? V reset? ) + R m I ( t ) (1) \tau_{m} \frac{\mathrm{d} V(t)}{\mathrm{d} t}=-\left(V(t)-V_{\text {reset }}\right)+R_{m} I(t) \tag{1} τm?dtdV(t)?=?(V(t)?Vreset??)+Rm?I(t)(1)

相应的差分方程:
τ m ( V ( t ) ? V ( t ? 1 ) ) = ? ( V ( t ? 1 ) ? V r e s e t ) + X ( t ) (2) \tau_{m}(V(t)-V(t-1))=-\left(V(t-1)-V_{r e s e t}\right)+X(t) \tag{2} τm?(V(t)?V(t?1))=?(V(t?1)?Vreset?)+X(t)(2)

neuron.pyLIFNODE继承自BaseNodeBaseNode中forward按照充电、放电、重置的顺序进行前向传播。

class BaseNode(nn.Module):
    def __init__(self, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False, monitor_state=False):
        '''
        * :ref:`API in English <BaseNode.__init__-en>`

        .. _BaseNode.__init__-cn:

        :param v_threshold: 神经元的阈值电压

        :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
            如果设置为 ``None``,则电压会被减去 ``v_threshold``

        :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数

        :param detach_reset: 是否将reset过程的计算图分离

        :param monitor_state: 是否设置监视器来保存神经元的电压和释放的脉冲。
            若为 ``True``,则 ``self.monitor`` 是一个字典,键包括 ``h``, ``v`` ``s``,分别记录充电后的电压、释放脉冲后的电压、释放的脉冲。
            对应的值是一个链表。为了节省显存(内存),列表中存入的是原始变量转换为 ``numpy`` 数组后的值。
            还需要注意,``self.reset()`` 函数会清空这些链表

        可微分SNN神经元的基类神经元。
       '''
        super().__init__()
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function
        self.monitor = monitor_state
        self.reset()
        
    @abstractmethod
    def neuronal_charge(self, dv: torch.Tensor):
    	'''
    	定义神经元的充电差分方程。子类必须实现这个函数。
    	'''
        raise NotImplementedError
        
    def neuronal_fire(self):
        '''
        根据当前神经元的电压、阈值,计算输出脉冲。
        '''
        if self.monitor:
            if self.monitor['h'].__len__() == 0:
                # 补充在0时刻的电压
                if self.v_reset is None:
                    self.monitor['h'].append(self.v.data.cpu().numpy().copy() * 0)
                else:
                    self.monitor['h'].append(self.v.data.cpu().numpy().copy() * self.v_reset)
            else:
                self.monitor['h'].append(self.v.data.cpu().numpy().copy())

        self.spike = self.surrogate_function(self.v - self.v_threshold)
        #surrogate function默认是Sigmoid,\alpha为1。 在前向传播时,使用神经元的输出			
        #离散的0和1,我们的网络仍然是SNN;而反向传播时,使用梯度替代函数的梯度来代替脉冲函数的梯度。
        if self.monitor:
            self.monitor['s'].append(self.spike.data.cpu().numpy().copy())

    def neuronal_reset(self):
        '''
        根据当前神经元释放的脉冲,对膜电位进行重置。
        '''
        if self.detach_reset:
            spike = self.spike.detach()
        else:
            spike = self.spike

        if self.v_reset is None: #不重置减去阈值
            self.v = self.v - spike * self.v_threshold
        else: #重置恢复为v_reset值
            self.v = (1 - spike) * self.v + spike * self.v_reset

        if self.monitor:
            self.monitor['v'].append(self.v.data.cpu().numpy().copy())

    def forward(self, dv: torch.Tensor): #前向forward
        '''
        :param dv: 输入到神经元的电压增量
        :return: 神经元的输出脉冲
        按照充电、放电、重置的顺序进行前向传播。
        '''
        self.neuronal_charge(dv)
        self.neuronal_fire()
        self.neuronal_reset()
        return self.spike

    def set_monitor(self, monitor_state=True):
        '''
        设置开启或关闭monitor。
        '''
        if monitor_state:
            self.monitor = {'h': [], 'v': [], 's': []}
        else:
            self.monitor = False
 
     def extra_repr(self):
     	'''
     	设置模块的额外表示信息。
   		'''
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
        
    def reset(self):
        '''
        重置神经元为初始状态,也就是将电压设置为 ``v_reset``。
        如果子类的神经元还含有其他状态变量,需要在此函数中将这些状态变量全部重置。
        一般在下一个epoch训练前进行
        '''
        if self.v_reset is None:
            self.v = 0.0
        else:
            self.v = self.v_reset

        self.spike = None

        if self.monitor:
            self.monitor = {'h': [], 'v': [], 's': []}

LIFNODE继承BaseNode,根据公式(2)的神经元更新方程,重写neuronal_charge

class LIFNode(BaseNode):
    def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False,
                 monitor_state=False):
        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)
        self.tau = tau

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, tau={self.tau}'

    def neuronal_charge(self, dv: torch.Tensor):
        if self.v_reset is None: #这里的dv就是上一层的输出,公式中的X(t)
            self.v += (dv - self.v) / self.tau
        else:
            self.v += (dv - (self.v - self.v_reset)) / self.tau

回顾一下网络

    # 定义并初始化网络
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 14 * 14, bias=False), #这里不加bias应该是偏置在SNN中不好表示
        neuron.LIFNode(tau=tau),
        nn.Linear(14 * 14, 10, bias=False),
        neuron.LIFNode(tau=tau)
    )

这里neuron.LIFNode的作用就是将14*14的模拟连续的神经元变为LIF神经元,按0,1发放脉冲,整个网络也就变成了SNN

b. 泊松编码

源程序对输入图像进行了泊松编码,无状态编码器的基类是StatelessEncoder

class StatelessEncoder(nn.Module):
    def __init__(self):
        """
        无状态编码器的基类。无状态编码器 ``encoder = StatelessEncoder()``,直接调用 ``encoder(x)`` 即可将 ``x`` 编码为 ``spike``。
        """
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor):
        """
        :param x: 输入数据
        :type x: torch.Tensor
        :return: ``spike``, shape 与 ``x.shape`` 相同
        :rtype: torch.Tensor
        """
        raise NotImplementedError

PoissonEncoder继承自StatelessEncoder

class PoissonEncoder(StatelessEncoder):
    def __init__(self):
        """
        无状态的泊松编码器。输出脉冲的发放概率与输入 ``x`` 相同。
        .. warning::
            必须确保 ``0 <= x <= 1``。
        """
        super().__init__()

    def forward(self, x: torch.Tensor):
        out_spike = torch.rand_like(x).le(x)
        # torch.rand_like(x)生成与x相同shape的介于[0, 1)之间的随机数, 这个随机数小于等于x中对应位置的元素,则发放脉冲
        return out_spike

c. 训练、重置

训练部分即按一次batchsize个样本进行权重的更新,更新之后LIF神经元部分需要重置。

    for epoch in range(train_epoch):
        net.train()
        for img, label in tqdm(train_data_loader):
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()
            optimizer.zero_grad()


            # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
            for t in range(T):
                if t == 0:
                    out_spikes_counter = net(encoder(img).float())
                else:
                    out_spikes_counter += net(encoder(img).float())

            # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
            out_spikes_counter_frequency = out_spikes_counter / T

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
            
            writer.add_scalar('train_accuracy', accuracy, train_times)
            train_accs.append(accuracy)

            train_times += 1

这里的functional.reset_net(net)函数如下:

def reset_net(net: nn.Module):
    '''
    :param net: 任何属于 ``nn.Module`` 子类的网络
    :return: None
    将网络的状态重置。做法是遍历网络中的所有 ``Module``,若含有 ``reset()`` 函数,则调用。
    reset()函数即是LIFNode类中的方法
    '''
    for m in net.modules():
        if hasattr(m, 'reset'):
            m.reset()

一个epoch参数更新后在整个测试集上进行模型测试,方便训练时看效果:

        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            test_accs.append(test_accuracy)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print(f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}')

d. 结果、分析

需要注意的是,训练这样的SNN,所需显存数量与仿真时长 T 线性相关,更长的 T 相当于使用更小的仿真步长,训练更为“精细”,但训练效果不一定更好,因此 T 太大,SNN在时间上展开后就会变成一个非常深的网络,梯度的传递容易衰减或爆炸。由于我们使用了泊松编码器,因此需要较大的 T。

100个epoch下不同batch_size的结果。大的batch_size达到相同精度需要更多的epoch,但是处理单个epoch的时间更短。
在这里插入图片描述
在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-07-25 11:36:25  更:2021-07-25 11:37:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/25 15:13:49-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计