首页 > 其他分享 >数据增强和泛化能力

数据增强和泛化能力

时间:2024-06-04 13:32:15浏览次数:15  
标签:loss 增强 泛化 nn self torch 能力 transforms test

一.数据增强是什么?

在深度学习中,数据增强是通过一定的方式改变输入数据,以生成更多的训练样本,从而提高模型的泛化能力和效果。数据增强可以减少模型对某些特征的过度依赖,从而避免过拟合。

二.什么是泛化能力?

模型泛化是指机器学习模型对新的、未见过的数据的适应能力。在机器学习中,我们通常会将已有的数据集划分为训练集和测试集,使用训练集训练模型,然后使用测试集来评估模型的性能。模型在训练集上表现得好,并不一定能在测试集或实际应用中表现得好。因此,我们需要保证模型具有良好的泛化能力,才能确保其在实际场景中的效果。

三.数据增强和模型泛化的联系

在深度学习中,要求样本数量充足,样本数量越多,训练出来的模型效果越好,模型的泛化能力越强。但是实际中,样本数量不足或者样本质量不够好,这时就需要对样本做数据增强,来提高样本质量。

例如在图像分类任务中,对于输入的图像可以进行一些简单的平移、缩放、颜色变换等操作,这些操作不会改变图像的类别,但可以增加训练样本的数量。这些增强后的样本可以帮助模型更好地学习和理解图像的特征,提高模型的泛化能力和准确率。

四.数据增强代码展示

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms

data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
        ]),
}

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        label = torch.from_numpy(np.array(label, dtype=np.int64))
        return image, label

training_data = food_dataset(file_path='.\\train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='.\\test.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

# from matplotlib import pyplot as plt
# image, label = iter(train_dataloader).__next__()
# sample = image[2]
# sample = sample.permute((1, 2, 0)).numpy()
# plt.imshow(sample)
# plt.show()
# print('Label is: {}'.format(label[2].numpy()))

import torch

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(device)
# import os
#
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'

from torch import nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),
        )

        self.out = nn.Linear(64 * 64 * 64, 20)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

model = CNN().to(device)
print(model)

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model.forward(x)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss:>7f}  [number:{batch_size_num}]")
        batch_size_num += 1

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()  #
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")
    acc_s.append(correct)
    loss_s.append(test_loss)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)

epochs = 50
acc_s = []
loss_s = []
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
from matplotlib import pyplot as plt
plt.subplot(1, 2, 1)
plt.plot(range(0, epochs), acc_s)
plt.xlabel("epoch")
plt.ylabel('accuracy')
plt.subplot(1, 2, 2)
plt.plot(range(0, epochs), loss_s)
plt.xlabel("epoch")
plt.ylabel('loss')
plt.show()

使用了数据增强即在这串代码里做出更改

data_transforms = {
    'train':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
        ]),
}

将其更改为:

data_transforms = {  # 也可以使用PIL库,smote 人工拟合出来数据
    'train':
        transforms.Compose([
            transforms.Resize([300, 300]),  # 是图像变换大小
            transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
            transforms.CenterCrop(256),  # 从中心开始裁剪[256,256]
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
            transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
            # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
            transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=B
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化,均值,标准差
        ]),
    'valid':
        transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
}

标签:loss,增强,泛化,nn,self,torch,能力,transforms,test
From: https://blog.csdn.net/lbr15660656263/article/details/139432988

相关文章

  • 有Cuda能力的GPU内核
    当CUDA应用程序启动一个内核时,CUDA运行时会确定系统中每个GPU的计算能力,并利用这些信息自动寻找最适合该GPU的内核cubin文件或PTX版本(如果可用)。如果存在支持目标GPU架构的cubin文件,将直接使用它;否则,CUDA运行时将加载PTX代码,并在启动之前将其即时编译(JIT编译)为GPU的本机cubin格式......
  • Syhunt Hybrid 7.0 (Windows) - 应用程序混合增强分析 (HAST)
    SyhuntHybrid7.0(Windows)-应用程序混合增强分析(HAST)SyhuntHybrid创新地融合了全面的静态和动态安全扫描请访问原文链接:https://sysin.org/blog/syhunt-hybrid/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.orgSyhuntHybrid获取深入的漏洞评估结果使用......
  • Cuda计算能力
    NVIDIACUDAC++编译器nvcc可用于生成针对特定架构的cubin文件和每个内核的向前兼容的PTX版本。每个cubin文件针对特定的计算能力版本,并且仅与具有相同主版本号的GPU架构向前兼容。例如,针对计算能力3.0的cubin文件在所有计算能力3.x(Kepler)设备上受支持,但在计算能力5.x(Maxwell)设备......
  • 学习笔记12:图像数据增强及学习速率衰减
    转自:https://www.cnblogs.com/miraclepbc/p/14360231.html数据增强常用数据增强方法:transforms.RandomCrop#随机位置裁剪transforms.CenterCrop#中心位置裁剪transforms.RandomHorizontalFlip(p=1)#随机水平翻转transforms.RandomVerticalFlip(p=1)#随机上下......
  • 计网期末复习指南(三):数据链路层(CRC冗余校验码计算、PPP协议、CSMA/CD协议、交换机的自
    前言:本系列文章旨在通过TCP/IP协议簇自下而上的梳理大致的知识点,从计算机网络体系结构出发到应用层,每一个协议层通过一篇文章进行总结,本系列正在持续更新中...  计网期末复习指南(一):计算机网络体系结构计网期末复习指南(二):物理层计网期末复习指南(三):数据链路层目录一.数......
  • 界面控件DevExtreme v23.2 - 可访问性、性能增强
    DevExtreme拥有高性能的HTML5/JavaScript小部件集合,使您可以利用现代Web开发堆栈(包括React,Angular,ASP.NETCore,jQuery,Knockout等)构建交互式的Web应用程序。从Angular和Reac,到ASP.NETCore或Vue,DevExtreme包含全面的高性能和响应式UI小部件集合,可在传统Web和下一代移动应用程序中......
  • WinDbg基本原理和使用方法,掌握基本的调试技术,并能够应用于实际的调试工作中;高级调试技
    WinDbg初级应用的大纲:1.WinDbg基础知识WinDbg简介:介绍WinDbg是什么以及其在Windows调试和分析中的作用。安装与配置:指导学员如何安装和配置WinDbg调试环境,包括下载安装、符号配置等基本步骤。2.调试基础调试流程:解释调试的基本流程,包括启动目标程序、设置断点、执行程序......
  • Autoruns工具的高级应用技巧和深度分析能力,能够在系统启动项管理、安全漏洞挖掘和恶意
    AutorunsforWindowsv14.11初级应用的大纲:1.简介与基础知识Autoruns简介:介绍Autoruns是一款由Sysinternals提供的系统启动项管理工具,用于查看和管理Windows系统启动时加载的所有程序、服务、驱动程序等。下载和安装:指导学习者如何下载并安装Autoruns工具,并介绍工具的界面和......
  • FSDump工具的内部原理和高级应用技术,基本用法和应用场景,掌握文件加密属性的查看和加密
    EFSDump初级应用的大纲:1.了解EFSDump简介:介绍EFSDump工具的作用、原理和功能。安装与配置:指导学习者如何获取和配置EFSDump工具,准备使用环境。2.基本用法查看文件加密属性:演示如何使用EFSDump查看文件的加密属性,识别加密文件。导出加密密钥:指导学习者如何导出文件的加......
  • ProcDump工具的基本用法和功能,并掌握如何利用它进行进程监视、性能分析和故障排查,从而
    ProcDump初级应用的大纲:1.ProcDump简介与基本用法介绍ProcDump工具的基本作用和功能。演示如何使用ProcDump来监视进程并在满足指定条件时生成转储文件。2.进程监视与性能分析探讨如何使用ProcDump监视进程的CPU利用率、内存占用等性能指标。演示如何利用ProcDump生成......