首页 > 其他分享 >学习笔记14:模型保存

学习笔记14:模型保存

时间:2024-06-04 09:46:23浏览次数:15  
标签:acc loss 14 模型 笔记 epoch test new model

转自:https://www.cnblogs.com/miraclepbc/p/14361926.html

保存训练过程中使得测试集上准确率最高的参数

import copy
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(extend_epoch):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
    if epoch_test_acc > best_acc:
        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = epoch_test_acc
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
model.load_state_dict(best_model_wts)

保存模型

PATH = 'E:/my_model.pth'
torch.save(model.state_dict(), PATH)

重新加载模型

new_model = models.resnet101(pretrained = True)
in_f = new_model.fc.in_features
new_model.fc = nn.Linear(in_f, 4)
new_model.load_state_dict(torch.load(PATH))

测试是否加载成功

new_model.to(device)
test_correct = 0
test_total = 0
new_model.eval()
with torch.no_grad():
    for x, y in test_dl:
        x, y = x.to(device), y.to(device)
        y_pred = new_model(x)
        loss = loss_func(y_pred, y)
        y_pred = torch.argmax(y_pred, dim = 1)
        test_correct += (y_pred == y).sum().item()
        test_total += y.size(0)
epoch_test_acc = test_correct / test_total
print(epoch_test_acc)

标签:acc,loss,14,模型,笔记,epoch,test,new,model
From: https://www.cnblogs.com/gongzb/p/18230159

相关文章

  • 学习笔记15:第二种加载数据的方法
    转自:https://www.cnblogs.com/miraclepbc/p/14367560.html构建路径集和标签集取出所有路径importgloball_imgs_path=glob.glob(r"E:\datasets2\29-42\29-42\dataset2\dataset2\*.jpg")获得所有标签species=['cloudy','rain','shine',&......
  • 学习笔记16:残差网络
    转自:https://www.cnblogs.com/miraclepbc/p/14368116.html产生背景随着网络深度的增加,会出现网络退化的现象。网络退化现象形象化解释是在训练集上的loss不增反降。这说明,浅层网络的训练效果要好于深层网络一个想法就是,如果将浅层网络的特征传到深层网络,那么深层网络的训练效果......
  • 学习笔记17:DenseNet实现多分类(卷积基特征提取)
    转自:https://www.cnblogs.com/miraclepbc/p/14378379.html数据集描述总共200200类图像,每一类图像都存放在一个以类别名称命名的文件夹下,每张图片的命名格式如下图:数据预处理首先分析一下我们在数据预处理阶段的目标和工作流程获取每张图像以及对应的标签划分测试集和训......
  • 学习笔记19:图像定位
    转自:https://www.cnblogs.com/miraclepbc/p/14385623.html图像定位的直观理解不仅需要我们知道图片中的对象是什么,还要在对象的附近画一个边框,确定该对象所处的位置。也就是最终输出的是一个四元组,表示边框的位置图像定位网络架构可以将图像定位任务看作是一个回归问题!数据......
  • 学习笔记11:预训练模型
    转自:https://www.cnblogs.com/miraclepbc/p/14348536.html什么是预训练网络预训练模型就是之前用较大的数据集训练出来的模型,这个模型通过微调,在另外类似的数据集上训练。一般预训练模型规模比较大,训练起来占用大量的内存资源。微调预训练网络我们采用vgg16作为预训练模型,来实......
  • 学习笔记12:图像数据增强及学习速率衰减
    转自:https://www.cnblogs.com/miraclepbc/p/14360231.html数据增强常用数据增强方法:transforms.RandomCrop#随机位置裁剪transforms.CenterCrop#中心位置裁剪transforms.RandomHorizontalFlip(p=1)#随机水平翻转transforms.RandomVerticalFlip(p=1)#随机上下......
  • 学习笔记13:微调模型
    转自:https://www.cnblogs.com/miraclepbc/p/14360807.htmlresnet预训练模型resnet模型与之前笔记中的vgg模型不同,需要我们直接覆盖掉最后的全连接层先看一下resnet模型的结构:我们需要先将所有的参数都设置成requires_grad=False然后再重新定义fc层,并覆盖掉原来的。重新定义的......
  • 笔记7:训练过程封装(代码模板)
    转自:https://www.cnblogs.com/miraclepbc/p/14335456.html相关包importtorchimportpandasaspdimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchimportnnimporttorch.nn.functionalasFfromtorch.utils.dataimportTensorDatasetfromtorch.utils.......
  • 学习笔记8:全连接网络实现MNIST分类(torch内置数据集)
    转自:https://www.cnblogs.com/miraclepbc/p/14344935.html相关包导入importtorchimportpandasaspdimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchimportnnimporttorch.nn.functionalasFfromtorch.utils.dataimportTensorDatasetfromtorch.ut......
  • 学习笔记9:卷积神经网络实现MNIST分类(GPU加速)
    转自:https://www.cnblogs.com/miraclepbc/p/14345342.html相关包导入importtorchimportpandasaspdimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchimportnnimporttorch.nn.functionalasFfromtorch.utils.dataimportTensorDatasetfromtorch.ut......