首页 > 其他分享 >深度学习(优化器)

深度学习(优化器)

时间:2025-01-04 20:37:15浏览次数:1  
标签:优化 self torch 学习 lr 深度 model grad id

                       

下面实现了深度学习中的几种优化器,包括SGD,Momentum, Nesterov,AdaGrad,RMSProp,AdaDelta,Adam和AdamW。

代码如下:

import torch
import torch.nn as nn
from torchvision import transforms,datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SGD:
    def __init__(self, model, lr=0.001):
        self.lr = lr
        self.model = model

    def zero_grad(self):
        self.model.zero_grad()

    def step(self):
        with torch.no_grad():
            for p in self.model.parameters():
                if p.requires_grad: 
                    p -= self.lr * p.grad

class Momentum:
    def __init__(self, model, lr=0.001, momentum=0.9):
        self.lr = lr
        self.momentum = momentum
        self.model = model

        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()

    def step(self):
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad: 
                    self.v[id] = self.momentum*self.v[id] - self.lr*p.grad
                    p += self.v[id]

class Nesterov:
    """(http://arxiv.org/abs/1212.0901)"""
    def __init__(self, model, lr=0.001, momentum=0.9):
        self.lr = lr
        self.momentum = momentum
        self.model = model

        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.v[id] = torch.zeros_like(p).to(device)
    
    def zero_grad(self):
        self.model.zero_grad()    

    def step(self):
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:            
                    p += self.momentum * self.momentum * self.v[id] - ((1 + self.momentum) * self.lr * p.grad)
                    self.v[id] = self.momentum*self.v[id] - self.lr * p.grad

class AdaGrad:
    def __init__(self, model, lr=0.001):
        self.lr = lr
        self.model = model
        self.v = {}
        for id,p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()   

    def step(self):
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:      
                    self.v[id] += p.grad * p.grad
                    p -= self.lr * p.grad / (torch.sqrt(self.v[id]) + 1e-7)  

class RMSprop:
    def __init__(self, model, lr=0.001, decay_rate = 0.99):
        self.lr = lr  
        self.decay_rate = decay_rate   
        self.model = model
        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()   

    def step(self):            
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:               
                    self.v[id] = self.decay_rate*self.v[id] + (1 - self.decay_rate) * p.grad * p.grad
                    p -= self.lr * p.grad / (torch.sqrt(self.v[id]) + 1e-7)

class AdaDelta:
    def __init__(self, model, lr=0.001, decay_rate = 0.99):
        self.lr = lr  
        self.decay_rate = decay_rate   
        self.model = model
        self.u = {}
        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.u[id] = torch.ones_like(p).to(device)
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()   

    def step(self):            
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:  
                    self.v[id] = self.decay_rate * self.v[id] + (1 - self.decay_rate) * p.grad * p.grad
                    delta_w = (torch.sqrt(self.u[id] + 1e-7)/torch.sqrt(self.v[id] + 1e-7)) * p.grad
                    self.u[id] = self.decay_rate*self.u[id] + (1 - self.decay_rate) * delta_w * delta_w
                    p -= self.lr *delta_w

class Adam:
    """(http://arxiv.org/abs/1412.6980v8)"""
    def __init__(self, model, lr=0.001, beta1=0.9, beta2=0.999):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.iter = 0
        self.model = model
        self.m = {}
        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.m[id] = torch.zeros_like(p).to(device)
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()   

    def step(self):
        self.iter += 1
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:    
                    self.m[id] = self.beta1 * self.m[id] + (1 - self.beta1) * p.grad
                    self.v[id] = self.beta2 * self.v[id] + (1 - self.beta2) * (p.grad**2)
                    m_hat = self.m[id] / (1 - self.beta1**self.iter)
                    v_hat = self.v[id] / (1 - self.beta2**self.iter)
                    p -= self.lr * m_hat / (torch.sqrt(v_hat) + 1e-7)

class AdamW:
    def __init__(self, model, lr=0.001, beta1=0.9, beta2=0.999, weight_decay = 0.01):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.weight_decay = weight_decay
        self.iter = 0
        self.model = model
        self.m = {}
        self.v = {}
        for id, p in enumerate(model.parameters()):
            if p.requires_grad: 
                self.m[id] = torch.zeros_like(p).to(device)
                self.v[id] = torch.zeros_like(p).to(device)

    def zero_grad(self):
        self.model.zero_grad()   

    def step(self):
        self.iter += 1
        with torch.no_grad():
            for id, p in enumerate(self.model.parameters()):
                if p.requires_grad:    
                    self.m[id] = self.beta1 * self.m[id] + (1 - self.beta1) * p.grad
                    self.v[id] = self.beta2 * self.v[id] + (1 - self.beta2) * (p.grad**2)
                    m_hat = self.m[id] / (1 - self.beta1**self.iter)
                    v_hat = self.v[id] / (1 - self.beta2**self.iter)
                    p -= self.lr *(m_hat / (torch.sqrt(v_hat) + 1e-7) + self.weight_decay*p) 


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x)) 
        x = torch.max_pool2d(x, 2) 
        x = torch.relu(self.conv2(x)) 
        x = torch.max_pool2d(x, 2) 
        x = x.view(x.size(0), -1) 
        x = torch.relu(self.fc1(x)) 
        x = torch.relu(self.fc2(x))  
        x = self.fc3(x)  
        return x

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

model = LeNet()
criterion = nn.CrossEntropyLoss()

#opt = SGD(model,0.001)
#opt = Momentum(model,0.001,0.9)
#opt = Nesterov(model,0.001,0.9)
#opt = AdaGrad(model,0.001)
#opt = RMSprop(model,0.001,0.99)
#opt = AdaDelta(model,0.001,0.99)
#opt = Adam(model, 0.001, 0.9, 0.999)
opt = AdamW(model, 0.001, 0.9, 0.999, 0.01)
model.to(device)

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
 
        output = model(images)
        loss = criterion(output, labels)

        opt.zero_grad()
        loss.backward()
        opt.step()

        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {(100 * correct / total):.2f}%")

参考:

手写深度学习之优化器(SGD、Momentum、Nesterov、AdaGrad、RMSProp、Adam)_手写优化器-CSDN博客

AdamW和Adam优化器对比分析-CSDN博客

Adam和AdamW的区别_adam adamw-CSDN博客

标签:优化,self,torch,学习,lr,深度,model,grad,id
From: https://www.cnblogs.com/tiandsp/p/18575046

相关文章

  • JVM实战—10.MAT的使用和JVM优化总结
    大纲1.线上大促活动导致的老年代内存泄漏和FGC(MAT分析出本地缓存没处理好)2.百万级数据误处理导致频繁FGC(大数据量加载到内存处理+String.split())3.JVM运行原理和GC原理总结4.JVM性能优化的思路和步骤5.问题汇总 1.线上大促活动导致的老年代内存泄漏和FGC(MAT分析出......
  • 带你从入门到精通——机器学习(九. 聚类算法)
    建议先阅读我之前的博客,掌握一定的机器学习前置知识后再阅读本文,链接如下:带你从入门到精通——机器学习(一.机器学习概述)-CSDN博客带你从入门到精通——机器学习(二.KNN算法)-CSDN博客带你从入门到精通——机器学习(三.线性回归)-CSDN博客带你从入门到精通——机器学习(四.逻......
  • 第10节 Java 新手必看!Static 关键字一文全解(附学习资料领取方式)(2)
    在Java编程中,static是一个不可忽视的关键字!......
  • 万字长文带你全面了解Java 中 break 和 continue(文内附内部学习资料)
    ......
  • 学期2024-2025-1 学号20241424 《计算机基础与程序设计》第15周学习总结
    学期2024-2025-1学号20241424《计算机基础与程序设计》第15周学习总结作业信息|这个作业属于2024-2025-1-计算机基础与程序设计)||-- |-- ||这个作业要求在2024-2025-1计算机基础与程序设计第15周作业||这个作业的目标|<作业总结>||作业正文|https://www.cnblogs.com/zmws/......
  • 学习《ROS2机器人开发从入门到实践》Day2
    文章目录前言一、将.sh普通文件变成可执行文件1.创建一个zxx.sh的普通文件2.给该文件添加可执行权限二、Linux环境变量1.查看ROS版本号2.查看ROS发行版本3.查看系统所有环境变量三、环境变量作用1.ros2run命令解释2.查看ROS2存放路径3.直接执行turtlesim_node4.环境变......
  • 关于最适合小白学习的wireshake文章
    一、wireshake简介Wireshark是一个网络封包分析软件。网络封包分析软件的功能是撷取网络封包,并尽可能显示出最为详细的网络封包资料。Wireshark使用WinPCAP作为接口,直接与网卡进行数据报文交换。网络管理员使用Wireshark来检测网络问题,网络安全工程师使用Wireshark来检......
  • Java学习教程,从入门到精通,Java Lambda 表达式语法知识点及案例代码(79)
    JavaLambda表达式语法知识点及案例代码Lambda表达式是Java8引入的一项重要特性,它允许我们将代码当作数据来传递,从而使代码更加简洁和易读。1.什么是Lambda表达式?Lambda表达式是一种匿名函数,它没有类和方法名,可以直接作为参数传递给方法或存储在变量中。2.Lambda表......
  • DL00755-基于YOLO深度学习的井盖缺陷检测系统可换模型
    完整链接:https://item.taobao.com/item.htm?ft=t&id=868129715108YOLOv8(YouOnlyLookOnceVersion8)是近年来深度学习领域中广泛应用的一种高效目标检测算法,特别擅长处理实时目标识别任务。在城市基础设施的管理和维护中,井盖缺陷的检测是一个重要的研究方向。井盖作为城市排水......
  • Manacher 学习笔记
    \(\text{Manacher学习笔记}\)一、引入首先我们需要知道的是\(\text{Manacher}\)是解决回文串问题的有效工具。一个通用的问题模型是给定一个长度为\(n\)的字符串\(s\),统计该字符串中所有的回文子串的个数。\(\text{Manacher}\)算法可以在\(O(n)\)的时间复杂度内解决这......