首页 > 其他分享 >第三章 3.7 优化器的影响

第三章 3.7 优化器的影响

时间:2024-12-16 10:56:16浏览次数:3  
标签:loss plt 第三章 val batch 3.7 train model 优化

代码:

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

###################  Chapter Three #######################################

# 第三章  读取数据集并显示
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
########################################################################
from torchvision import datasets
import torch
data_folder = '~/data/FMNIST' # This can be any directory you want to
# download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)
tr_images = fmnist.data
tr_targets = fmnist.targets

val_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = val_fmnist.data
val_targets = val_fmnist.targets


########################################################################
import matplotlib.pyplot as plt
#matplotlib inline
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

########################################################################
class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/255 #归一化
        x = x.view(-1,28*28)
        self.x, self.y = x, y
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix]
        return x.to(device), y.to(device)
    def __len__(self):
        return len(self.x)

from torch.optim import SGD, Adam
def get_model():
    model = nn.Sequential(
        nn.Linear(28 * 28, 1000),
        nn.ReLU(),
        nn.Linear(1000, 10)
    ).to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=1e-2) # 优化器 Adam 或 SGD,做对比
    return model, loss_fn, optimizer

def train_batch(x, y, model, optimizer, loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

def accuracy(x, y, model):
    model.eval()
    # this is the same as @torch.no_grad
    # at the top of function, only difference
    # being, grad is not computed in the with scope
    with torch.no_grad():
        prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

########################################################################
def get_data():
    train = FMNISTDataset(tr_images, tr_targets)
    trn_dl = DataLoader(train, batch_size=32, shuffle=True)#批大小
    val = FMNISTDataset(val_images, val_targets)
    val_dl = DataLoader(val, batch_size=len(val_images), shuffle=False)
    return trn_dl, val_dl
########################################################################
#@torch.no_grad()
def val_loss(x, y, model):
    with torch.no_grad():
        prediction = model(x)
    val_loss = loss_fn(prediction, y)
    return val_loss.item()

########################################################################
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

########################################################################
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
for epoch in range(10): #轮数 10次
    print(epoch)
    train_epoch_losses, train_epoch_accuracies = [], []
    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss)
    train_epoch_loss = np.array(train_epoch_losses).mean()

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct)
    train_epoch_accuracy = np.mean(train_epoch_accuracies)
    for ix, batch in enumerate(iter(val_dl)):
        x, y = batch
        val_is_correct = accuracy(x, y, model)
        validation_loss = val_loss(x, y, model)
    val_epoch_accuracy = np.mean(val_is_correct)
    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    val_losses.append(validation_loss)
    val_accuracies.append(val_epoch_accuracy)

########################################################################
epochs = np.arange(10)+1
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
#%matplotlib inline
plt.figure(figsize=(20,5))
plt.subplot(211)
plt.plot(epochs, train_losses, 'bo', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation loss when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
#plt.show()
plt.subplot(212)
plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation accuracy when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()])
plt.legend()
plt.grid('off')
plt.show()

 

标签:loss,plt,第三章,val,batch,3.7,train,model,优化
From: https://www.cnblogs.com/excellentHellen/p/18609515

相关文章

  • 基于图神经网络的大语言模型检索增强生成框架研究:面向知识图谱推理的优化与扩展
    在大型语言模型(LLMs)相关的人工智能突破中,图神经网络(GNNs)与LLMs的融合已成为一个极具前景的研究方向。这两类模型的结合展现出显著的互补性,能够协同增强LLMs的推理能力和上下文理解能力。通过从知识图谱(KGs)存储的海量信息中进行智能化检索,该结合能够生成准确且不含幻觉的答案......
  • 第三章 3.6 批大小的影响
    第三章3.4训练神经网络 #https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch#https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################ChapterThree########################################第三章......
  • Java核心技术卷1 第三章选读
    前言本文内容选自Java核心技术卷1第10版,感兴趣的小伙伴可以自行阅读原书,以下内容为本人学习后摘取的片段与大家分享。正文3.3.2浮点类型所有的浮点数值计算都遵循IEEE754规范。具体来说,下面是用于表示溢出和出错情况的三个特殊的浮点数值:正无穷大负无穷大NaN(不......
  • HarmonyOS Next 应用性能优化实战
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)中应用性能优化的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。一、性能评估指标与工具......
  • SQL优化之《预警事件统计》
    在做一件什么事情:在首页大屏上,可以通过各种维度展示事件统计信息。sql如下:点击查看代码SELECTcount(*)count,camera_codegroupNameFROMalarm_eventWHEREalarm_event.illegal_tag="24"ANDalarm_event.organization_code......
  • SQL优化之--用户R值初始化
    在做一件什么事情:对新用户创建一个账号。如果用户账户已经存在,则对该账户的余额进行增减update操作。如果用户账户不存在,创建一个新的账户。并对用户账户明细表进行记录。对要插入的数据和系统中已经存在的数据取交集,然后与要插入的数据取补给,所得数据就是要插入系统中的新的用......
  • Component is not found in path “wx://not-found“.(env: macOS,mp,1.06.2409140; li
    Componentisnotfoundinpath"wx://not-found".(env:macOS,mp,1.06.2409140;lib:3.7.1)ErrorMessage:Componentisnotfoundinpath"wx://not-found".EnvironmentDetails:-OperatingSystem:macOS-Platform:mp(MiniProgram)-T......
  • SAP包装印刷行业解决方案:优化生产,提升竞争力
    在当今快速变化的市场中,包装印刷行业面临着多重挑战。随着全球化和技术创新的推进,客户需求日益多样化,产品生命周期缩短,交货时间压力增大,同时原材料和人工成本持续攀升。为了在激烈的市场竞争中保持优势,包装印刷企业亟需优化生产流程,提高产品质量,降低运营成本,并增强市场响应速度。S......
  • HarmonyOS Next 关于页面渲染的性能优化方案
    HarmonyOSNext关于页面渲染的性能优化方案HarmonyOSNext应用开发中,用户的使用体验至关重要。其中用户启动APP到呈现页面主要包含三个步骤:框架初始化页面加载布局渲染从页面加载到布局渲染中,主要包含了6个环节:执行页面文件生成页面节点树页面节点树挂载布局渲......
  • 掌握PageRank算法核心!你离Google优化高手只差一步!
    0前言98年前的搜索引擎体验不好:返回结果质量不高:搜索结果不考虑网页质量,而通过时间顺序检索易被钻空:搜索引擎基于检索词检索,页面中检索词出现的频次越高,匹配度越高,这样就会出现网页作弊的情况。有些网页为了增加搜索引擎的排名,故意增加某个检索词频率当时Google拉里·......