首页 > 其他分享 >基于Pytorch的简单深度学习项目实战

基于Pytorch的简单深度学习项目实战

时间:2023-05-11 22:23:20浏览次数:46  
标签:实战 loss nn Pytorch train 深度 test data accuracy

基于Pytorch的简单深度学习项目实战_pytorch深度学习项目实战_NPC_0001的博客-CSDN博客

基于以上补充:

所需数据阿里网盘分享:

(暂不支持分享,后续补上)

代码:

import torch.utils.data
import torchvision

from torch import nn
from torch.utils.tensorboard import SummaryWriter

data_train = torchvision.datasets.CIFAR10("./dataset",train = True,transform= torchvision.transforms.ToTensor(),
                                          download=True)
data_test = torchvision.datasets.CIFAR10("./dataset",train = False,transform= torchvision.transforms.ToTensor(),
                                          download=True)
data_train_size = len(data_train)
data_test_size = len(data_test)

print("------训练集大小{}------".format(data_train_size))
print("------测试集大小{}------".format(data_test_size))

#加载数据
dataloader_train = torch.utils.data.DataLoader(data_train,batch_size=64)
dataloader_test = torch.utils.data.DataLoader(data_test,batch_size=64)

#创建网络
class FewShot(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,padding="same"),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,padding="same"),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,padding="same"),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )

    def forward(self,x):
        x=self.model(x)
        return x

few_shot = FewShot()

#损失函数
loss_fun = nn.CrossEntropyLoss()

#优化器
learning_rate =1e-2
optimier = torch.optim.SGD(few_shot.parameters(),lr=learning_rate)


#开始训练
#训练轮数
Epoch = 10

#记录训练次数
train_step = 0
test_step = 0

writer = SummaryWriter("logs")

#训练
for epoch in range(Epoch):
    print("------第{}轮训练开始------".format(epoch+1))

    #训练步骤
    few_shot.train()
    train_accuracy = 0
    for data in dataloader_train:
        imgs,label = data
        output = few_shot(imgs)
        loss = loss_fun(output,label)

        #优化器优化模型
        optimier.zero_grad()
        loss.backward()
        optimier.step()

        train_step+=1

        accuracy = (output.argmax(1) == label).sum()
        train_accuracy += accuracy
        if train_step % 100 ==0:
            writer.add_scalar("train_loss",loss.item(),train_step)
            print("第{}次训练,LOSS值为:{}".format(train_step,loss.item()))

    #测试
    few_shot.eval()
    loss_test = 0
    test_accuracy = 0
    with torch.no_grad():

        for data in dataloader_test:
            imgs,label = data
            output = few_shot(imgs)
            loss = loss_fun(output,label)
            loss_test +=loss.item()
            accuracy = (output.argmax(1) == label).sum()
            test_accuracy += accuracy

    test_step +=1
    print("第{}轮测试,LOSS值为:{}".format(epoch +1,loss_test))
    writer.add_scalar("test_loss",loss_test,test_step)

    print("第{}轮训练,准确率为:{}".format(epoch+1,train_accuracy / data_train_size))
    print("第{}轮测试,准确率为:{}".format(epoch + 1, test_accuracy / data_test_size))

    writer.add_scalar("train_accuracy",train_accuracy/data_train_size,test_step)
    writer.add_scalar("test_accuracy", test_accuracy / data_test_size, test_step)

    #模型保存
    torch.save(few_shot,"few_show_{}".format(epoch))
    print("模型保存成功")

writer.close()

标签:实战,loss,nn,Pytorch,train,深度,test,data,accuracy
From: https://www.cnblogs.com/gitLab/p/17392421.html

相关文章

  • 【Spring实战】第4章 面向切面的Spring
    POM依赖<dependency><groupId>org.springframework</groupId><artifactId>spring-aop</artifactId><version>4.0.7.RELEASE</version></dependency><!--SpringAOP依赖AspectJ,不然会报ReflectionWorldExc......
  • Pytorch语法——torch.autograd.grad
    Thetorch.autograd.gradfunctionisapartofPyTorch'sautomaticdifferentiationpackageandisusedtocomputethegradientsofgivenoutputswithrespecttogiveninputs.Thisfunctionisusefulwhenyouneedtocomputegradientsexplicitly,rathe......
  • OpenSeadragon 实战系列其他属性的使用
    viewport的使用我们打开openseadragn的官网,可以找到下图所示的viewport点开viewport,你可以看到很多viewport的方法那么如何使用viewport呢?在基础篇中的示例代码中,我们定义了viewer1varviewer=OpenSeadragon({2id:"openseadragon1",3......
  • OpenSeadragon 实战系列第三方插件
    序言在我们的项目中,一般不可能只是简单的显示图片,对应着还需要做一些图像标注、图像颜色过滤等操作,比如一些医学病理切片。所以openseadragon也为我们提供了一些插件,我们打开官网,找到plugins这些插件中有很多是中间件,各位根据自己的需求自行研究把,在我的项目中只使用......
  • Pytorch数据预处理
    为了能用深度学习来解决现实世界的问题,我们经常从预处理原始数据开始,而不是从那些准备好的张量格式数据开始。首先我们准备一个人工数据集: 这是一个.csv格式(用逗号隔开)的数据文件。该数据集有四行三列。其中每行描述了房间数量(“NumRooms”)、巷子类型(“Alley”)和房屋价格(“P......
  • OpenSeadragon 实战系列dzi图像切割命名规则篇
    序言根据前边的两篇文章,我们已经可以实现图像的显示了。但是现在我们显示的还是由微软软件自动生成的图片,在实际运用中,需要由后端将图片切割,具体切割方式在微软的dzi图片格式说明中也有,地址:https://docs.microsoft.com/en-us/previous-versions/windows/silverlight/dotnet-wi......
  • 官网使用conda&pip安装PyTorch命令总结(包含各版本)
    原网页https://pytorch.org/get-started/previous-versions/因为有时访问该网站比较慢,所以本博客记录该网页内容InstallingpreviousversionsofPyTorchWe’dpreferyouinstallthelatestversion,butoldbinariesandinstallationinstructionsareprovidedbelow......
  • 模板元编程实战--TypeList算法--查找
    从一个类型列表中查找是否包含某一个类型。要从一个类型列表中查找,那么首先要获得每一个类型,然后与特定的类型比较,然后将结果保存起来。首先考虑一下Elem应该如何实现。Elem将会展开参数列表,然后处理,这里使用之前演示Fold高阶函数回调处理:template<TLIn,typenameT>clas......
  • 蛋白质深度学习
    本文主要面向两类目标读者:一类是想使用机器学习的生物学家,一类是想进入生物学领域的机器学习研究者。如果你不熟悉生物学或机器学习,仍然欢迎你阅读本文,但有时你可能会觉得有点读不太懂!如果你已经熟悉这两者,那么你可能根本不需要本文——你可以直接跳到我们的示例notebook以......
  • 【pytorch】理解张量,了解张量的创建和操作
    深度学习的核心是卷积,卷积的核心是张量(Tensor)理解TensorTensor可以简单理解为是标量、向量、矩阵的高维扩展。你可以把张量看作多维数组,但相较于ndarray,Tensor包含了grad、requires_grad、grad_fn、device等属性,是为服务于神经网络而设计的类型,标量可以看作是零维张量、......