首页 > 其他分享 >24、检查所训练的模型是否训练好,是否达到预期目标

24、检查所训练的模型是否训练好,是否达到预期目标

时间:2023-02-27 11:37:21浏览次数:32  
标签:24 loss 训练 是否 item step test total

1、其实在没训练完 一轮 之后,可以对它进行一个测试,在测试数据集上跑一遍

以测试集上的损失或者正确率来评估模型是否训练好

2、在测试的过程中不需要进行调优,所以可以用 with torch.no_grad():

        #测试步骤:
        total_test_loss=0   #记录总的损失差
        with torch.no_grad():
            for data in test_dataloader:
                imgs,targets=data
                outputs=tuidui(imgs)
                loss=loss_fn(outputs,targets)    #这个loss只是一条数据的,要求在测试集上的总的损失
                total_test_loss=total_test_loss+loss.item()
        print('整体的测试集上的loss:{}',format(total_test_loss))

3、对于打印要求的设置:

#可以进行打印设置:每训练100次,打印一次
        if total_test_step%100==0:
            print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))

4、使用tensorboard来进行可视化

'''添加tensorboard;把每一次的训练进行可视化'''
writer=SummaryWriter('logs')

for i in range(epoch):
print('--------第{}轮训练开始----------'.format(i+1))
#训练步骤开始
for data in train_dataloader:
imgs,targets=data
outputs=tuidui(imgs)
loss=loss_fn(outputs,targets)

#进行优化的第一步是梯度清零
optimizer.zero_grad()
#利用损失来求每一个参数节点的梯度
loss.backward()
#进行优化
optimizer.step()
#更新训练次数
total_train_step+=1
#可以进行打印设置:每训练100次,打印一次
if total_test_step%100==0:
print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))
#loss.item(),使用item会让tensor类型的数据直接变成数字,数值型

writer.add_scalar('train_loss',loss.item(),total_train_step)

#测试步骤:
total_test_loss=0 #记录总的损失差
with torch.no_grad():
for data in test_dataloader:
imgs,targets=data
outputs=tuidui(imgs)
loss=loss_fn(outputs,targets) #这个loss只是一条数据的,要求在测试集上的总的损失
total_test_loss=total_test_loss+loss.item()
print('整体的测试集上的loss:{}',format(total_test_loss))
writer.add_scalar('test_loss',total_test_loss,total_test_step)
total_test_step=total_test_step+1
writer.close()

 5、测试模型的正确率

 模型中我们的输出output是概率值的形式,并不是标签的形式。

转换:

  Argmax,它可以输出横向中概率值最大的那个的位置索引。转化完之后,再让模型输出的标签和真实图像的标签进行比较。如果全部相等则表明预测全部正确,如果个数不一致代表不相等。

 

 

 

 

 

 

 

 

 

标签:24,loss,训练,是否,item,step,test,total
From: https://www.cnblogs.com/ar-boke/p/17158298.html

相关文章

  • 【2023-02-24】连岳摘抄
    23:59我生活中最喜欢的东西不需要花钱。很明显,我们拥有的最宝贵的资源就是时间。                           ......
  • 《分布式技术原理与算法解析》学习笔记Day24
    分布式缓存在计算机领域,缓存是一个非常重要的、用来提升性能的技术。什么是分布式缓存?缓存技术是指用一个更快的存储设备存储一些经常用到的数据,供用户快速访问。分布......
  • 第三章图3124
    importpandasaspdcatering_sale="C:/Users/Lenovo/Desktop/catering_sale.xls"data=pd.read_excel(catering_sale,index_col=u'日期')print(data.describe())importm......
  • stm32笔记[5]-FreeRTOS及(软IIC)读写AT24C02
    STM32CubeIDE使用FreeRTOS教程资料FreeRTOS从入门到精通1--实时操作系统的前世今生FreeRTOS从入门到精通2--人生若只如初见,初识STM32CubeIDEFreeRTOS从入门到精通3--......
  • 2022-SZUACM招新 训练赛2
    2022-SZUACM招新训练赛2https://vjudge.net/contest/544906#overview下午打了一下,稍微记录一波,有一些蛮有意思的小题。A-Arrayhttps://codeforces.com/problemset/p......
  • 24.生产环境中对不同时区问题的处理办法
    1.current_date,current_timestamp,localtimestamp--1.在会话中修改时区--偏移量-tz_offset--数据库时区--time_zone--系统本地时区--local--区域名--v$timezone_name......
  • 蓝桥杯训练赛二-问题 A
    题目描述用简单素数筛选法求N以内的素数。输入N输出2~N的素数样例输入100样例输出23571113171923293137414347535961677173798......
  • 242. 有效的字母异位词
    1classSolution{2public:3boolisAnagram(strings,stringt){4if(s.size()!=t.size())returnfalse;5string::iterators_it......
  • 23、完整模型训练的步骤
    1、以CIFAR10为例子1'''以CIFAR10为例子训练完整的模型步骤'''2importtorch3importtorchvision4fromtorchimportnn5fromtorch.nnimportConv2d6......
  • Educational Codeforces Round 24
    EducationalCodeforcesRound24https://codeforces.com/contest/818有些题就是从某个角度想好复杂,不好实现,但是换一种思考方式,从另一个角度想就会豁然开朗,也很好写。这......