首页 > 其他分享 >学习笔记17:DenseNet实现多分类(卷积基特征提取)

学习笔记17:DenseNet实现多分类(卷积基特征提取)

时间:2024-06-04 09:44:20浏览次数:23  
标签:loss img 17 labels test epoch train 特征提取 DenseNet

转自:https://www.cnblogs.com/miraclepbc/p/14378379.html

数据集描述

总共200200类图像,每一类图像都存放在一个以类别名称命名的文件夹下,每张图片的命名格式如下图:

数据预处理

首先分析一下我们在数据预处理阶段的目标和工作流程

  • 获取每张图像以及对应的标签

  • 划分测试集和训练集

  • 通过写数据集类的方式,获取数据集并进一步获得DataLoader

  • 打印图片,验证效果

获取图像及标签

all_imgs_path = glob.glob(r'E:\birds\birds\*\*.jpg') # 获取所有图像路径列表
all_labels_name = [i.split('\\')[3].split('.')[1] for i in all_imgs_path] # 获取每张图像的标签名
label_to_index = dict([(v, k) for k, v in enumerate(unique_labels)]) # 将标签名映射到数值
# 获取每张图片的数值标签
all_labels = []
for img in all_imgs_path:
    for k, v in label_to_index.items():
        if k in img:
            all_labels.append(v)

划分测试集和训练集

以下代码可以作为模板来用,不做额外解释

np.random.seed(2021)
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path) * 0.8)

train_path = all_imgs_path[:s]
train_labels = all_labels[:s]
test_path = all_imgs_path[s:]
test_labels = all_labels[s:]

通过写数据集类的方式,获取数据集并进一步获得DataLoader

以下代码可以作为模板来用,不做额外解释

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

class BirdsDataset(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        pil_img = Image.open(img)
        pil_img = pil_img.convert('RGB') # 这一句是专门用来解决一种RuntimeError的
        np_img = np.array(pil_img, dtype = np.uint8)
        if np_img.shape == 2:
            img_data = np.repeat(np_img[:, :, np.newaxis], 3, axis = 2)
            pil_data = Image.fromarray(img_data)
        data = self.transforms(pil_img)
        return data, label
    def __len__(self):
        return len(self.imgs)

train_ds = BirdsDataset(train_path, train_labels, transform)
test_ds = BirdsDataset(test_path, test_labels, transform)
train_dl = data.DataLoader(train_ds, batch_size = 32) # 这里只是提取卷积基,不做训练,因此不用shuffle
test_dl = data.DataLoader(test_ds, batch_size = 32)

结果查看

取出一个批次的数据,绘图

img_batch, label_batch = next(iter(train_dl))
plt.figure(figsize = (12, 8)) # 定义画布大小
index_to_label = dict([(k, v) for k, v in enumerate(unique_labels)])
for i, (img, label) in enumerate(zip(img_batch[:3], label_batch[:3])):
    img = img.permute(1, 2, 0).numpy() # 将channel放在最后一维
    plt.subplot(1, 3, i + 1)
    plt.title(index_to_label.get(label.item()))
    plt.imshow(img)

结果如下:

提取卷积基

这一阶段的工作流程如下:

  • 获取DenseNet预训练模型,使用feature部分

  • 使用卷积基提取图像特征,并存放在列表中

预训练模型获取

my_densenet = models.densenet121(pretrained = True).features

if torch.cuda.is_available():
    my_densenet = my_densenet.cuda()

for p in my_densenet.parameters():
    p.requires_grad = False

提取图像特征

train_features = []
train_features_labels = []
for im, la in train_dl:
    out = my_densenet(im.cuda())
    out = out.view(out.size(0), -1) # 这里需要进行扁平化操作,因为后面要进行线性模型预测
    train_features.extend(out.cpu().data) # 这里注意是extend,extend可以将一个列表加到另一个列表的后面
    train_features_labels.extend(la)

test_features = []
test_features_labels = []
for im, la in test_dl:
    out = my_densenet(im.cuda())
    out = out.view(out.size(0), -1)
    test_features.extend(out.cpu().data)
    test_features_labels.extend(la)

重新定义数据集

因为后面要通过线性模型来预测,因此之前的图像数据集就不好用了

因此需要用刚刚提取到的特征,重新制作数据集

class FeatureDataset(data.Dataset):
    def __init__(self, feature_list, label_list):
        self.feature_list = feature_list
        self.label_list = label_list
    def __getitem__(self, index):
        return self.feature_list[index], self.label_list[index]
    def __len__(self):
        return len(self.feature_list)

train_feature_ds = FeatureDataset(train_features, train_features_labels)
test_feature_ds = FeatureDataset(test_features, test_features_labels)
train_feature_dl = data.DataLoader(train_feature_ds, batch_size = 32, shuffle = True)
test_feature_dl = data.DataLoader(test_feature_ds, batch_size = 32)

模型定义与预测

这里定义一个线性模型即可

模型定义

class FCModel(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear = nn.Linear(in_size, out_size)
    def forward(self, input):
        return self.linear(input)

in_feature_size = train_features[0].shape[0]
net = FCModel(in_feature_size, 200)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.00001)
epochs = 30

模型训练

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    
    model.train()
    for x, y in trainloader:
        y = torch.tensor(y, dtype = torch.long)
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim = 1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
    
    epoch_acc = correct / total
    epoch_loss = running_loss / len(trainloader.dataset)
    
    test_correct = 0
    test_total = 0
    test_running_loss = 0
    
    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            y = torch.tensor(y, dtype = torch.long)
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            y_pred = torch.argmax(y_pred, dim = 1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    epoch_test_acc = test_correct / test_total
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy: ', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy: ', round(epoch_test_acc, 3))
    
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, net, train_feature_dl, test_feature_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

训练结果

标签:loss,img,17,labels,test,epoch,train,特征提取,DenseNet
From: https://www.cnblogs.com/gongzb/p/18230180

相关文章

  • CTFshow-Crypto(17-25)
    17EZ_avbv(easy)18贝斯多少呢base62穷举分段给了段编码,hint为base628nCDq36gzGn8hf4M2HJUsn4aYcYRBSJwj4aE0hbgpzHb4aHcH1zzC9C3IL随波逐流和Cyberchef都没梭哈出来看了师傅们的wp大概意思是:分组长度固定,但是不一定是被整除为整数,只要找到从头开始截取一个长度解出明文,就......
  • [leetcode 3171] 解法列表
    线段树解法+二分classSolution{publicintminimumDifference(int[]nums,intk){this.nums=nums;this.n=nums.length;returncheck(k);}publicstaticvoidmain(String[]args){Solutionsolution=newSol......
  • 数学森林/洛谷P1750 出栈序列
    原创新思路求解出栈序列问题。问题描述:数学家小王经过千辛万苦长途跋涉终于来到了数学森林。无奈森林入口有很多个小矮人镇守。小矮人拿出一套题目让小王抽取一道题目说解出题目方能进入数学森林。题目如下:给定一个大小为c(最多可以同时存储c个元素)的堆栈,输入n个入栈的数,请输......
  • 【文末附gpt升级秘笈】关于论文“7B?13B?175B?解读大模型的参数的论文
    论文大纲引言简要介绍大模型(深度学习模型)的概念及其在各个领域的应用。阐述参数(Parameters)在大模型中的重要性,以及它们如何影响模型的性能。引出主题:探讨7B、13B、175B等参数规模的大模型。第一部分:大模型的参数规模定义“B”代表的意义(Billion/十亿)。解释7B、13B、175B等......
  • 持续性学习-Day17(MySQL)
    1、初识MySQLJavaEE:企业级Java开发Web前段(页面展示,数据)后端(连接点:连接数据库JDBC;链接前端:控制,控制反转,给前台传数据)数据库(存数据)1.1数据库分类关系型数据库(SQL):MySQL、Oracle、SqlServer、DB2、SQLlite通过表和表、行和列之间的关系进行数据的存储非关系型数......
  • 4.16-4.17技术支持面试
    1、讲讲你的实习经历xxx2、讲讲密码学,对称和非对称(公钥加密)的区别,非对称是否可以用私钥加密;对称和非对称区别在于,对称使用同一个密钥加密解密(有安全隐患),非对称是公钥加密私钥解密(私钥一般储存在服务器);可行,应用于数字签名方面可以,私钥加密公钥验证解密签名,但是数字签名的过程包......
  • 【计算机毕业设计】ssm717出租车管理系统的设计与实现+vue
    现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本出租车管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息,使用这种软件工具可以帮助管理人员提高事务处理效率,达到......
  • Navicat 17 体验官火热招募中 | 优选好礼等您来
    体验官火热招募中......
  • ../common/fdfs_global.h:17:26: fatal error: sf/sf_global.h: No such file or dire
    安装fastdfs之前需要安装一下libserverframe在解压后的fastdfs文件夹下的INSTALL里有说 打开链接:https://github.com/happyfish100/libserverframe/tags,选择一个合适的版本 [root@hqqfastdfs]#tar-zxvflibserverframe-1.2.3.tar.gz[root@hqqfastdfs]#cdlibserv......
  • Zcmu-1178
    思路:分析题目要求的就是由2,3,5,7单独相乘或者组合相乘的数字。所以将数字循环起来相乘,之后结果按从大到小地无重复放进数组当中。学长#include<set>#include<queue>#include<vector>#include<cstdio>usingnamespacestd;typedeflonglongll;intnum[4]={2,3,5......