首页 > 其他分享 >信号处理--基于gumbel-softmax方法实现运动想象分类的通道选择

信号处理--基于gumbel-softmax方法实现运动想象分类的通道选择

时间:2024-03-17 20:30:06浏览次数:31  
标签:dim nn -- self torch hot gumbel softmax size

目录

背景

亮点

环境配置

数据

方法

结果

代码获取

参考文献


背景

基于Gumbel-softmax方法EEG通道选择层的PyTorch实现。该层可以放置在任何深度神经网络架构的前面,以共同学习给定任务和网络权重的脑电图通道的最佳子集。这一层由选择神经元组成,每个神经元都使用输入通道上离散分布的连续松弛来学习最佳的单热权重向量来选择输入通道,而不是线性组合它们。

亮点

使用Gumbel-softmax方法对多通道脑电数据进行单通道选择(非多通道线性加权)

使用多尺度滤波卷积网络实现运动想象4分类。

环境配置

PyTorch 0.3.1,

CUDA 9.1

数据

High-Gamma Dataset

方法

多尺度滤波卷积网络主要代码:

class MSFBCNN(nn.Module):
	def __init__(self,input_dim,output_dim,FT=10):
		super(MSFBCNN, self).__init__()
		self.T = input_dim[1]
		self.FT = FT
		self.D = 1
		self.FS = self.FT*self.D
		self.C=input_dim[0]
		self.output_dim = output_dim
		
		# Parallel temporal convolutions
		self.conv1a = nn.Conv2d(1, self.FT, (1, 65), padding = (0,32),bias=False)
		self.conv1b = nn.Conv2d(1, self.FT, (1, 41), padding = (0,20),bias=False)
		self.conv1c = nn.Conv2d(1, self.FT, (1, 27), padding = (0,13),bias=False)
		self.conv1d = nn.Conv2d(1, self.FT, (1, 17), padding = (0,8),bias=False)

		self.batchnorm1 = nn.BatchNorm2d(4*self.FT, False)
		
		# Spatial convolution
		self.conv2 = nn.Conv2d(4*self.FT, self.FS, (self.C,1),padding=(0,0),groups=1,bias=False)
		self.batchnorm2 = nn.BatchNorm2d(self.FS, False)

		#Temporal average pooling
		self.pooling2 = nn.AvgPool2d(kernel_size=(1, 75),stride=(1,15),padding=(0,0))

		self.drop=nn.Dropout(0.5)

		#Classification
		self.fc1 = nn.Linear(self.FS*math.ceil(1+(self.T-75)/15), self.output_dim)

	def forward(self, x):

		# Layer 1
		x1 = self.conv1a(x);
		x2 = self.conv1b(x);
		x3 = self.conv1c(x);
		x4 = self.conv1d(x);

		x = torch.cat([x1,x2,x3,x4],dim=1)
		x = self.batchnorm1(x)

		# Layer 2
		x = torch.pow(self.batchnorm2(self.conv2(x)),2)
		x = self.pooling2(x)
		x = torch.log(x)
		x = self.drop(x)
		
		# FC Layer
		x = x.view(-1, self.num_flat_features(x))
		x = self.fc1(x)
		return x

	def num_flat_features(self, x):
		size = x.size()[1:]  # all dimensions except the batch dimension
		num_features = 1
		for s in size:
			num_features *= s
		return num_features

Gumbel-softmax 再参数化主要代码:

class SelectionLayer(nn.Module):
	def __init__(self, N,M,temperature=1.0):

		super(SelectionLayer, self).__init__()
		self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
		self.N = N
		self.M = M
		self.qz_loga = Parameter(torch.randn(N,M)/100)

		self.temperature=self.floatTensor([temperature])
		self.freeze=False
		self.thresh=3.0

	def quantile_concrete(self, x):

		g = -torch.log(-torch.log(x))
		y = (self.qz_loga+g)/self.temperature
		y = torch.softmax(y,dim=1)

		return y

	def regularization(self):
		
		eps = 1e-10
		z = torch.clamp(torch.softmax(self.qz_loga,dim=0),eps,1)
		H = torch.sum(F.relu(torch.norm(z,1,dim=1)-self.thresh))

		return H

	def get_eps(self, size):

		eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon)

		return eps

	def sample_z(self, batch_size, training):

		if training:

			eps = self.get_eps(self.floatTensor(batch_size, self.N, self.M))
			z = self.quantile_concrete(eps)
			z=z.view(z.size(0),1,z.size(1),z.size(2))
	 
			return z

		else:

			ind = torch.argmax(self.qz_loga,dim=0)
			one_hot = self.floatTensor(np.zeros((self.N,self.M)))
			for j in range(self.M):
					one_hot[ind[j],j]=1
			one_hot=one_hot.view(1,1,one_hot.size(0),one_hot.size(1))
			one_hot = one_hot.expand(batch_size,1,one_hot.size(2),one_hot.size(3))

			return one_hot

	def forward(self, x):

		z = self.sample_z(x.size(0),training=(self.training and not self.freeze))
		z_t = torch.transpose(z,2,3)
		out = torch.matmul(z_t,x)
		return out

结果

实现从64通道脑电信号中提取出N个重要通道脑电信号,增强后续分类任务的性能

代码获取

https://download.csdn.net/download/YINTENAXIONGNAIER/88946872

参考文献

  • Strypsteen, Thomas, and Alexander Bertrand. "End-to-end learnable EEG channel selection for deep neural networks with Gumbel-softmax." Journal of Neural Engineering 18.4 (2021): 0460a9.

标签:dim,nn,--,self,torch,hot,gumbel,softmax,size
From: https://blog.csdn.net/YINTENAXIONGNAIER/article/details/136639158

相关文章

  • [C语言]——函数
    一.函数的概念数学中我们其实就见过函数的概念,⽐如:⼀次函数y=kx+b,k和b都是常数,给⼀个任意的x,就得到⼀个y值。其实在C语言也引⼊函数(function)的概念,有些翻译为:子程序,子程序这种翻译更加准确⼀些。C语言中的函数就是⼀个完成某项特定的任务的一小段代码。这段代码是有特殊......
  • ESP32学习笔记-读取SD卡并显示到屏幕上
    硬件FireBeetle2ESP32-E开发板1.54"240x240 IPS 广视角TFT显示屏硬件接线测试代码//加载库#include"Arduino.h"#include"FS.h"#include"SD.h"#include"SPI.h"#include"DFRobot_GDL.h"//定义显示屏针脚#defineTFT_DCD2#......
  • 如何查找访问 Nginx 的前 10 个 IP?
    在管理和维护Web服务器时,了解谁正在访问您的网站是非常重要的。Nginx是一个流行的Web服务器,通过分析其访问日志,您可以了解访问者的来源、频率以及他们的行为。有时候,您可能希望查找访问量最高的IP地址,以便进一步分析或采取措施,比如加强安全性或优化性能。本文将详细......
  • 【10】Python3之使用openpyxl,操纵表格
    使用openpyxl,读取Excel文件fromopenpyxlimportload_workbook#加载工作簿,后面写Excel的路径wb=load_workbook(r"C:\Users\以权天下\Desktop\月光.xlsx")#选择活动工作表或特定工作表wb.activesheet=wb['2024']#2024是表名Excel_data=sheet['A2'].value#A2是单元格......
  • Delphi10.3主从表步骤(18)
    1.选择两个FDQuery和两个dataSource,分别命名为master,slave,然后将两个dataSource的属性Dataset设置为对应的FDQuery,假设主从表的关联字段为:从表的mainID和主表的Mid做为主从关联2.在FDQuery1的sql中写入主表语句select*frommainT3.在FDQuery2的sql中写入从表的语......
  • Leecode 二叉树的前序遍历
    Day2刷题我的思路:用数组list存储遍历结果,利用ArrayList的方法实现嵌套!importjava.util.*;classSolution{//defininganarraylistpublicList<Integer>preorderTraversal(TreeNoderoot){List<Integer>Traversal=newArrayList<>();......
  • CF1948
    A题意:定义一个字符是特殊的,当且仅当它左右两边恰有一个字符与之相同。要求构造一个字符串,使得恰好有\(n\)个特殊字符,或判断无解。考虑一个连续的字符段,如果长度\(1\),不贡献特殊字符;否则必然贡献\(2\)个。所以无解条件就是\(2\not\midn\)。否则可以用AABAABAABAAB.........
  • 基于Rust的Tile-Based游戏开发杂记(02)ggez绘图实操
    尽管ggez提供了很多相关特性的demo供运行查看,但笔者第一次使用的时候还是有很多疑惑不解。经过仔细阅读demo代码并结合自己的实践,逐步了解了ggez在不同场景下的绘图方式,在此篇文章进行一定的总结,希望能够帮助到使用ggez的读者。供运行查看,但笔者第一次使用的时候还是有很多疑惑不......
  • JAVA面向对象高级:static注意事项
    packagecom.itheima.static1;publicclassStudent{staticStringschoolName;doublescore;//实例变量//1.类方法中可以直接访问类的成员,不可以直接访问实例成员publicstaticvoidprinthelloworld(){//注意:同一个类中,访问类成员,可以省略类......
  • boss
    importsubprocessimportreimportrequestsfromurllib.parseimporturlparse,parse_qsfromfunctoolsimportpartialsubprocess.Popen=partial(subprocess.Popen,encoding="utf-8")importexecjs获取security-checkurlparams={"query":......