首页 > 其他分享 >spikingjelly学习-训练网络

spikingjelly学习-训练网络

时间:2024-04-09 22:32:00浏览次数:25  
标签:spikingjelly layer seq 训练 钩子 脉冲 网络 output 神经元

【MNIST数据集包含若干尺寸为28*28的8位灰度图像,总共有0~9共10个类别。以MNIST的分类为例,一个简单的单层ANN网络如下

我们也可以用完全类似结构的SNN来进行分类任务。就这个网络而言,只需要先去掉所有的激活函数,再将尖峰神经元添加到原来激活函数的位置,这里我们选择的是LIF神经元。神经元之间的连接层需要用
spikingjelly.activation_based.layer包装:

在 spikingjelly 中,我们约定,只能输出脉冲,即0或1的神经元,都可以称之为“脉冲神经元”。使用脉冲神经元的网络,进而也可以称之为脉冲神经元网络(Spiking Neural Networks, SNNs)。这里使用了 neuron.IFNode() 来构建 IF 神经元层,该神经元层有如下构造函数:
  1. v_threshold – 神经元的阈值电压
  2. v_reset – 神经元的重置电压。
  3. surrogate_function – 反向传播时用来计算脉冲函数梯度的替代函数
    神经元的数量是在初始化或调用 reset() 函数重新初始化后,根据第一次接收的输入的 shape 自动决定的。此处则是10个神经元。其中膜电位衰减常数 需要通过参数tau设置,替代函数这里选择surrogate.ATan。
    然后是训练SNN网络,指定好训练参数如学习率等以及若干其他配置优化器默认使用Adam,以及使用泊松编码器,在每次输入图片时进行脉冲编码。

【训练代码的编写需要遵循以下三个要点:
 脉冲神经元的输出是二值的,而直接将单次运行的结果用于分类极易受到编码带来的噪声干扰。因此一般认为脉冲网络的输出是输出层一段时间内的发放频率(或称发放率),发放率的高低表示该类别的响应大小。因此网络需要运行一段时间,即使用T个时刻后的平均发放率作为分类依据。
 我们希望的理想结果是除了正确的神经元以最高频率发放,其他神经元保持静默。常常采用交叉熵损失或者MSE损失,这里我们使用实际效果更好的MSE损失。
 每次网络仿真结束后,需要重置网络状态

 # 保存绘图用数据
    net.eval()
    # 注册钩子
    output_layer = net.layer[-1] # 输出层
    output_layer.v_seq = []
    output_layer.s_seq = []
    def save_hook(m, x, y):
        m.v_seq.append(m.v.unsqueeze(0))
        m.s_seq.append(y.unsqueeze(0))

    output_layer.register_forward_hook(save_hook)


    with torch.no_grad():
        img, label = test_dataset[0]
        img = img.to(args.device)
        out_fr = 0.
        for t in range(args.T):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()
        print(f'Firing rate: {out_spikes_counter_frequency}')

        output_layer.v_seq = torch.cat(output_layer.v_seq)
        output_layer.s_seq = torch.cat(output_layer.s_seq)
        v_t_array = output_layer.v_seq.cpu().numpy().squeeze()  # v_t_array[i][j]表示神经元i在j时刻的电压值
        np.save("v_t_array.npy",v_t_array)
        s_t_array = output_layer.s_seq.cpu().numpy().squeeze()  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
        np.save("s_t_array.npy",s_t_array)

在这里插入图片描述
【在PyTorch中,钩子(hooks)是一种强大的工具,允许你在模型的前向传播(forward pass)或反向传播(backward pass)过程中插入自定义操作。这些操作可以用于调试、可视化、保存中间状态等目的,而不需要修改模型的定义。
钩子的类型
前向钩子(Forward Hooks):在层的前向传播执行完毕后立即执行。它们通常用于检查、修改或记录从层输出的数据。
反向钩子(Backward Hooks):在层的梯度计算过程中执行。它们用于检查或修改梯度值。
这段代码中的钩子使用
在提供的代码段中,使用了一个前向钩子(save_hook)来保存神经网络某层在前向传播过程中的电压值(v)和脉冲值(s)。
这个钩子函数save_hook接收三个参数:
m:注册钩子的模块(在这个例子中是输出层)。
x:输入到该模块的数据。
y:从该模块输出的数据。
在钩子函数内部,它将模块m的电压值v和输出脉冲y保存到列表中。这里使用unsqueeze(0)是为了增加一个批次维度,使得每次迭代的数据可以被堆叠起来。
钩子的注册
这行代码将save_hook函数注册为output_layer(网络的最后一层)的前向钩子。这意味着每当output_layer完成前向传播时,save_hook函数都会被调用。
数据的保存
在所有测试图像通过网络并且钩子函数被调用之后,v_seq和s_seq列表中的数据被合并(使用torch.cat)并转换为NumPy数组,然后通过np.save保存到文件中。这些文件包含了在整个测试集上,输出层神经元的电压值和脉冲发放情况,可以用于进一步的分析和可视化。】
这段代码通过注册一个前向钩子来捕获并保存神经网络最后一层在前向传播过程中的电压和脉冲数据。这种方法非常有用,因为它允许在不修改网络结构的情况下收集内部状态信息,对于理解和分析网络的行为非常有帮助。

标签:spikingjelly,layer,seq,训练,钩子,脉冲,网络,output,神经元
From: https://blog.csdn.net/weixin_44781508/article/details/137021168

相关文章

  • WDS+MDT网络启动自动部署windows(二)基本安装配置
    简介:WDS网络部署服务的核心只有TFTP和多播,采用WDS来做,就是因为多播这个优势,否则TFTPD,iventory,openwrt都更简单方便。见前几篇博客。当然这依托于DHCP将信息发送给客户端来实现。本文将介绍DHCPWDS的简单安装配置。DHCP和WDS同一台服务器,也可以分开两台安装。所有未截图描述的......
  • 网络协议
    1.网际互连协议-IP(InternetProtocol)IP指网际互连协议,是TCP/IP体系中的网络层协议。设计IP的目的是提高网络的可扩展性:一是解决互联网问题,实现大规模、异构网络的互联互通;二是分割顶层网络应用和底层网络技术之间的耦合关系,以利于两者的独立发展。根据端到端的设计原则,IP只为主......
  • 目标检测:yolov8(ultralytics)训练自己的数据集,新手小白也能学会训练模型,一看就会
    目录1.环境配置2.数据集获取2.1网上搜索公开数据集2.2自制数据集2.2.1Labelimg安装2.2.2Labelimg使用2.3数据集转换及划分2.3.1数据集VOC格式转yolo格式2.3.2数据集划分3.训练模型3.1创建data.yaml3.2训练模型4.模型测试5.可视化界面分为4部分,......
  • 基于樽海鞘群算法优化的广义回归神经网络(GRNN)预测
    基于樽海鞘群算法优化的广义回归神经网络(GRNN)预测文章目录基于樽海鞘群算法优化的广义回归神经网络(GRNN)预测1.GRNN神经网络概述2.GRNN的网络结构3.GRNN的理论基础4.数据集5.樽海鞘群算法优化GRNN6.实验结果7.Matlab代码摘要:本文介绍基于樽海鞘群算法优化的广......
  • WDS+MDT网络启动自动部署windows(一)实验环境介绍
    简介:这个系列以前搞过一次,挺顺利的,这次搞起来,居然折腾了两周,不知道问题出在哪里,始终无法正常PXE引导UEFI模式的计算机。经过不断的折腾,终于发现,DHCPoption60PXEClient,不应该设置。不知道是UEFI和BIOS处理方式不同,还是Windowsserver2022的WDS有bug,提示我两个都要勾选的。实......
  • 突破编程_C++_网络编程(Windows 套接字(setsockopt 选项设置))
    1setsockopt函数介绍Windows套接字(Winsock)的setsockopt函数是用于设置套接字选项的重要工具。通过这个函数,开发者可以调整套接字的行为,以满足特定的网络应用需求。(1)函数原型intsetsockopt(SOCKETs,intlevel,intoptname,constchar*optval,intop......
  • Java IO与NIO-Java内存管理-Java虚拟机(JVM)-Java网络编程-Java注解(Annotation)
    JavaIO与NIO:请解释Java中的IO(Input/Output)和NIO(NewInput/Output)的区别是什么?它们各自的优势是什么?答案:Java中的IO是基于流(Stream)的方式进行输入输出操作,而NIO则是基于通道(Channel)和缓冲区(Buffer)的方式进行输入输出操作。NIO相比于IO具有非阻塞IO、选择器(Selector)和内存映......
  • 无忧网络验证系统 getInfo SQL注入漏洞复现
    0x01产品简介无忧网络验证是一套安全稳定高效的网络验证系统,基于统一核心的通用互联网+信息化服务解决方案,是为软件作者设计的一套完整免费的网络验证体系。可以为开发的软件增加收费授权的功能,让作者开发的软件可以进行销售、充值、登陆等操作,并且提供防破解验证功能,可以......
  • docker ——网络配置和管理
    docker网络基础了解docker网络两种docker网络单主机与多主机的docker网络网络驱动网络驱动介绍bridge桥接网络,这是默认的网络驱动程序host主机网络overlay覆盖网络macvlan将mac地址分配给容器,使容器作为网络上的物理设备none表示关闭容器的所有......
  • 计算机网络常见网络命令使用与协议的分析
    实验目的常见网络命令使用与协议的分析实验条件Windows,ethereal实验内容常见命令使用:Ipconfig     网络协议分析:IPArp                       tcpudp           Ping/pingip–n–l        ......