首页 > 其他分享 >小土堆pytorch笔记

小土堆pytorch笔记

时间:2023-02-01 10:15:06浏览次数:54  
标签:loss torch 笔记 pytorch train tudui test 土堆 total

I 验证网络结构是否有误

  1. 初始化一个符合网络的输入数据

    input = torch.ones((64, 3, 32, 32))

  2. 将输入数据传进网络,看是否报错

    print(network(input).shape)

II 修改已知网络(比如vgg16)

vgg16_false = torchvision.models.vgg16(weights=None)   
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)  
# 添加模块  
vgg16_false.classifier.add_module(name="add_linear", module=nn.Linear(in_features=1000, out_features=10))  
# 修改vgg的classifier的第7个模块  
vgg16_true.classifier[6] = nn.Linear(4096, 10)

III 模型保存及对应加载方法

# 保存模型方法1:模型结构+模型参数。  
def save1(model, filename):  
    torch.save(model, filename)


# 对应使用save1()保存的模型的加载方法1。注意:要让加载的模型可被该方法访问。
def load1(filename):
    return torch.load(filename)


# 保存模型方法2:保存模型参数(官方推荐)
def save2(model, filename):
    torch.save(model.state_dict(), filename)


# 对应使用save2()保存的模型的加载方法2
def load2(model, filename):
    state_dict = torch.load(filename)
    model.load_state_dict(state_dict)
    return model

IV 训练模型步骤

0. 定义训练的设备

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. 设置数据集

1.1 定义数据集

train_dataset = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=torchvision.transforms.ToTensor(),
                                             download=True)
test_dataset = torchvision.datasets.CIFAR10(root="../dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                            download=True)

1.2 定义数据集相关参数

train_dataset_length = len(train_dataset)
test_dataset_length = len(test_dataset)

1.3 利用DataLoader加载数据集

train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

2. 设置模型

2.1 创建模型

tudui = Tudui()
tudui.to(device) # 可以不写成tudui = tudui.to(device)

2.2 定义训练相关参数

epoch = 10  # 训练轮数
train_total_step = 0  # 训练总次数
test_total_step = 0  # 测试总次数

2.3 设置tensorboard

writer = SummaryWriter(log_dir="../logs-train")

3. 定义损失函数

loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)

4. 定义优化器

learning_rate = 1e-2  # 学习率
optim = torch.optim.SGD(params=tudui.parameters(), lr=learning_rate)

5. 训练模型

for i in range(epoch):
    # 5.1 开始训练
    tudui.train() # 对有drop等层有效
    for inputs, targets in train_dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = tudui(inputs)
        # 计算损失值
        loss = loss_fn(outputs, targets)
        # 使用优化器优化模型
        optim.zero_grad()
        loss.backward()
        optim.step()
        # 记录训练情况
        train_total_step += 1
        if train_total_step % 100 == 0:
            writer.add_scalar(tag="time", scalar_value=end_time-start_time, global_step=train_total_step)
            print("训练次数:{}, loss:{}".format(train_total_step, loss.item()))
            writer.add_scalar(tag="train_loss", scalar_value=loss.item(), global_step=train_total_step)

    # 5.2 测试网络
    tudui.eval()
    test_total_accuracy = 0
    test_total_loss = 0
    with torch.no_grad():
        for inputs, targets in test_dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = tudui(inputs)
            loss = loss_fn(outputs, targets)
            test_total_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            test_total_accuracy += accuracy

    print("整体测试集loss:{}, accuracy:{}".format(test_total_loss, test_total_accuracy / test_dataset_length))
    writer.add_scalar(tag="test_loss", scalar_value=test_total_loss, global_step=i)
    writer.add_scalar(tag="test_accuracy", scalar_value=test_total_accuracy / test_dataset_length, global_step=i)

    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型文件:tudui_{}.pth已保存".format(i))

writer.close()

V 完整测试模型步骤

1. 导入图片

image = Image.open("../imgs/airplane.png")

1.1 调整图片通道数

由于png图片有4通道(多了一个透明度通道),故须转为RGB图片的三通道,从而适应jpg、png等各种类型的图片
image = image.convert("RGB")

1.2 调整图片尺寸大小和数据类型

trans = transforms.Compose([transforms.Resize((32, 32)),
                    transforms.ToTensor()])
image = trans(image)

1.3 调整图片维数

增加一个batch size维度
image = torch.reshape(image, (1, 3, 32, 32))

2. 加载训练好的模型权重

modle = torch.load("xx.pth")

3. 测试模型并打印输出

modle.eval()
with torch.no_grad():
    output = modle(image)
# 打印出每行最大的列索引值
print(output.argmax(1))

标签:loss,torch,笔记,pytorch,train,tudui,test,土堆,total
From: https://www.cnblogs.com/curie/p/17081624.html

相关文章

  • PostgreSQL学习笔记-4.基础知识:触发器、索引
    PostgreSQL触发器是数据库的回调函数,它会在指定的数据库事件发生时自动执行/调用。下面是关于PostgreSQL触发器几个比较重要的点:PostgreSQL触发器可以在BEFORE、AFT......
  • QT 问题笔记
    1.'QMainWindow'filenotfound  网上:解决方法:在.pro中加入QT+=coreguigreaterThan(QT_MAJOR_VERSION,4):QT+=widgets 别人的demo,#include<QMainWindow>可......
  • 极客时间 Java并发编程实战 笔记
    思考、再思考、总结、再总结01可见性、原子性和有序性举几个例子先。缓存可能导致可见性问题,因为多核CPU上的多个核可能都持有同一数据的不同缓存。两个线程并行地对......
  • Typora实现云笔记,支持云同步+多端查看
    前言:为啥要使用typora做笔记,因为是码农标配,但是typora没有云存储功能,所以通过多款软件,打造属于个人的云笔记安装typora请支持正版链接免费版本:蓝奏云密码:be24,下......
  • JavaScript学习笔记—DOM:事件
    事件(event)事件就是用户和页面之间发生的交互行为比如:点击按钮,鼠标移动,双击按钮,敲击键盘,松开按键...可以通过为事件绑定响应函数(回调函数),来完成和用户之间的交互绑定响......
  • 《RPC实战与核心原理》学习笔记Day14
    19|分布式环境下如何快速定位问题?分布式环境下定位问题有什么难点?分布式环境下定位问题的难点在于,各子应用、子服务之间有复杂的依赖关系,我们有时很难确定是哪个服务......
  • 动态规划学习笔记
    动态规划1,什么是动态规划私以为,动态规划就是在递归思想的基础上,用空间换时间,将已经计算过的结果用存储起来,消除冗余计算,提高算法效率。2,什么时候使用动态规划抽象一点......
  • Linux初学笔记
    关于java全栈开发要掌握的技术JavaSEMySQL前端(HTML、CSS、JS)JavaWebSSM框架(可以开始找工作了)SpringboardVueSpringCloudGitLinux关于Linux需要掌握的技术消......
  • 树链剖分学习笔记
    怕到时候忘了,来写一篇笔记前置芝士:树的存储与遍历,\(dfs\)序,线段树。树链剖分的思想及能解决的问题:树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。具体......
  • Elasticsearch 从入门到实践 小册笔记
    MappingJSON中是可以嵌套对象的,保存对象类型可以用object类型,但实际上在ES中会将原JSON文档扁平化存储的。假如作者字段是一个对象,那么可以表示为:{"author":{......