首页 > 其他分享 >机器学习——自定义层

机器学习——自定义层

时间:2023-11-03 11:22:05浏览次数:31  
标签:__ 机器 自定义 nn self torch weight 学习

深度学习成功背后的一个因素是神经网络的灵活性: 我们可以用创造性的方式组合不同的层,从而设计出适用于各种任务的架构。 例如,研究人员发明了专门用于处理图像、文本、序列数据和执行动态规划的层。 有时我们会遇到或要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,必须构建自定义层。本节将展示如何构建自定义层。

 

不带参数的层

首先,我们构造一个没有任何参数的自定义层。下面的CenteredLayer类要从其输入中减去均值。 要构建它,我们只需继承基础层类并实现前向传播功能

import torch
import torch.nn.functional as F
from torch import nn


class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))
tensor([-2., -1.,  0.,  1.,  2.])

 

现在,我们可以将层作为组件合并到更复杂的模型中。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

 作为额外的健全性检查,我们可以在向该网络发送随机数据后,检查均值是否为0。 由于我们处理的是浮点数,因为存储精度的原因,我们仍然可能会看到一个非常小的非零数。

Y = net(torch.rand(4, 8))
Y.mean()
tensor(7.4506e-09, grad_fn=<MeanBackward0>)

注意,torch.rand(4, 8)是一个PyTorch函数,作用是生成一个形状为(4, 8)的随机张量,元素服从[0,1)上的均匀分布

 

带参数的层

以上我们知道了如何定义简单的层,下面我们继续定义具有参数的层, 这些参数可以通过训练进行调整。 我们可以使用内置函数来创建参数,这些函数提供一些基本的管理功能。 比如管理访问、初始化、共享、保存和加载模型参数。 这样做的好处之一是:我们不需要为每个自定义层编写自定义的序列化程序。

python
import torch
from torch import nn

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(MyLinear, self).__init__()
        # 使用内置函数创建参数
        self.weight = nn.Parameter(torch.randn(in_features, out_features)) 
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        return x @ self.weight + self.bias

layer = MyLinear(3, 4)

# 保存和加载可以自动处理
torch.save(layer.state_dict(), 'path/to/params.pth')  
layer.load_state_dict(torch.load('path/to/params.pth'))

 这个自定义线性层使用了PyTorch内置的Parameter和nn.Module,可以自动管理参数初始化、保存、加载等功能。

 

import torch
from torch import nn

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(MyLinear, self).__init__()
        
        # 不使用Parameter,自行创建参数       
        self.weight = torch.randn(in_features, out_features)
        self.bias = torch.zeros(out_features)

    def forward(self, x):

        # 需要自行处理参数保存、加载等管理
        return x @ self.weight + self.bias 

layer = MyLinear(3, 4) 

# 需要自定义保存逻辑
torch.save({'weight': layer.weight, 'bias': layer.bias}, 'params.pth')

# 需要自定义加载逻辑
params = torch.load('params.pth')
layer.weight = params['weight']
layer.bias = params['bias']

 主要区别:

1. 不使用Parameter,自己定义参数tensor

2. forward函数中需要自行使用这些参数

3. 保存和加载时需要自定义处理逻辑

4. 初始化、访问等管理都需要自行实现

 

下面来看具体实现的例子(以PyTorch为例)

class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)
linear = MyLinear(5, 3)
linear.weight
Parameter containing:
tensor([[ 0.1775, -1.4539,  0.3972],
        [-0.1339,  0.5273,  1.3041],
        [-0.3327, -0.2337, -0.6334],
        [ 1.2076, -0.3937,  0.6851],
        [-0.4716,  0.0894, -0.9195]], requires_grad=True)

我们可以使用自定义层直接执行前向传播计算。

linear(torch.rand(2, 5))
tensor([[0., 0., 0.],
        [0., 0., 0.]])

 

标签:__,机器,自定义,nn,self,torch,weight,学习
From: https://www.cnblogs.com/yccy/p/17807197.html

相关文章

  • java基础学习:path,java_home环境变量配置
    1.path变量: 装jdk后会自动配置java和javac的path路径 2.JAVA_HOME环境变量:   ......
  • Django实战项目-学习任务系统-发送邮件通知
    接着上期代码内容,继续完善优化系统功能。 本次增加发送邮件通知功能,学习任务系统发布的任务,需要及时通知到学生用户知晓。由于目前智能手机普及,人人都离不开手机,所以手机端接收通知信息更加及时有效。 其中微信使用频率最多,本来想使用微信通知功能,但是经过网上搜集资料测试......
  • 使用websocket开发智能聊天机器人
    前面我们学习了异步web框架(sanic)和http异步调用库httpx,今天我们学习websocket技术。websocket简介我们知道HTTP协议是:请求->响应,如果没有响应就一直等着,直到超时;但是有时候后台的处理需要很长时间才能给到结果,比如30分钟,那HTTP的请求不可能等这么久,所以,可以通过Ajax轮询来解决。......
  • Selenium 4.0beta:读源码学习新功能
    Selenium4源码分析这一篇文章我们来分析Selenium4python版源码。除非你对Selenium3的源码烂熟于心,否则通过对比工具分析更容易看出Selenium4更新了哪些API。文件对比工具推荐BeyondCompare驱动支持Selenium4去掉了android、blackberry和phantomjs等驱动支持。Selenium......
  • android侧滑应用学习记录
    android侧滑菜单怎么禁止滑动1、点击图标,看看是哪个软件的快捷组件。打开软件的设置,取消桌面或其它界面显示就OK。另外,也可以通过权限设置,禁止软件显示通知等等,禁止这一类的组件和任务栏显示。2、打开“设置”面板;找到“个人”类里的“安全”选项。点击进入;找到选项“屏幕锁定”选......
  • 【专题】2023中国工业机器人应用与趋势研究报告PDF合集分享(附原数据表)
    原文链接:https://tecdat.cn/?p=34132自18世纪中期工业革命以来,人类进入工业社会。在历次工业革命中,人类通过发明创造和管理革新,改进生产方式、降低成本、提高效率,随之而来的是生活、物质、文化、教育等各方面的变化,人际关系和社会结构也得以重塑。如今,数字化技术的发展为工业注入......
  • 11类型别名和自定义类型
    Go语言中没有“类”的概念,也不支持“类”的继承等面向对象的概念。Go语言中通过结构体的内嵌再配合接口比面向对象具有更高的扩展性和灵活性。类型别名和自定义类型自定义类型在Go语言中有一些基本的数据类型,如string、整型、浮点型、布尔等数据类型,Go语言中可以使用type关键......
  • 0为什么你应该学习Go语言
    终于等到你!Go语言——让你用写Python代码的开发效率编写C语言代码。为什么互联网世界需要Go语言世界上已经有太多太多的编程语言了,为什么又出来一个Go语言?硬件限制:摩尔定律已然失效摩尔定律:当价格不变时,集成电路上可容纳的元器件的数目,约每隔18-24个月便会增加一倍,性能也将提......
  • 博客园自定义主题教程
    https://www.cnblogs.com/cainiao-chuanqi/p/11388719.htmlhttps://blog.csdn.net/cxyliangzai/article/details/125094052?spm=1001.2101.3001.6650.8......
  • 《APUE》学习笔记
    学习资源:https://www.bilibili.com/video/av75586088/?p=2&spm_id_from=pageDriver&vd_source=1ecb7953e7a94890c19f9abe34af6240项目:IPV4流媒体广播系统知识点:多进程的实现及关系进程间通信多线程或多进程并发数据库文件I/O操作守护进程系统日志文件流量控制网络套接......