首页 > 其他分享 >深度学习(学习率)

深度学习(学习率)

时间:2024-08-02 22:05:43浏览次数:18  
标签:optimizer optim torch list 学习 lr scheduler 深度

Pytorch做训练的时候,可以调整训练学习率。

通过调整合适的学习率曲线可以提高模型训练效率和优化模型性能。

各种学习率曲线示例代码如下:

import torch
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

if __name__ == '__main__':

    lr_init = 0.5   #初始学习率
    parameter = [nn.Parameter(torch.tensor([1, 2, 3], dtype=torch.float32))]
    optimizer = optim.SGD(parameter, lr=lr_init)

    scheduler_list=[]

    #每迭代step_size次,学习率乘以gamma
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    scheduler_list.append(scheduler)

    #在迭代到millestones次时,学习率乘以gamma
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,40,80], gamma=0.5)
    scheduler_list.append(scheduler)

    #每次学习率是上一次的gamma倍
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 
    scheduler_list.append(scheduler)

    #前total_inters次迭代,学习率从lr_init*(start_factor~end_factor)线性下降,total_iters次之后稳定在end_factor
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,start_factor=1,end_factor=0.01,total_iters=50)
    scheduler_list.append(scheduler)

    #学习率在base_lr~max_lr之间循环,上升step_size_up个周期,下降step_size_down个周期
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=0.01,max_lr=0.1,step_size_up=10,step_size_down=30)
    scheduler_list.append(scheduler)

    #学习率为cos曲线,T_max为半个周期,最小为eta_min,最大为lr_init
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=20,eta_min=0.1)
    scheduler_list.append(scheduler)

    #cos退火学习率,第一个周期为T_0,后面每一个周期为前一个的T_mult倍,最小值为eta_min,最大值为lr_init
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.01)
    scheduler_list.append(scheduler)

    #学习率先上升后下降,pct_start学习率上升部分占比,最大学习率=max_lr,初始学习率=max_lr/div_factor,最终学习率=初始学习率/final_div_factor
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=0.1,pct_start=0.2,total_steps=100,div_factor=10,final_div_factor=5)
    scheduler_list.append(scheduler)

    #total_iters次以内学习率为lr_init,total_iters之后学习率为lr_init/factor
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer,factor=0.5,total_iters=50)
    scheduler_list.append(scheduler)

    #多个学习率组合,将学习率在milestones次循环处分割为两端
    scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer,schedulers=[
        torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9),
        torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)],
        milestones=[50])
    scheduler_list.append(scheduler)

    #同样多种学习率组合,可以给出连续学习率
    scheduler = torch.optim.lr_scheduler.ChainedScheduler([
        torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99),
        torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)])
    scheduler_list.append(scheduler)

    #自定义lambda函数设定学习率,这里是lr_init * 1.0/(step+1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 /(step+1))
    scheduler_list.append(scheduler)

    #自定义lambad函数设定学习率,这里是lr[t] = 0.95*lr[t-1]
    scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,lr_lambda=lambda epoch:0.95)
    scheduler_list.append(scheduler)

    #当指标度量停止改进时,ReduceLROnPlateau会降低学习率,scheduler.step中需要设置loss参数
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=5,threshold=1e-4,threshold_mode='abs',cooldown=0,min_lr=0.001,eps=1e-8)
    # scheduler_list.append(scheduler)

    learning_rates = []
    for sch in scheduler_list:
        rates=[]
        optimizer.param_groups[0]['lr'] = lr_init

        for _ in range(100):
            optimizer.step()
            sch.step()
            rates.append(sch.get_last_lr()[0])
        learning_rates.append(rates)

    numpy_rates = np.array(learning_rates)

    for i in range(numpy_rates.shape[0]):
        plt.subplot(4,4,i+1)
        plt.plot(numpy_rates[i,:])          

    plt.show()

各种学习率曲线如下:

标签:optimizer,optim,torch,list,学习,lr,scheduler,深度
From: https://www.cnblogs.com/tiandsp/p/18327978

相关文章

  • Python应用开发——30天学习Streamlit Python包进行APP的构建(23):构建多页面应用程序
    创建动态导航菜单通过st.navigation,可以轻松创建动态导航菜单。您可以在每次重新运行时更改传递给st.navigation的页面集,从而更改与之匹配的导航菜单。这是创建自定义、基于角色的导航菜单的便捷功能。本教程使用st.navigation和st.Page,它们是在Streamlit1.36.0版中......
  • 虚幻五 学习(五)开始写C++代码
    经过正确创建文件现在就有C++文件了   UFUNCTION(BlueprintCallable) voidOpenLobby(); UFUNCTION(BlueprintCallable) voidCallOpenLevel(constFString&Address); UFUNCTION(BlueprintCallable) voidCallClientTravel(constFString&Address);#include"......
  • Datawhale AI夏令营(AI+生命科学)深度学习-Task3直播笔记
    机器学习lgm上分思路    1、引入新特征(1)对于Task2特征的再刻画        GC含量是siRNA效率中的一个重要且基本的参数,可以作为模型预测的特征。这是因为低GC含量会导致非特异性和较弱的结合,而高GC含量可能会阻碍siRNA双链在解旋酶和RISC复合体作用下的解旋。......
  • 书籍分享《TensorFlow机器学习实战指南》从入门到实战,免费领取!
    Google公司开发的TensorFlow深度学习库因其简单易学、应用场景广泛已经快成为各家公司开展人工智能研究的标配了。TensorFlow机器学习实战指南作者:NickMcClure,资深数据科学家,就职于美国西雅图PayScale公司,曾经在Zillow公司和Caesar’sEntertainment公司工作,获得蒙......
  • Python学习笔记50:游戏篇之外星人入侵(十一)
    前言本篇文章接着之前的内容,继续对游戏功能进行优化,主要是优化游戏状态以及对应的处理。状态一个游戏包含多种状态,这个状态是一个可以很复杂也可以很简单的内容。条件所限,我们这个游戏的状态就比较简单:未开始游戏中暂停结束我们通过一个字段进行控制,并且将这个字段放置......
  • pytorch深度学习实践(刘二大人)课堂代码&作业——线性回归
    一、课堂代码1.torch.nn.linear构造linear对象,对象里包含了w和b,即直接利用linear实现wx+b(linear也继承自module,可以自动实现反向传播)2.torch.nn.MSELoss损失函数MSE包含2个参数:size_average(求均值,一般只考虑这个参数)、reduce(求和降维)3.torch.optim.SGDSGD优化器,设置......
  • javascript学习 - DOM 元素获取、属性修改
    什么是WebAPIWebAPI是指网页服务器或者网页浏览器的应用程序接口。简单来讲,就是我们在编写JavaScript代码时,可以通过WebAPI来操作HTML网页和浏览器。WebAPI又可以分为两类:DOM(文档对象模型)BOM(浏览器对象模型)DOM(DocumentObjectModel),即文档对象模型,主要用......
  • javascript学习 - DOM 事件
    事件什么是事件在之前DOM的学习中,我们主要学习了如何获取DOM元素,并且学会了如何给获取的元素进行属性修改等操作。但这些基本都是静态的修改,并没有接触到一些动作。而今天要学习的事件,其实就是这些动作的总称。所谓事件,就是在编程时系统内所发生的动作或者发生的事情......
  • Redis学习[5] ——Redis过期删除和内存淘汰
    六、Redis过期键值删除6.1Redis的过期键值删除策略6.1.1什么是过期键值删除?Redis中是可以对key设置过期时间的,所以需要有相应的机制将已过期的键值对删除,也就是**过期键值删除策略。Redis会用一个过期字典(expiresdict)**来存储有过期时间的所有key。当查询一个key时,Red......
  • PCIe学习笔记(15)
    设备就绪状态(DeviceReadinessStatus,DRS)消息(DeviceReadinessStatus(DRS)是PCIe规范中引入的一种机制,旨在改进设备初始化和就绪状态的检测与报告。在以往的PCIe版本中,系统通常依赖于固定的超时机制来判断设备是否已经成功初始化并准备好进行数据传输。然而,这种方法存......