首页 > 其他分享 >笔记5:TensorDataset、DataLoader及数据集划分

笔记5:TensorDataset、DataLoader及数据集划分

时间:2024-06-04 09:11:03浏览次数:25  
标签:loss torch TensorDataset DataLoader batch 笔记 epoch train test

TensorDataset

转自:https://www.cnblogs.com/miraclepbc/p/14333299.html

导入相关包

from torch.utils.data import TensorDataset

特征与标签合并

HRdataset = TensorDataset(X, Y)

模型训练

for epoch in range(epochs):
    for i in range(num_batch):
        x, y = HRdataset[i * batch_size: i * batch_size + batch_size]
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        print('epoch: ', epoch, 'loss: ', loss_func(model(X), Y).data.item())

DataLoader

导入相关包

from torch.utils.data import DataLoader

加载数据

HR_ds = TensorDataset(X, Y)
HR_dl = DataLoader(HR_ds, batch_size = batch_size, shuffle = True)

模型训练

for epoch in range(epochs):
    for x, y in HR_dl:
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        print('epoch: ', epoch, 'loss: ', loss_func(model(X), Y).data.item())

划分数据集

导入相关包

from sklearn.model_selection import train_test_split

划分数据集

train_x, test_x, train_y, test_y = train_test_split(X_data, Y_data)
  • 默认3:1

包装数据

train_x = torch.from_numpy(train_x).type(torch.float32)
test_x = torch.from_numpy(test_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.float32)

train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size = batch_size)

定义准确率

def accuracy(y_pred, y_true):
    return ((y_pred.data.numpy() > 0.5).astype('int') == y_true.numpy()).mean()

模型训练

for epoch in range(epochs):
    for x, y in train_dl:
        y_pred = model(x)
        loss = loss_func(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        epoch_accuracy = accuracy(model(train_x), train_y)
        epoch_loss = loss_func(model(train_x), train_y).data
        epoch_test_accuracy = accuracy(model(test_x), test_y)
        epoch_test_loss = loss_func(model(test_x), test_y).data
        print('epoch: ', epoch, 'loss: ', round(epoch_loss.item(), 3), 'accuracy: ', round(epoch_accuracy.item(), 3),
              'test_loss: ', round(epoch_test_loss.item(), 3), 'test_accuracy: ', round(epoch_test_accuracy.item(), 3))

标签:loss,torch,TensorDataset,DataLoader,batch,笔记,epoch,train,test
From: https://www.cnblogs.com/gongzb/p/18230104

相关文章

  • 实战营学习笔记3
    在浦语大模型的第三课《基于Internlm和LangChain构建你的知识库》中,北辰老师以其生动有趣的风格,深入浅出地讲解了RAG(RetrievalAugmentedGeneration)的基本概念,并指导我们如何利用茴香豆搭建一个RAG助手。在此之前,我阅读过一些关于大型语言模型的资料,心中一直存有一个疑惑:既......
  • 浙大翁恺《C语言程序设计》课程笔记
    1.1计算机与编程语言设计算法->编写程序->计算机执行程序执行的两种方式1.解释:借助一个程序(解释器),那个程序能试图理解你的程序,然后按照你的要求让计算机执行2.编译:借助一个程序(编译器),把你的程序翻译成机器语言,然后让计算机执行编程语言本身没有解释型和编译型之......
  • [学习笔记]点分治
    一、主要思想很容易理解,我们将一个树以一个节点分割成若干个子树。对于这个节点,我们以一些方式统计和改变答案,然后不断地向子树递归。那应该选择哪个节点呢?显然是重心。树的重心有一个性质:所有子树的大小小于等于当前树的大小的二分之一。也就是说,这保证了递归层数\(log_2\)的......
  • Git 笔记
    Git笔记git原理git的四个区域文件的四种状态git的工作流程安装git配置信息和获取帮助常用命令创建仓库跟踪文件gitadd取消跟踪gitrm提交到仓库gitcommit推送到远程分支gitpushcommit的查看、修改、合并搭建git服务器git原理git的四个区域工作......
  • g++编译过程学习笔记
    g++编译过程学习笔记学习用例使用很简单的多文件编译项目,进行编译过程的学习,主要文件构成如下:.├──include│└──hello.h└──src├──hello.cpp└──main.cpp其中hello.h声明了一个可以输出HelloWorld!的函数并在hello.cpp中完成实现。ma......
  • JAVA学习笔记6
    学习目标:精通JAVA学习内容:1.方法调用packagecn.itcast.day04.demo02;/*publicclassDemo01Method{publicstaticvoidmain(String[]args){for(intj=1;j<5;j++){for(inti=1;i<20;i++){System.out.print(“*”);}System.out.println();}}}......
  • COD读书笔记
    计算机组成与设计课程复习与CSAPP中类似的部分做了忽略或者简化性能的度量知识回顾对于某个计算机X,定义性能和执行时间的关系表达式:\[\text{性能}_X=\frac{1}{\text{执行时间}_X}\]描述时钟周期和时钟频率的关系:\[\text{时钟周期}=\frac{1}{\text{时钟频率}}\]对......
  • 概率论笔记(上)
    学习视频如下:主要学习视频:《概率论与数理统计》教学视频全集(宋浩)_哔哩哔哩_bilibili其余知识点补充: 二维连续型随机变量的积分计算_哔哩哔哩_bilibili 014二维连续型随机变量_哔哩哔哩_bilibili 矩估计&最大似然估计通俗易懂版解释(自用)_哔哩哔哩_bilibili ......
  • python学习笔记-03
    流程控制1.顺序流程代码自上而下的执行。2.选择流程/分支流程根据在某一步的判断有选择的执行相应的逻辑。2.1单分支if语句if条件表达式: 代码 代码 ...2.2双分支if-else语句if条件表达式: 代码 代码 ...else:代码代码...2.3多分支if......
  • 初识C语言(02)—学习笔记
    转义字符转义字符释义\0结束标志\n换行\'打印单引号\"打印双引号\\打印一个反斜杠\t水平制表符\a警告字符,蜂鸣?在书写连续多个问号时使用,防止它们被解析成三字符\dddddd表示1~3个八进制的数字\xdddd表示2个十六进制数字\v垂直......