首页 > 其他分享 >定义损失函数并以此训练和评估模型

定义损失函数并以此训练和评估模型

时间:2024-07-15 23:25:37浏览次数:12  
标签:opt loss 定义 val yb train model 评估 函数

基础神经网络模型搭建 

【Pytorch】数据集的加载和处理(一)

【Pytorch】数据集的加载和处理(二)

损失函数计算模型输出和目标之间的距离。通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率

目录

基础模型搭建

数据集的加载和处理

定义损失函数

定义优化器

训练并评估模型


基础模型搭建

import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, x):
         pass
def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 20, 5, 1)
    self.conv2 = nn.Conv2d(20, 50, 5, 1)
    self.fc1 = nn.Linear(4*4*50, 500)
    self.fc2 = nn.Linear(500, 10)
def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2, 2) 
    x = x.view(-1, 4*4*50)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)
Net.__init__ = __init__
Net.forward = forward
model = Net()

检查搭建情况 

print(model)

 

原位置为cpu 

 

 转移至所需CUDA设备

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

 

数据集的加载和处理

导入MNIST训练数据集和验证数据集并处理

from torch import nn
from torchvision import datasets
from torch.utils.data import TensorDataset
path2data="./data"
train_data=datasets.MNIST(path2data, train=True, download=True)
x_train, y_train=train_data.data,train_data.targets
val_data=datasets.MNIST(path2data, train=False, download=True)
x_val,y_val=val_data.data, val_data.targets
if len(x_train.shape)==3:
    x_train=x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:
    x_val=x_val.unsqueeze(1)
print(x_val.shape)
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
for x,y in train_ds:
    print(x.shape,y.item())
    break

from torch.utils.data import DataLoader 
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)

 

 

定义损失函数

损失函数计算模型输出和目标之间的距离。Pytorch 中的 optim 包提供了各种优化算法的实现,例如SGD、Adam、RMSprop 等。

通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率。

loss_func = nn.NLLLoss(reduction="sum")
for xb, yb in train_dl:
    # move batch to cuda device
    xb=xb.type(torch.float).to(device)
    yb=yb.to(device)
    out=model(xb)
    loss = loss_func(out, yb)
    print (loss.item())
    break

得到一个测试值 

 

定义优化器

定义一个Adam优化器,优化器的输入是模型参数和学习率

from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-4)

通过opt .step()自动更新模型参数,同时需要注意计算下一批的梯度之前需将梯度归0

opt.step()
opt.zero_grad()

训练并评估模型

定义一个辅助函数 loss_batch来计算每个小批量的损失值。函数的 opt 参数引用优化器,如果给定,则计算梯度并按小批量更新模型参数。

def  loss_batch(loss_func,  xb,  yb,yb_h,  opt=None): 
    loss = loss_func(yb_h, yb) 
    metric_b =  metrics_batch(yb,yb_h) 
    if opt is  not None: 
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(),metric_b

 定义一个辅助函数metrics_batch来计算每个小批量的性能指标,这里以准确率作为分类任务的性能指标,并使用 output.argmax 来获取概率最高的预测类

def metrics_batch(target, output):
    pred = output.argmax(dim=1, keepdim=True)
    corrects=pred.eq(target.view_as(pred)).sum().item()
    return corrects

定义一个辅助函数loss_epoch来计算整个数据集的损失和指标值。使用数据加载器对象获取小批量,将它们提供给模型,并计算每个小批量的损失和指标,通过两个运行变量来分别添加损失值和指标值。

def loss_epoch(model,loss_func,dataset_dl,opt=None):
    loss=0.0
    metric=0.0
    len_data=len(dataset_dl.dataset)
    for xb, yb in dataset_dl:
        xb=xb.type(torch.float).to(device)
        yb=yb.to(device)
        yb_h=model(xb)
        loss_b,metric_b=loss_batch(loss_func, xb, yb,yb_h, opt)
        loss+=loss_b
        if metric_b is not None:
            metric+=metric_b
    loss/=len_data
    metric/=len_data
    return loss, metric

最后,定义一个辅助函数train_val来训练多个时期的模型。在每个时期使用验证数据集评估模型的性能。训练和评估需要分别使用 model.train()和 model.eval()模式。torch.no_grad()可以阻止 autograd 在评估期间计算梯度。

def train_val(epochs, model, loss_func, opt, train_dl, val_dl):
    for epoch in range(epochs):
        model.train()
        train_loss,train_metric=loss_epoch(model,loss_func,train_dl,opt)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_metric=loss_epoch(model,loss_func,val_dl)
        accuracy=100*val_metric
        
        print("epoch: %d, train loss: %.6f, val loss: %.6f,accuracy: %.2f" %(epoch, train_loss,val_loss,accuracy))

 设定时期数为5,调用函数进行训练和评估

num_epochs=5
train_val(num_epochs, model, loss_func, opt, train_dl, val_dl)

 

标签:opt,loss,定义,val,yb,train,model,评估,函数
From: https://blog.csdn.net/weixin_73404807/article/details/140450445

相关文章

  • 四、Python集合与函数
    集合set1.不同元素组成2.无序3.集合中元素必须是不可变类型s={1,2,3,4,5}集合常用魔法s={1,2,3,4,5}s.add('s')print(s)#>>>{1,2,3,4,5,'s'}s.add(6)print(s)#>>>{1,2,3,4,5,'s',6}s.clear()print(s)s={1,2,3,4,5}v=s......
  • DO、DTO、BO、AO、VO、POJO定义规范
    DO、DTO、BO、AO、VO、POJO定义分层领域模型规约:DO(DataObject):与数据库表结构一一对应,通过DAO层向上传输数据源对象DTO(DataTransferObject):数据传输对象,Service或Manager向外传输的对象BO(BusinessObject):业务对象。由Service层输出的封装业务逻辑的对象AO(Applicatio......
  • Vue 3 中 defineExpose() 函数的使用
    什么是defineExpose()?defineExpose()是Vue3提供的一个CompositionAPI函数,用于明确指定哪些内部响应式状态或函数可以被外部访问。如何使用defineExpose()?在子组件中定义:import{ref,defineExpose}from'vue'exportdefault{setup(){constcount=......
  • 全面分析构造函数(1)
    什么是构造函数             构造函数是在创建类对象时,由系统自动调用,初始化新对象的函数,给其中的成员变量赋值。构造函数没有返回值,名字与类名相同,有参数,所以可以进行函数重载,构造函数大致可以分为一下几类:无参构造:没有参数的构造函数,也是默认构造函数有参......
  • 关于使用自定义按键的探索和实现
    目前游戏中大都支持自定义按键,在设置页面中,为了给玩家一个舒适或者更有空间的操作方式,我对在ue4中自定义按键输入的实现进行探索,当然ue4和ue5版本差别不大可以说大同小异。对于自定义按键,比如玩家开枪,跳跃翻滚,本来使用q,we,来完成,但是我们在设置中可以改变他的按键,所以初步实现......
  • VTK-自定义交互器、可拖拽坐标轴、视图定向立方体
    源代码:https://github.com/qianqiu10000/mySWInteractorStyle1.0.git仿照SolidWorks的操作习惯自定义的VTK交互器:1.左键单击Actor,可以选择Actor,并显示红色2.左键双击Actor,可以在Actor位置弹出拖拽坐标轴,可以移动、旋转3.单击空格键,可以弹出立方体视图定向工具4.按住鼠标......
  • 14 mysql 函数
    在mysql中,函数主要分为内置函数(系统函数)和自定义函数不管是内置函数还是自定义函数,都是使用select函数名(参数列表);字符串函数char_length():判断字符串的字符数length():判断字符串的字节数(字符集有关)SELECTchar_length('你好,中国'),length('你好,中国');--返回结果:51......
  • STM32标准库函数功能介绍————EXTI库
    1.voidEXTI_DeInit(void);函数解释:EXTI的反初始化函数,即恢复默认状态。参数解释:无参数2.voidEXTI_Init(EXTI_InitTypeDef*EXTI_InitStruct);函数解释:EXTI的初始化函数参数解释:注意要加&号3.voidEXTI_StructInit(EXTI_InitTypeDef*EXTI_InitStruct);函数解释:将EXTI......
  • 遥感降水评估
    遥感降水可以作为地面雨量计和雷达观测降水的补充,在偏远山区和缺资料地区更为适合。目前,学界有多种降水数据,每一种降水数据都有独特的方法制作。因此,在使用前需要对这些降水的可靠性进行评估。在获得误差基础上,方可进行应用。 准备:经过插值的遥感降水:cmorph  降水,CN05.1......
  • opencv—常用函数学习_“干货“_7
    目录十九、模板匹配从图像中提取矩形区域的子像素精度补偿(getRectSubPix)在图像中搜索和匹配模板(matchTemplate)比较两个形状(轮廓)的相似度(matchShapes)解释二十、图像矩计算图像或轮廓的矩(moments)计算图像或轮廓的Hu不变矩(HuMoments)解释使用示例二一、......