首页 > 其他分享 >pytorch--训练分层学习率设置

pytorch--训练分层学习率设置

时间:2023-05-27 21:57:24浏览次数:42  
标签:nn -- 模型 分层 pytorch params model self out

在训练模型时,我们经常会使用两个神经网络模型进行融合,若两个模型的复杂度不同,或者激活函数不同,导致训练后的模型训练损失忽高忽低,差距巨大,有可能是陷入了局部最优的状况。这时候采用分层学习率的策略可能帮助模型度过局部最优困境。

下面是一个简单的示例:

对于一个继承于nn.Module的神经网络模型Model

class Model(nn.Module):
	def __init__(self):
		super().__init__()
		self.layer1 = nn.Sequential(nn.Linear(20, 10), nn.Tanh())
		self.layer2 = nn.Linear(10, 1)
	
	def forward(self, x):
		out = self.layer1(x)
		out = self.layer2(out)
		return out

那么分层学习率的设置大致如下:

model = Model() # 模型初始化
# 设置分层学习率
params_list = [{'params': model.layer1.parameters(), 'lr': 0.001},
	{'params': model.layer2.parameters(), 'lr': 0.002}]
# 将学习率传入优化器 
optimizeer = torch.optim.RMSprop(params_list)
# 模型训练
train(model, max_epoch, optimizer, train_iter, vali_iter, test_iter, loss_func)

标签:nn,--,模型,分层,pytorch,params,model,self,out
From: https://www.cnblogs.com/huxiaohu52/p/17437429.html

相关文章

  • cartographer代码——世界坐标系点和像素坐标系点的转换
    构建栅格地图,要弄清楚坐标之间的关系。本篇根据代码,画出了坐标转换的关系。如下图:cartographer中的代码如下://Returnstheindexofthecellcontainingthe'point'whichmaybeoutside//themap,i.e.,negativeortoolargeindicesthatwillreturnfalsefo......
  • wait,notify,notifyAll,sleep,join等线程方法的全方位演练
    一、概念解释1.进入阻塞:有时我们想让一个线程或多个线程暂时去休息一下,可以使用wait(),使线程进入到阻塞状态,等到后面用到它时,再使用notify()、notifyAll()唤醒它,线程被唤醒后,会等待CPU调度。不过需要注意的是:在执行wait()方法前必须先拿到这个对象的monitor锁。2.线程......
  • 2023-05 多校联合训练 HZNU站
    我想要原石然而,由于提瓦特大陆实在是太大了,游戏中设置了许多传送锚点。众所周知,每个传送锚点附近都有若干个原石(其实并没有),曾经有一位丰富经验的旅行者开辟了\(n−1\)条路和\(n\)个由路连通的传送锚点。为了便于后续的旅行者知道地图上原石的分布情况,他决定给旅行者一些提示......
  • 线程的 6 个状态(生命周期)
    线程的6个状态(生命周期)1.线程的一生中有哪几个状态有6种状态,分别如下:NewRunnableBlockedWaitingTimed_WaitingTerminated2.每个状态的含义是什么New:是在newThread()之后,执行start()方法之前的一个状态;Runnable:是在线程调用start()方法之后的状态(其实包括......
  • [ARC160F] Count Sorted Arrays
    ProblemStatementThereareaninteger$N$and$M$pairsofintegers:$(a_1,b_1),(a_2,b_2),\dots,(a_M,b_M)$.Eachpair$(a_i,b_i)$satisfies$1\leqa_i\ltb_i\leqN$.Initally,youhaveall$N!$permutationsof$(1,2,\dots,N)$.Youwillperf......
  • mysql监控工具sqlprofiler,类似sqlserver的profiler工具
    最近无意发现了mysql的客户端监控工具“NeroProfileSQL”,刚开始还不知道怎么使用,经过半小时摸索,现将使用步骤写下来。背景:开发的时候,如果数据存储层这块使用EF,或者其他orm框架,数据库是mysql,想知道最终执行的sql语句,那么这款工具就帮你忙了。1、去官网下载安装windows......
  • CS61b_最小区间排序
       publicstaticvoidzorkSort(int[]A,intk){inti;intn=A.length;i=0;PriorityQueue<Integer>pq=newPriorityQueue<>();while(i<k){pq.add(A[i]);i++;}whil......
  • 1 基础语法
    1、查看数据:1)View2)str3)class4)typeof5)mode6)glimpse7)summary2、R中数据结构 1)  *同质数据结构:向量、矩阵、多维数组*异质数据结构:列表、数据框 2)*原子向量,各个值是同类型的:logical、interger、double、character、c......
  • PKUCPC2023游记
    PKUCPC2023游记怎么有大学生写游记呢?怎么回事呢?怎么回事呢?day?补了之前PKUCPC2022除了计算几何以外的所有题,感觉题目不是很难的样子,争取拿个二等奖!day0听说秋丽他们要打THUPC所以来了北京,于是打算去面基群友,顺便蹭餐晚饭,不过去之前得知加上我有十二个人,非常恐怖,不知道......
  • 我的十年程序员生涯--无锡之旅,开启岗前培训
    2012年的那个春天,考研初试结果出来了,很不理想。面临着二战及工作两种选择,最终选择了工作。当时的理由是“研究生之后,仍旧要工作,不如现在去工作,而且还可以积累三年的工作经验”。现在来看这个理由很不成立,工作的头两年感觉不到学历的重要,越是随着工作年限的增长,越感觉到学历的重要......