首页 > 其他分享 >Pytorch小土堆跟练代码(第1天)

Pytorch小土堆跟练代码(第1天)

时间:2024-10-12 22:53:52浏览次数:14  
标签:img self label Pytorch path 土堆 dataset 跟练 dir

本系列为跟练小土堆每集代码,然后进入李宏毅机器学习教程。在系列中会敲完所有视频中代码,并且在注释写出感悟和易错点。欢迎大家一起交流!

最前面的安装部分,可以移步我的另一个帖子

第一章·Dataset

首先讲了数据集的读取,主要调用了Dataset相关的函数,有图片和特征的地址提取和打开

'''数据集的读取'''
"一种直接把标签写到数据库上,另外一种数据集和label分开"

'''read_data.py'''
from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):

    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)

    '''标签'''
    def __getitem__(self,idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    '''返回长度'''
    def __len__(self):
        return len(self.img_path)


'''上面并没有变量,下面创建两个变量,验证一下'''
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
"实参后面是要变得,不是全一个!那个bees我就错了"


train_dataset = ants_dataset + bees_dataset
len(train_dataset)
len(ants_dataset)
len(bees_dataset)
img, label = train_dataset[124]
'''这是按照顺序调一下图片'''

'''这个是为了label复杂时候,可以分家。image和label分开,让txt的文档名字和它一致'''

要注意.和_不要写混,查bug查了好一会

第二章·TensorBoard

这里讲了如何在数据训练过程中生成图像来监视训练效果——生成losses

汉译其实就是张量板

'''transform的使用,能用于图像变化'''
import numpy as np
from PIL import Image
'''tensorboard是图像展示,可以看到losses,可以知道应该训练到哪一步'''
from torch.utils.tensorboard import SummaryWriter


"""
可以用ctrl打开
Writes entries directly to event files in the log_dir to be
consumed by TensorBoard.

The `SummaryWriter` class provides a high-level API to create an event file
in a given directory and add summaries and events to it. The class updates the
file contents asynchronously. This allows a training program to call methods
to add data to the file directly from the training loop, without slowing down
training.
"""
'''里面我们看到log_dir是SummaryWritter的参数,它的意思是目录。比如下面这个就是存到logs文件夹里面'''


writer = SummaryWriter("logs")

 # writer.add_image()     ctrl+/可以注释
 
 for i in range(100):
     writer.add_scalar("y=2x",2*i,i) #名字,横轴value,纵轴.
     #如果后面的轴导致线复回乱跑,就kill掉进程
 
 
 writer.close()

 '''左边多了一个logs的文件夹,是tensorboard的文件'''



'''接下来是生成图像'''
image_path = "dataset/train/bees/16838648_415acd9e3f.jpg"

img_PIL = Image.open(image_path)
img_array= np.array(img_PIL)
print(type(img_array))#看看什么类型是为了下面writer要输入这个,tensor型或者numpy型
print(img_array.shape)#顺序应该反一下

writer.add_image("test", img_array, 3, dataformats='HWC')   #换成第二张照片的时候这里也要改步骤。这里是同一个test下的
writer.add_image("train", img_array, 1, dataformats='HWC')

for i in range(100):
    writer.add_scalar("y=2x",2*i,i)

writer.close()

'''此时遇到pillow的问题,高版本下无法使用Image中的某个函数。下回8.4.0,就解决了'''

标签:img,self,label,Pytorch,path,土堆,dataset,跟练,dir
From: https://blog.csdn.net/2301_80060871/article/details/142886160

相关文章

  • PyTorchStepByStep - Chapter 2: Rethinking the Training Loop
      defmake_train_step_fn(model,loss_fn,optimizer):defperform_train_step_fn(x,y):#SetmodeltoTRAINmodemodel.train()#Step1-Computemodel'spredictions-forwardpassyhat=model(x)......
  • Transformer的Pytorch实现【1】
    使用Pytorch手把手搭建一个Transformer网络结构并完成一个小型翻译任务。首先,对Transformer结构进行拆解,Transformer由编码器和解码器(Encoder-Decoder)组成,编码器由Multi-HeadAttention+Feed-ForwardNetwork组成的结构堆叠而成,解码器由Multi-HeadAttention+Multi-HeadAtte......
  • 利用pytorch的datasets在本地读取MNIST数据集进行分类
    MNIST数据集下载地址:tensorflow-tutorial-samples/mnist/data_setatmaster·geektutu/tensorflow-tutorial-samples·GitHub数据集存放和dataset的参数设置:完整的MNIST分类代码:importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorchvisionimpor......
  • 使用PyTorch搭建Transformer神经网络:入门篇
    目录简介环境设置PyTorch基础Transformer架构概述实现Transformer的关键组件5.1多头注意力机制5.2前馈神经网络5.3位置编码构建完整的Transformer模型训练模型总结与进阶建议简介Transformer是一种强大的神经网络架构,在自然语言处理等多个领域取得了巨大......
  • 使用StyleGAN3合成自定义数据(pytorch代码)
    使用StyleGAN3合成自定义数据在现代计算机视觉和机器学习领域,生成对抗网络(GAN)已成为生成高质量图像的重要工具。其中,StyleGAN3是NVIDIA团队推出的第三代生成对抗网络,其显著改进了图像生成的质量和稳定性。本文旨在介绍如何在训练数据较少的情况下,使用StyleGAN3来合成......
  • 机器学习四大框架详解及实战应用:PyTorch、TensorFlow、Keras、Scikit-learn
    目录框架概述PyTorch:灵活性与研究首选TensorFlow:谷歌加持的强大生态系统Keras:简洁明了的高层APIScikit-learn:传统机器学习的必备工具实战案例图像分类实战自然语言处理实战回归问题实战各框架的对比总结选择合适的框架1.框架概述机器学习框架在开发过程中起着至......
  • 安装 Anaconda、PyTorch(GPU 版)库与 PyCharm
    Anaconda是一款巨大的Python环境集成平台,里面包含了Python解释器、JupyterNotebook代码编辑器以及很多的第三方库,所以安装Anaconda后我们无需再安装Python解释器,非常方便。一、安装Anaconda1.卸载Anaconda(可选)如果我们原来的电脑上安装过Anaconda,为了避免重复安......