首页 > 其他分享 >第三章 3.9 在训练过程中修改学习率

第三章 3.9 在训练过程中修改学习率

时间:2024-12-16 14:13:00浏览次数:8  
标签:loss plt 第三章 val batch 修改 train model 3.9

Learning_rate_annealing.ipynb

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

###################  Chapter Three #######################################

# 第三章  读取数据集并显示
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
########################################################################
from torchvision import datasets
import torch
data_folder = '~/data/FMNIST' # This can be any directory you want to
# download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)
tr_images = fmnist.data
tr_targets = fmnist.targets

val_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = val_fmnist.data
val_targets = val_fmnist.targets


########################################################################
import matplotlib.pyplot as plt
#matplotlib inline
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

########################################################################
class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/255 #归一化
        x = x.view(-1,28*28)
        self.x, self.y = x, y
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix]
        return x.to(device), y.to(device)
    def __len__(self):
        return len(self.x)

from torch.optim import SGD, Adam
def get_model():
    model = nn.Sequential(
        nn.Linear(28 * 28, 1000),
        nn.ReLU(),
        nn.Linear(1000, 10)
    ).to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-2)
    return model, loss_fn, optimizer

def train_batch(x, y, model, optimizer, loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

def accuracy(x, y, model):
    model.eval()
    # this is the same as @torch.no_grad
    # at the top of function, only difference
    # being, grad is not computed in the with scope
    with torch.no_grad():
        prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

########################################################################
def get_data():
    train = FMNISTDataset(tr_images, tr_targets)
    trn_dl = DataLoader(train, batch_size=32, shuffle=True)#批大小
    val = FMNISTDataset(val_images, val_targets)
    val_dl = DataLoader(val, batch_size=len(val_images), shuffle=False)
    return trn_dl, val_dl
########################################################################
#@torch.no_grad()
def val_loss(x, y, model):
    with torch.no_grad():
        prediction = model(x)
    val_loss = loss_fn(prediction, y)
    return val_loss.item()

########################################################################
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

########################################################################
from torch import optim
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0, threshold = 0.001, min_lr = 1e-5, threshold_mode = 'abs')
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
for epoch in range(30): #轮数 30次
    print(epoch)
    train_epoch_losses, train_epoch_accuracies = [], []
    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss)
    train_epoch_loss = np.array(train_epoch_losses).mean()

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct)
    train_epoch_accuracy = np.mean(train_epoch_accuracies)
    for ix, batch in enumerate(iter(val_dl)):
        x, y = batch
        val_is_correct = accuracy(x, y, model)
        validation_loss = val_loss(x, y, model)
        scheduler.step(validation_loss) #
    val_epoch_accuracy = np.mean(val_is_correct)
    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    val_losses.append(validation_loss)
    val_accuracies.append(val_epoch_accuracy)

########################################################################
epochs = np.arange(30)+1#轮数 30次
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
#%matplotlib inline
plt.figure(figsize=(20,5))
plt.subplot(211)
plt.plot(epochs, train_losses, 'bo', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation loss when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
#plt.show()
plt.subplot(212)
plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation accuracy when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()])
plt.legend()
plt.grid('off')
plt.show()
# plt.figure(figsize=(20,5))
# for ix, par in enumerate(model.parameters()):
#     print(f'绘图:{ix}')
#     print(f'数据:{par.shape}')
#     if(ix==0):
#         plt.subplot(411)
#         plt.hist(par.cpu().detach().numpy().flatten())
#         plt.title('Distribution of weights conencting input to hidden layer')
#         #plt.show()
#     elif(ix ==1):
#         plt.subplot(412)
#         plt.hist(par.cpu().detach().numpy().flatten())
#         plt.title('Distribution of biases of hidden layer')
#         #plt.show()
#     elif(ix==2):
#         plt.subplot(413)
#         plt.hist(par.cpu().detach().numpy().flatten())
#         plt.title('Distribution of weights conencting hidden to output layer')
#         #plt.show()
#     elif(ix ==3):
#         plt.subplot(414)
#         plt.hist(par.cpu().detach().numpy().flatten())
#         plt.title('Distribution of biases of output layer')
#         plt.show()

 

标签:loss,plt,第三章,val,batch,修改,train,model,3.9
From: https://www.cnblogs.com/excellentHellen/p/18609967

相关文章

  • 第三章:3.8.1 绘制各层参数分布图 hist
    Chapter03/Varying_learning_rate_on_scaled_data.ipynb绘制各层参数分布图#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################ChapterThree###......
  • 第三章 3.7 优化器的影响
    代码:#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################ChapterThree########################################第三章读取数据集并显示fro......
  • 记录一次Centos镜像修改以及升级OpenSSL和OpenSSH
    事情是这样的:公司的阿里云服务器被说有漏洞需要修复--查看说漏洞大多都是OpenSSL和OpenSSH的,想到版本比较低就升级他两不就行了吗?结果更新升级发现app-stream均无法成功,原因centos已经停了维护,各种镜像均已不再维护了。第一步修改为阿里云镜像entOS8现已可使用国内的aliyun......
  • 第三章 3.6 批大小的影响
    第三章3.4训练神经网络 #https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################ChapterThree########################################第三章......
  • Java核心技术卷1 第三章选读
    前言本文内容选自Java核心技术卷1第10版,感兴趣的小伙伴可以自行阅读原书,以下内容为本人学习后摘取的片段与大家分享。正文3.3.2浮点类型所有的浮点数值计算都遵循IEEE754规范。具体来说,下面是用于表示溢出和出错情况的三个特殊的浮点数值:正无穷大负无穷大NaN(不......
  • 网站底部二维码怎么修改,如何轻松更新网站底部二维码
    如果您需要修改网站底部的二维码,可以按照以下步骤进行操作:登录后台管理:使用您的账户信息登录网站的后台管理系统。导航至底部设置:登录后,导航至“模板管理”或“页面管理”等相关页面。这些页面通常会包含底部内容的编辑功能。选择模板文件:在模板管理页面中,找到当前使用的......
  • 网站被修改能恢复吗,如何恢复被篡改的网站
    网站被修改后,可以通过以下步骤恢复到正常状态:备份当前文件:首先备份当前的网站文件,防止恢复过程中出现问题。查找备份文件:如果有定期备份的习惯,可以从备份文件中恢复。如果没有备份,可以联系托管商或云服务提供商,询问是否有自动备份。使用版本控制系统:如果使用Git......
  • 网站公司信息怎么修改,如何在网站后台管理系统中修改公司信息
    在网站后台管理系统中修改公司信息是一个常见的维护任务。以下是具体步骤:登录后台:使用管理员账号登录网站的后台管理系统。导航到公司信息管理:在后台菜单中,找到“公司信息”或“关于我们”模块。编辑公司信息:在编辑页面中,可以修改公司的基本信息,如公司名称、地址......
  • 企业网站修改首页底部,如何优雅地调整首页底部信息
    企业网站的首页底部通常包含版权信息、联系方式、导航链接等内容。以下是如何优雅地调整首页底部信息的步骤:确定需求:明确需要添加或修改的内容,如新的联系方式、社交媒体链接等。编辑HTML:在网站的前端代码中,找到首页底部对应的HTML文件,通常位于footer.html或类似文件中。添加内......
  • 织梦怎么修改网站文字,织梦CMS网站文字修改指南
    在织梦CMS中修改网站文字通常涉及编辑模板文件或内容管理模块。以下是详细的步骤:登录后台管理系统:使用管理员账号登录织梦CMS的后台管理系统。编辑模板文件:进入“模板”->“默认模板管理”。找到需要修改的模板文件,通常是index.htm或其他相关的HTML文件。打开模板......