首页 > 其他分享 >Dataset and DataLoader

Dataset and DataLoader

时间:2024-08-17 11:53:11浏览次数:8  
标签:__ loss torch nn self DataLoader Dataset data

刘二大人_第八节课

代码:

import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import Dataset # 抽象类,不可实例化
from torch.utils.data import DataLoader # help us loading data in PyTorch
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="True"


class DiabetesDataset(Dataset):
    # 继承DataSet的类需要重写init,getitem,len魔法函数。
    # 分别是为了加载数据集,获取数据索引,获取数据总量
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=" ", dtype=np.float32)
        # shape本身是一个二元组(x,y)对应数据集的行数和列数,这里[0]我们取行数,即样本数
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    # 用文件的文件名作为文件内容的索引,读到内存中,等用的时候,再去索引文件名去读取内容
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset("diabetes_data.csv.gz")
# 我们用DataLoader为数据进行分组,batch_size是一个组中有多少个样本,shuffle表示要不要对样本进行随机排列
# 一般来说,训练集我们随机排列,测试集不。num_workers表示我们可以用多少进程并行的运算
train_loader = DataLoader(dataset=dataset,
                          batch_size= 32,
                          shuffle= True,
                          num_workers=2)

class Model(torch.nn.Module):
    def __init__(self): # 构造函数
        super(Model, self).__init__()
        # 创建了一个线性层,输入特征数为9,输出特征数为6。线性层在神经网络中通常用于实现权重矩阵乘法和偏置的加法
        self.linear1 = torch.nn.Linear(9, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        # Sigmoid函数通常用于二分类问题中,将线性层的输出转换为概率值。
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.sigmoid(self.linear3(x)) # 对线性层加激活函数 sigmoid
        return x

model = Model() # 实例化模型

epoch_list = []
loss_list = []
# criterion = torch.nn.BCELoss(size_average=True)
criterion = torch.nn.BCELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

if __name__ == '__main__': # if这条语句在windows系统下一定要加,否则会报错
    for epoch in range(100): # 外层循环是训练周期
        for i, data in enumerate(train_loader, 0): # 内层循环是mini-batch
            # 1. Prepare data
            inputs, labels = data # #将输入的数据赋给inputs,标签赋给labels
            # 2. Forward
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
            # 添加到列表中
            loss_list.append(loss.item())
            epoch_list.append(epoch)
            # 3. Backward
            optimizer.zero_grad()
            loss.backward()
            # 4.Upgrade
            optimizer.step()
    # 绘图准备 
    plt.plot(epoch_list, loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('')
    plt.show()

训练结果如下:

标签:__,loss,torch,nn,self,DataLoader,Dataset,data
From: https://blog.csdn.net/weixin_72454974/article/details/141263729

相关文章

  • AttributeError:“CarvanaDataset”对象没有属性“image”
    我尝试制作一个Unet模型。这是我的代码:importtorchimporttorch.nnasnnimporttorch.optimasoptimimportalbumentationsasAfromalbumentations.pytorchimportToTensorV2fromtqdmimporttqdmfrommodelimportUnetfromutilsimportget_loadersimport......
  • [转]相同CRC不同数据的测试.CRC16 - CRC64 test results on 18.2M dataset
    转载自: http://www.backplane.com/matt/crc64.html  CRC16-CRC64testresultson18.2Mdataset,w/programsourceProgram&TestRunbyMattDillon18.2Mmessage-iddatasetsuppliedbyJoeGrecoIwouldliketothankeveryonewhoofferedtheirhistoryf......
  • Pytorch笔记|小土堆|P14-15|torchvision数据集使用、Dataloader使用
    学会看内置数据集的官方文档:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10示例代码:importtorchvisionfromtorch.utils.tensorboardimportSummaryWriterfromtorchvisionimporttransforms#ToTensorte......
  • Python,Geopandas报错,AttributeError: The geopandas.dataset has been deprecated and
    Python版本3.9,Geopandas版本1.0.1问题描述:这是执行的代码,importpandasaspdimportgeopandasimportmatplotlib.pyplotaspltworld=geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))world.plot()plt.show()这是报错信息,Traceback(mo......
  • torch.utils.data.Dataset 和 torch.utils.data.DataLoader
    torch.utils.data是PyTorch中用于数据加载和预处理的模块。通常结合使用其中的Dataset和DataLoader两个类来加载和处理数据。Datasettorch.utils.data.Dataset是一个抽象类,用于表示数据集。需要用户自己实现两个方法:__len__和__getitem__。__len__方法返回数据集的大小,__getit......
  • Pytorch笔记|小土堆|P5-6|Dataset类
    Dataset类作用:模型的数据集接口__init__将对象实例化,创建对象时obj=class(...,...)会立即被调用,需要提供(输入)类中使用到的变量。__getitem__通过img,label=obj[idx]获取(返回)每一个数据和label__len__通过len(obj)获取(返回)数据量点击查看代码fromtorch.utils.dataim......
  • 【待做】【AI+安全】数据集:HTTP DATASET CSIC 2010
    HTTPDATASETCSIC2010HTTPDATASETCSIC2010包含已经标注过的针对Web服务的请求。该数据集由西班牙最高科研理事会CSIC在论文ApplicationoftheGenericFeatureSelectionMeasureinDetectionofWebAttacks中作为附件给出的,是一个电子商务网站的访问日志,包含36000......
  • Pytorch Dataloader 添加批量维度
    我认为这个问题已经被问过几次了,但我还没有在这里找到一个好的答案。所以我有一个由2个numpy数组组成的Pytorch数据集。以下是维度。特征=[10000,450,28]numpy数组。dim_0=样本数,dim_1=时间序列,dim_2=特征。基本上我有一个450帧长的数据,其中每......
  • 我想使用 torch DataLoader 并使用生产者/消费者模式,但它卡住了
    我想更改torchDataLoader,并在其中使用消费者/生产者模式。我有一个队列,一个线程将文件放入其中,这些项目由框架使用__getitem__使用。这是我的代码:importglobimporttimefromtorch.utils.dataimportDataLoader,Datasetimportthreadingimportqueue......
  • datasets(HuggingFace)学习笔记
    一、概述(1)datasets使用ApacheArrow格式,使得加载数据集没有内存限制(2)datasets的重要模块:load_dataset:用于加载原始数据文件load_from_disk:用于加载Arrow数据文件DatasetDict:用于操作多个数据集,保存、加载、处理等Dataset:用于操作单个数据集,保存、加载、处理等二、数据......