首页 > 其他分享 >迁移学习代码复现

迁移学习代码复现

时间:2024-08-16 09:26:57浏览次数:11  
标签:nn 代码 train 复现 图像 import 迁移 224 transforms

一、前言

说来可能令人难以置信,迁移学习技术在实践中是非常简单的,我们仅需要保留训练好的神经网络整体或者部分网络,再在使用迁移学习的情况下把保留的模型重新加载到内存中,就完成了迁移的过程。之后,我们就可以像训练普通神经网络那样训练迁移过来的神经网络了。
我们使用已经训练好的大型图像分类卷积神经网络来做一个分类任务:区分画面上的动物是蚂蚁还是蜜蜂

二、数据导入

将蚂蚁蜜蜂数据集进行一些(增强)图形处理,如随机从原始图像中切下来一块224×224大小的区域,随机水平翻转图像,将图像的色彩数值标准化等等。数据增强的目的是:增强模型的鲁棒性,提高模型的泛化能力和性能。

1、导入相应的函数库

# 加载程序所需要的包
#import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import copy
import os

导入蚂蚁蜜蜂数据集,并将这些照片进行一定的照片增强处理,最后以data_loader的形式存储。注:该代码是在kaggle上进行训练,所以data_dir是kaggle中input存储的位置,运行时,需根据具体情况修改文件路径。

# 数据存储总路径
data_dir = '/kaggle/input/bees-and-ants/蚂蚁蜜蜂数据集'
# 图像的大小为224×224
image_size = 224
# 加载的过程将会对图像进行如下增强操作:
# 1. 随机从原始图像中切下来一块224×224大小的区域
# 2. 随机水平翻转图像
# 3. 将图像的色彩数值标准化
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                     transforms.Compose([
                                         transforms.RandomResizedCrop(image_size),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                         ])
                                     )

# 加载校验数据集,对每个加载的数据进行如下处理:
# 1. 放大到256×256
# 2. 从中心区域切割下224×224大小的区域
# 3. 将图像的色彩数值标准化
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                   transforms.Compose([
                                       transforms.Resize(256),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                       ])
                                   )
# 创建相应的数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 4, shuffle = True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 4, shuffle = True, num_workers=4)
# 读取得出数据中的分类类别数
num_classes = len(train_dataset.classes)

2、展示部分照片

数据存储在train_loader(val_loader)的迭代器中,其中imgs是以数字的形式展示。在 PyTorch 中,图像数据通常以 (C, H, W) 的顺序存储,其中 C 是颜色通道数(例如 RGB 的 3 个通道),H 是图像的高度,W 是图像的宽度。然而,matplotlib 的 imshow 函数期望的图像数据是 (H, W, C) 顺序的,即高度和宽度作为前两个维度,颜色通道作为最后一个维度。所以img_np是通过通道转换后的照片(同时也通过Normalize标准化后的照片),imgs[0][0]是灰度的照片。

imgs,label=next(iter(train_loader))
fig,ax=plt.subplots(1,2,figsize=(12,6))
# 展示标准化后的照片
img_np = imgs[0].permute(1, 2, 0).numpy()
ax[0].imshow(img_np)
# 展示标准化后的照片
ax[1].imshow(imgs[0][0])
plt.show()

原图如下:
在这里插入图片描述
图如下:
在这里插入图片描述

三、使用自定义的模型进行训练

1、定义模型

自定义卷积神经网络模型,并确定前向传播每层之间的连接方式。

# 使用自定义的CNN模型
depth = [4, 8]
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 4, 5, padding = 2) #输入通道为3,输出通道为4,窗口大小为5,padding为2
        self.pool = nn.MaxPool2d(2, 2) #一个窗口为2*2的pooling运算
        self.conv2 = nn.Conv2d(depth[0], depth[1], 5, padding = 2) #第二层卷积,输入通道为depth[0], 输出通道为depth[1],窗口为15,padding为2

标签:nn,代码,train,复现,图像,import,迁移,224,transforms
From: https://blog.csdn.net/weixin_57342469/article/details/141192793

相关文章

  • 「代码随想录算法训练营」第三十九天 | 动态规划 part12
    115.不同的子序列题目链接:https://leetcode.cn/problems/distinct-subsequences/文章讲解:https://programmercarl.com/0115.不同的子序列.html题目难度:困难视频讲解:https://www.bilibili.com/video/BV1fG4y1m75Q/题目状态:看题解思路:动态规划数组初始化创建一个二维动......
  • 混合策略改进的蜣螂算法(IDBO)优化长短期记忆神经网络原理及matlab代码
    目录0引言1数学模型2模型对比3matlab代码3.1改进的主代码3.2IDBO-LSTM4视频讲解0引言针对DBO算法全局探索能力不足、易陷入局部最优以及收敛精度不理想等问题,多为学者提出了混合多策略改进的蜣螂优化算法(IDBO)。主要混合策略改进首先是采用混沌映射结合随机反......
  • 混合策略改进的蜣螂算法(IDBO)优化支持向量机原理及matlab代码
    目录0引言1数学模型2模型对比3matlab代码3.1改进的主代码3.2IDBO-SVM4视频讲解0引言针对DBO算法全局探索能力不足、易陷入局部最优以及收敛精度不理想等问题,多为学者提出了混合多策略改进的蜣螂优化算法(IDBO)。主要混合策略改进首先是采用混沌映射结合随机反向......
  • 探索Swift模块化测试的艺术:构建可维护的代码框架
    标题:探索Swift模块化测试的艺术:构建可维护的代码框架在Swift语言的生态中,代码模块化测试是一个至关重要的实践,它不仅有助于确保代码的可靠性,还能提高开发效率和代码质量。Swift的模块化测试框架提供了一套强大的工具和方法,使得开发者能够以模块化的方式进行测试。本文将深......
  • 代码随想录Day16
    513.找树左下角的值给定一个二叉树的根节点root,请找出该二叉树的最底层最左边节点的值。假设二叉树中至少有一个节点。示例1:输入:root=[2,1,3]输出:1示例2:输入:[1,2,3,4,null,5,6,null,null,7]输出:7提示:二叉树的节点个数的范围是[1,104]-231<=......
  • 编程基础题:开关灯(C语言方式代码,C++方式代码,Python3方式编写)三种语言编写代码
    1.题目描述:假设有N蓋灯(N为不大于5000的正整数),从1到N按顺序依次编号,初始时全部处于开启状态;第一个人(1号)将灯全部关闭,第二个人(2号)将编号为2的倍数的灯打开,第三个人(3号)将编号为3的倍数的灯做相反处理(即,将打开的灯关闭,将关闭的灯打开)。依照编号递增顺序,以......
  • Swift编译器的代码验证机制:安全与效率的双重协奏
    标题:Swift编译器的代码验证机制:安全与效率的双重协奏引言Swift语言以其安全性和高性能著称,这在很大程度上归功于其编译器的精妙设计。Swift编译器的代码验证机制是确保代码既安全又高效的关键环节。本文将详细探讨Swift编译器的代码验证机制,并展示如何在实践中利用这些机......
  • Vue 项目中,设置的 `color` 样式为 Hex 代码,但最终显示为 RGB 代码 情况原因
    在Vue项目中,设置的color样式为Hex代码,但最终显示为RGB代码,这通常是由于以下几种情况导致:1.CSS预处理器(Sass,Less)的影响:当你使用Sass或Less等CSS预处理器时,它们会将Hex颜色代码转换为RGB颜色代码,以便更好地进行颜色计算和操作。如果你在style属性......
  • 【代码随想录】一、数组:6.前缀和
    二刷的时候发现更新了一些新的题目,尝试写了写后,发现我完全不会ACM输入输出模式。这两天在补前几天没背的八股,写得不够满意(几乎是完全誊代码了),先放着,后面再补充补充吧。1.题目:44.开发商购买土地#include<iostream>#include<vector>#include<climits>usingnamespacestd......
  • 【Django开发】前后端分离django美多商城项目第1篇:欢迎来到美多 项目主要页面介绍【附
    本教程的知识点为:项目准备项目准备配置1.修改settings/dev.py文件中的路径信息2.INSTALLED_APPS3.数据库用户部分图片1.后端接口设计:视图原型2.具体视图实现用户部分使用Celery完成发送判断帐号是否存在1.判断用户名是否存在后端接口设计:用户部分......