首页 > 其他分享 >4. 基础实战——FashionMNIST时装分类

4. 基础实战——FashionMNIST时装分类

时间:2022-08-21 23:44:44浏览次数:97  
标签:实战 loss 时装 nn train torch FashionMNIST model data

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 设置环境和超参数
## 方案一:使用os.environ
# os.environ['CUDA_VISIBLE_DEVICES']='0'
## 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device('cuda:1' if torch.cuda is_available() else 'cpu')

## 配置其他超参数,如batch_size, num_workers, learning rate, 以及总的epochs
batch_size = 256
num_workers = 4   # 对于Windows用户,这里应设置为0,否则会出现多线程错误
lr = 1e-4
epochs = 20

# 设置数据变换
from torchvision  import transforms

image_size = 28
data_transform = transform.Compose([
    transform.ToPILImage(),
    # 这一步取决于后续的数据读取方式,如果使用内置数据集读取方式则不需要
    transform.Resize(image_size),
    transform.ToTensor()])

## 读取方式一:使用torchvision自带数据集,下载可能需要一段时间
from torchvision import datasets

train_data = datasets.FashionMNIST(root='./', train=True, download=True, transform=data_transform)
test_data = datasets.FashionMNIST(root='./', train=False, download=True, transform=data_transform)

# 定义DataLoader类,加载数据
# drop_last对最后无法满足 batch_size大小的皮数据予以丢弃
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# 据可视化操作,验证读入的数据是否正确
import matplotlib.pyplot as plt
image, label = next(iter(train_loader))
print(image.shape, label.shape)
plt.imshow(image[0][0], cmap="gray")

# 模型设计
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(1, 32, 5),
                nn.ReLU(),
                nn.MaxPool2d(2, stride = 2),
                nn.Dropout(0.3),
                nn.Conv2d(32, 64, 5),
                nn.ReLU(),
                nn.MaxPool2d(2, stride=2),
                nn.Dropout(0.3))
        self.fc = nn.Sequential(
                nn.Linear(64*4*4, 512),
                nn.ReLU(),
                nn.Linear(512, 10))
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 64*4*4)
        x = self.fc(x)
        return x

model = Net()
model = model.cuda()
# model = nn.DataParallel(model).cuda()   # 多卡训练时的写法

## 设定损失函数
# 使用CrossEntropy损失会,自动把整数型的label转为one-hot型,用于计算CE loss
criterion = nn.CrossEntropyLoss()

## 设置优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

## 训练和测试
def train(epoch):
    model.train()
    train_loss = 0
    for data, label in train_loader:
        data, label = data.cuda(), label.cuda()
        optimizer = optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
    train_loss = train_loss/len(train_loader.dataset)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

def val(epoch):
    model.eval()
    val_loss = 0
    gt_labels = []
    pred_labels = []
    with torch.no_grad():
        for data, label in test_loader:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            preds = torch.argmax(output, 1)
            gt_labels.append(label.cpu().data.numpy())
            pred_labels.append(preds.cpu().data.numpy())
            loss = criterion(output, label)
            val_loss += loss.item()*data.size(0)
    val_loss = val_loss/len(test_loader.dataset)
    gt_labels, pred_labels = np.concatenate(gt_labels), np.concatenate(pred_labels)
    acc = np.sum(gt_labels==pred_labels)/len(pred_labels)
    print('Epoch: {} \tValidation Loss: {:.6f}, Accuracy: {:6f}'.format(epoch, val_loss, acc))

## 训练与测试
for epoch in range(1, epochs+1):
    train(epoch)
    val(epoch)

  模型保存

save_path = './FahionModel.pkl'
torch.save(model, save_path)

  加载模型

model = torch.load('model.pkl')

  注意:将模型保存成何种格式文件无所谓(比如pkl,pth等)。  

  保存与加载模型参数

torch.save(model.state_dict(), 'model_params.pth')
model.load(torch.load( 'model_params.pth')) 
 

标签:实战,loss,时装,nn,train,torch,FashionMNIST,model,data
From: https://www.cnblogs.com/5466a/p/16611426.html

相关文章

  • Elasticsearch 实战
    需求假设现在有这么一个需求,系统接了很多的报文,需要提供全文检索,为了简化,报文目前只有类型,流水号,内容这三个字段。索引设计建立msg索引,映射规则如下PUT/msg{ "mappi......
  • 大数据Hadoop之——HDFS小文件问题与处理实战操作
    目录一、背景1)小文件是如何产生的?2)文件块大小设置3)HDFS分块目的二、HDFS小文件问题处理方案1)HadoopArchive(HAR)2)Sequencefile3)CombineFileInputFormat4)开启JVM重用5)合并本......
  • 大数据Hadoop之——Hadoop HDFS多目录磁盘扩展与数据平衡实战操作
    目录一、概述二、HadoopDataNode多目录磁盘配置1)配置hdfs-site.xml2)配置详解1、dfs.datanode.data.dir2、dfs.datanode.fsdataset.volume.choosing.policy3、dfs.datanod......
  • 新一代分布式实时流处理引擎Flink入门实战操作篇
    @目录安装部署安装方式Local(Standalone单机部署)Standalone部署StandaloneHA部署FlinkOnYarn演示案例概述会话(Session)模式单作业(Per-Job)模式流程演示应用(Application)......
  • JPA 入门实战(2)--简单使用
    本文主要介绍JPA的实际使用,相关的环境及软件信息如下:JPA2.2(eclipselink2.7.10、hibernate-entitymanager5.6.10.Final、openjpa3.2.2),JPA3.0(eclipselink3.0.2、h......
  • Rust实战系列-Rust介绍
    “学习资料:rustinaction[1]1.Rust安装curl--proto'=https'--tlsv1.2-sSfhttps://sh.rustup.rs|shsource"$HOME/.cargo/env"2.helloworld创建hel......
  • 新一代分布式实时流处理引擎Flink入门实战之先导理论篇-上
    @目录概述定义为什么使用Flink应用行业和场景应用行业应用场景实时数仓演变FlinkVSSpark架构系统架构术语无界和有界数据流式分析基础分层API运行模式作业提交流程顶层抽......
  • Blazor预研与实战
    背景最近一直在搞一件事,就是熟悉Blazor,后期需要将Blazor真正运用到项目内。前期做了一些调研,包括但不限于Blazor知识学习组件库生态预研与现有SPA框架做比对与WebFor......
  • jmeter接口自动化实战--新增店员
    一、目标使用jmeter通过接口实现新增店员功能二、步骤及思想1、登录。  首先需要登录app2、进入新增店员页面。  进入app后调用任何接口需要有token,所以要提取......
  • canal同步mysql实战
    环境mysql5.6.41canal1.151.16测试过后,一直报错canal_config表不存在,更换版本后正常目的:同步一个数据库中的二个表1、创建表CREATETABLE`user01`(`id`int(......