首页 > 其他分享 >8-加载数据集

8-加载数据集

时间:2024-08-14 21:38:01浏览次数:13  
标签:__ nn 数据 self torch data def 加载






数据的读取方式
1、如果数据量比较小,直接读入内存,通过data[i]获取
2、如果数据量很大,我们不能直接读入内存,比如数据有很多文件,我们可以将文件名存储到一个文件,通过names[i]获取文件名,然后再去读取数据

dataloader加载器

多线程的错误问题

在linux多线程是通过fork创建的,但是在windows是通过spawn创建的,所以会出现运行时错误。
解决方法是将代码写入if-else语句,而不是直接写在for循环

即下面这种形式

点击查看代码
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class DiabetesDataset(Dataset):
    def __init__(self, filePath):
        xy = np.loadtxt(filePath, delimiter=',', dtype=np.float32)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        self.len = xy.shape[0]

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz') # 创建dataset
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()
criterion = torch.nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)

            print('epoch: ', epoch, 'i: ', i, 'loss: ', loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

标签:__,nn,数据,self,torch,data,def,加载
From: https://www.cnblogs.com/morehair/p/18359819

相关文章

  • 成为MySQL DBA后,再看ORACLE数据库(十四、统计信息与执行计划)
    一、前言一条SQL到达数据库内核之后,会解析为一条逻辑执行计划,CBO优化器对逻辑计划进行改写和转换,生成多个物理执行计划。为SQL构造出搜索空间,根据数据的统计信息、基数估计、算子代价模型为搜索空间中的执行计划估算出执行所需要的代价(CPU、内存、网络、I/O等资源消耗),最终选出代......
  • 基于STM32的边缘计算实时数据处理可视化系统:嵌入式C++、 FreeRTOS、Kafka、Spring Bo
    一、项目概述本项目旨在设计并实现一个基于STM32的边缘计算实时数据处理系统。该系统能够在边缘设备端进行数据采集、预处理,并将处理后的数据实时传输到后端服务器进行进一步分析和存储。本项目主要解决以下问题:减轻后端服务器的数据处理负担,提高系统整体效率降低......
  • MySQL-2:数据库基础知识(50%-100%)
    目录前言一、SQL语言基础1.SQL语言简介2.SQL分类3.SELECT语句的使用4.INSERT语句的使用5.UPDATE语句的使用6.DELETE语句的使用二、基本查询1.WHERE子句的使用2.ORDERBY子句的使用3.GROUPBY和HAVING子句使用4.LIMIT子句的使用总结前言前一半MySQL-1:数据库......
  • MySQL数据库专栏(三)数据库服务维护操作
    1、界面维护,打开服务窗口找到MySQL服务,右键单击可对服务进行启动、停止、重启等操作。选择属性,还可以设置启动类型为自动、手动、禁用。2、指令维护卸载服务:scdelete [服务名称]例如:scdeleteMySQL启动服务:netstart[服务名称]例如:netstartMySQL停止服务:netsto......
  • 数据类型
    Java基础语法中的数据类型是编程中非常重要的一个概念,它决定了变量能够存储什么类型的数据以及这些数据在内存中的表示方式。Java的数据类型可以分为两大类:基本数据类型(PrimitiveTypes)和引用数据类型(ReferenceTypes)。基本数据类型基本数据类型是Java中不可变的数据类型,它们直......
  • 如何保证数据不丢失?(死信队列)
    死信队列1、什么是死信死信通常是消息在特定的场景下表现:消息被拒绝访问消费者发生异常,超过重试次数消息的Expiration过期时间过长或者队列TTL过期时间消息队列到达最大容量maxLength2、什么是死信队列用来存储死信的队列,并且队列中只由死信构成的消息队列是死信队列......
  • 高阶数据结构(Java):AVL树插入机制的探索
    目录1、概念1.1什么是AVL树2.1平衡因子3、AVL树节点的定义4、AVL树的插入机制4.1初步插入节点4.2更新平衡因子4.3 提升右树高度4.3.1右单旋4.3.2左右双旋4.4 提升左树高度4.4.1左单旋 4.4.2右左双旋5、AVL树的验证6、AVL树的删除1、概念1.1什......
  • C程序设计(安徽专升本3.2基本数据类型)
    一、数据类型的分类 在本章节我们之讲解基础的数据类型,因为后续的数据类型将会单独对此讲解,常考的为基本数据类型,数组,函数,指针这几种类型!其它类型作为了解,认识即可二、整型类型此处对整数类型的讲解排除字符型和布尔型,它们单独拉出讲解,且我不喜欢废话讲解,我直接列表加代码......
  • R 语言GJR-GARCH、GARCH-t、GARCH-ged分析金融数据波动性预测、检验、可视化
    全文链接:https://tecdat.cn/?p=37354原文出处:拓端数据部落公众号 在当今复杂多变的金融市场中,准确理解和预测股票指数的走势对于投资者和金融机构而言至关重要。GARCH模型作为一种有效的工具,能够捕捉金融时间序列数据中的波动聚集性和异方差性,为我们提供更深入的市场洞察。准......
  • ABP默认模板修改默认数据库类型并初始化数据库数据
    我这里以SQLite数据库为例,其他数据库类似。1.下载模板https://aspnetboilerplate.com/ 根据自己的需求选择版本和前端框架并填写项目名称,点击“Createmyproject!”即可下载一个ABP标准模板项目。  解压下载好的压缩包,找到目录:aspnet-core,接下来就可以用VS打开.sln......