首页 > 其他分享 >Pytorch功能库留存

Pytorch功能库留存

时间:2022-10-15 17:23:14浏览次数:54  
标签:功能 12 nn self torch 留存 Pytorch net size

初始化

首先,介绍我们导入的包和基础的网络结构

import torch
import torch.nn as nn

#可替代网络结构部分
'''
神经网络类的定义
    1. 输入卷积: in_channel = 1, out_channel = 12, kernel_size = (5, 5), stride = (2, 2), padding = 2
    2. 激活函数: 1.7159Tanh(2/3*x)
    3. 第二层卷积: in_channel = 12, out_channel = 12, kernel_size = (5, 5), stride = (2, 2), padding = 2
    4. 激活函数同上
    5. 全连接层: 192 * 30
    6. 激活函数同上
    7. 全连接层:30 * 10
    8. 激活函数同上

    按照论文的说明,需要对网络的权重进行一个[-2.4/F_in, 2.4/F_in]的均匀分布的初始化
'''


class LeNet1989(nn.Module):
    def __init__(self):
        super(LeNet1989, self).__init__()

        self.conv1 = nn.Conv2d(1, 12, 5, stride=2, padding=2)
        self.act1 = nn.Tanh()
        self.conv2 = nn.Conv2d(12, 12, 5, stride=2, padding=2)
        self.act2 = nn.Tanh()

        self.fc1 = nn.Linear(192, 30)
        self.act3 = nn.Tanh()
        self.fc2 = nn.Linear(30, 10)
        self.act4 = nn.Tanh()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                F_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                m.weight.data = torch.rand(m.weight.data.size()) * 4.8 / F_in - 2.4 / F_in
            if isinstance(m, nn.Linear):
                F_in = m.in_features
                m.weight.data = torch.rand(m.weight.data.size()) * 4.8 / F_in - 2.4 / F_in

    def forward(self, x):
        x = self.conv1(x)
        x = 1.7159 * self.act1(2.0 * x / 3.0)

        x = self.conv2(x)
        x = 1.7159 * self.act2(2.0 * x / 3.0)

        x = x.view(-1, 192)

        x = self.fc1(x)
        x = 1.7159 * self.act3(2.0 * x / 3.0)

        x = self.fc2(x)
        out = 1.7159 * self.act4(2.0 * x / 3.0)

        return out

查看网络参数

方法一

推荐使用方法一,因为方法二可以得到更多的信息,但是要注意的是这一段看情况添加

看网络是在GPU跑还是CPU跑,我相信大部分是GPU,就用下面这个就可

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = LeNet1989().to(device)

用CPU就用这个

net = LeNet1989()

下面的部分就注意输入图片的input_size就可

from torchsummary import summary
#batch_size可以不指定,默认为-1,输入 模型(model)、输入尺寸(input_size)、批次大小(batch_size)、运行平台(device)
#summary(model, input_size, batch_size, device)
summary(net, input_size=(1, 16, 16), batch_size=1)

运行结果:

image-20221004211203977

方法二

代码:

#查看网络参数部分
net = LeNet1989()
print(net)
params = list(net.parameters())
k = 0
for i in params:
    l = 1
    # print("该层的结构:"+ str(list(i.size())))
    for j in i.size():
        l *= j
    # print("该层参数和:"+ str(l))
    k = k + l
print('总参数和:' + str(k))

运行结果:

image-20221004210139370

查看网络结构

from tensorboardX import SummaryWriter  # pip install tensorboardX
x = torch.randn(1, 16, 16) # 例如:x = torch.randn(1, 64, 64, 64)
net = LeNet1989()
writer = SummaryWriter('./tensorboard')
with writer:
    writer.add_graph(net, (x,))
# 终端中输入:tensorboard --logdir=tensorboard/ --host=127.0.0.1

图片通道转换

彩色图转灰度图

from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # 彩色图像转灰度图像num_output_channels默认1
    transforms.ToTensor()
])

彩色图(三通道)转指定R,G,B通道

def change_image_channels(image):
    # 3通道转单通道
    if image.mode == 'RGB':
        r, g, b = image.split()
    return r,g,b

模型测试调用

.pt, .pth, .pkl文件都可

保存:

torch.save(model, mymodel.pth)#保存整个model的状态

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以
model.eval()

标签:功能,12,nn,self,torch,留存,Pytorch,net,size
From: https://www.cnblogs.com/dongxuelove/p/16794593.html

相关文章

  • [转]嵌入式系统上实现GPS全球定位功能
           GPS(GlobalPositioningSystem)即全球定位系统,是由美国建立的一个卫星导航定位系统,利用该系统,用户可以在全球范围内实现全天候、连续、实时的三维导航定......
  • Redis6 新功能介绍
    特性的详细细节在此不赘述,我们来看Redis6.0,。Redis6.0版本特性大约可以分为四类,如下表新特性内核优化应用优化其他ACL权限管控(包括ACLLOG)过期Key回收优化......
  • C语言中预编译功能,预处理器指令
    三种预处理包括:宏定义、文件包含、条件编译。宏定义是C语言提供的三种预处理功能的其中一种。宏定义和操作符的区别是:宏定义是替换,不做计算,也不做表达式求解。宏定义又......
  • Excel带模糊搜索功能的下拉菜单,助力职场效率提升!
    Excel情报局职场联盟Excel生产挖掘分享Excel基础技能Excel爱好者大本营用1%的Excel基础搞定99%的职场问题做一个超级实用的Excel公众号Excel是门手艺玩转需要勇气数万Excel......
  • win11+wls2+ubuntu2004配置cuda+cudnn+pytorch
    0.前置说明win11系统开启子系统wsl2安装Ubuntu2004版本子系统(2204版本未测试,请自测)1.安装wsl2-Ubuntu2004子系统win11以上默认是wsl2了,win10参考列表第一个子系统......
  • 项目中导出功能(word)
    导出方法:fileName:导出word文件名称this.url.exportword:接口地址,返回blob文件流exportsMethod(){letfileName=this.info.lcmcgetActionBlob(this.url.......
  • uni-app 148朋友圈列表分页功能实现
    下图是我测试的截图/pages/find/moments/moments.vue<template><view><free-transparent-bar:scrollTop="scrollTop"@clickRight="clickRight"></free-transparent-......
  • uni-app 190扫一扫加入群聊功能(二)
    /pages/chat/scan-add/scan-add.nvue<template><viewclass="page"><!--导航栏--><free-nav-bartitle="加入群聊"showBack:showRight="false"></free-nav-ba......
  • uni-app 110清空聊天记录功能
    chat.jsimport$Ufrom"./util.js";import$Hfrom'./request.js';classchat{constructor(arg){this.url=arg.urlthis.isOnline=falsethis.socket=......
  • uni-app 111发送表情包功能
    chat.jsimport$Ufrom"./util.js";import$Hfrom'./request.js';classchat{constructor(arg){this.url=arg.urlthis.isOnline=falsethis.socket=......