首页 > 其他分享 >2023-11-30

2023-11-30

时间:2023-11-30 22:12:11浏览次数:31  
标签:11 loss nn 30 train 2023 test total data

import torchvision
import torchvision.datasets as datasets
from torch import nn
from torch.utils.data import DataLoader
import torch
# 设置学习率和训练轮数
learning_rate = 1e-3
epoch = 50

# 准备数据集
train_data = datasets.ImageFolder('./data/train', transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
]))
test_data = datasets.ImageFolder('./data/test', transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
]))

train_data_size = len(train_data)
test_data_size = len(test_data)

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 搭建神经网络
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 2 * 2, 128),
nn.Linear(128, 64),
nn.Linear(64, 5)
)
self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = self.model(x)
x = self.dropout(x)
return x

tudui = Tudui()

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(tudui.parameters(), lr=learning_rate)

total_train_step = 0
total_test_step = 0

for i in range(epoch):
print("-------第 {} 轮训练开始-------".format(i+1))

# 训练步骤开始
tudui.train()
for data in train_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_train_step += 1
if total_train_step % 100 == 0:
print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))

# 测试步骤开始
tudui.eval()
total_test_loss = 0
correct_predictions = 0
total_samples = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == targets).sum().item()
total_samples += targets.size(0)

print("整体测试集上的Loss: {}".format(total_test_loss))
print("整体测试集上的正确率: {}".format(correct_predictions / total_samples))
total_test_step += 1
torch.save(tudui, "tudui_{}.pth".format(i))
print("模型已保存")

标签:11,loss,nn,30,train,2023,test,total,data
From: https://www.cnblogs.com/wllovelmbforever/p/17868487.html

相关文章

  • 11.30二次探测法解决冲突
    设哈希表长为14,哈希函数是H=key%11,表中已有数据的关键字为15,38,61,84共四个,现要将关键字为49的元素加到表中,用二次探测法解决冲突,则放入的位置是(9)。15的位置是4,38的位置是5,61的位置是6,84的位置为749对应5和38冲突所以要用二次探索就是跳跃式的加数直到不重复且不超过哈希表长{1......
  • 20211128《信息安全系统设计与实现》第十四章学习笔记
    一、任务内容自学教材第14章,提交学习笔记(10分)1.知识点归纳以及自己最有收获的内容,选择至少2个知识点利用chatgpt等工具进行苏格拉底挑战,并提交过程截图,提示过程参考下面内容(4分)“我在学***X知识点,请你以苏格拉底的方式对我进行提问,一次一个问题”核心是要求GPT:“请你以苏格......
  • 20231130
     软件需求与分析课堂测试八—结构化建模分析(100分)个人答案非标准 【说明】某大学为进一步推进无纸化考试,欲开发一考试系统。系统管理员能够创建专业方向、课程编号、任课教师等相关考试基础信息。教师和考生进行考试相关工作。系统与考试有关的主要功能如下:(1)考试设置:教师制......
  • java-2023-11-30
    1、java中char类型由于使用Unicode编码所以是占两个字节而并不像C中是占一个字节。2、java中不使用0或非0值来代表假或真而是直接使用false或true。3、java中float和double由于精度不同不能进行比较,否则存在两值明显不等但输出的比较结果却为true的风险。4、如果运算结果可能超......
  • 2023.11.30 练习
    CF1887C首先容易想到区间加需转化为差分,字典序的比较呢就考虑二分哈希。二分第一个不一样的位置,这个位置也一定是差分数组第一个不一样的。把哈希如果放到线段树上,那么在线段树上二分即可。我们依次处理修改的时候,顺便处理当前的最小的字典序。我们这里如果采用主席树,那么会......
  • 2023-2024-1 20211306 密码系统设计与实现课程学习笔记12
    20211306密码系统设计与实现课程学习笔记12任务详情自学教材第14章,提交学习笔记知识点归纳以及自己最有收获的内容,选择至少2个知识点利用chatgpt等工具进行苏格拉底挑战,并提交过程截图,提示过程参考下面内容“我在学***X知识点,请你以苏格拉底的方式对我进行提问,一次一个......
  • NOIp 2023 游记
    咕了正好一周的NOIp游记,是我第一篇游记,也是一张寄往四年后不得不退役的、即将画上青春句号的自己的,包含了自己的青涩、期待与成长的信笺。Day\((-\infty,-7)\)CSP-S2023打炸,135,很清楚蓝勾是没了,所以每天都在想到底能不能去NOIp。问过并没有教过我什么的教练,回复到应该要......
  • 11.30每日总结
    实验一:百度机器翻译SDK实验一、实验要求 任务一:下载配置百度翻译Java相关库及环境(占10%)。 任务二:了解百度翻译相关功能并进行总结,包括文本翻译-通用版和文本翻译-词典版(占20%)。 任务三:完成百度翻译相关功能代码并测试调用,要求可以实现中文翻译成英文,英文翻译成中文(占30%)。......
  • ARC118
    ARC118第一次做arc场,被爆杀QAQARC118AlinkARC118A题意ARC国家的消费税率是\(t\)。其中\(t\)是正整数。ARC国家有整数屋。整数屋先生以不含税价格\(A\)日元处理着各个正整数\(A\),这个含税价格是\(\lfloor\frac{100+t}{100}A\rfloor\)日元。但是,对于实数\(x\),\(......
  • 多线程连接池报错报警告[connectionpool.py:305 ] [WARNING] Connection pool is full
    第一种方法:按照建议WebDriverWait帮助解决了问题 fromselenium.webdriver.support.uiimportWebDriverWaitfromselenium.webdriver.supportimportexpected_conditionsasECfromselenium.webdriver.common.byimportByfromselenium.common.exceptionsimportT......