首页 > 其他分享 >6.2 手写卷积类

6.2 手写卷积类

时间:2023-08-05 18:31:51浏览次数:36  
标签:kernel nn 卷积 self torch shape 6.2 手写 size

import torch
from torch import nn
from d2l import torch as d2l

class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size)) #如kernel_size= (2,2),则随机初始化一个2x2的卷积
        self.bias = nn.Parameter(torch.zeros(1)) #bias初始化为0
    def forward(self,X):
        return self.corr2d(X,self.weight) + self.bias
    
    # 卷积操作函数
    def corr2d(self,X,K):
        h,w = K.shape
        # 卷积输出大小
        Y = torch.zeros((X.shape[0] - h + 1),X.shape[1] - w + 1)
        for i in range(Y.shape[0]):
            for j in range(Y.shape[1]):
                # (i,j)是目标区域的左上角坐标
                Y[i,j] = (X[i:i+h,j:j+w] * K).sum()
        return Y

# 测试
X = torch.tensor([[0.0,1.0,2.0],
                  [3.0,4.0,5.0],
                  [6.0,7.0,8.0]])
model = Conv2D(kernel_size=(2,2))
Y = model.forward(X)
print(Y)

标签:kernel,nn,卷积,self,torch,shape,6.2,手写,size
From: https://blog.51cto.com/u_16207976/6977185

相关文章

  • 01手写顺序表
    一、简介学习数据结构的第一个程序,手写实现顺序表。实现功能创建表清空表中元素判断表中数据是否为空求表中有效数据长度指定数据元素定位指定位置插入元素释放空间打印顺序表的内容删除指定位置上的元素二、完整代码sqlist.h#ifndef__SQLIST_H#define__SQLIST......
  • CNN tflearn处理mnist图像识别代码解说——conv_2d参数解释,整个网络的训练,主要就是为
    官方参数解释:Convolution2Dtflearn.layers.conv.conv_2d(incoming,nb_filter,filter_size,strides=1,padding='same',activation='linear',bias=True,weights_init='uniform_scaling',bias_init='zeros',regularizer=None,wei......
  • TDengine vs InfluxDB:写入速度领先 16.2 倍,查询速度超百倍
    为了验证TDengine3.0在IoT场景下的性能,我们针对第三方基准性能测试平台TSBS(TimeSeriesBenchmarkSuite)中的IoT场景,预设了五种规模的卡车车队基础数据集,在相同的AWS云环境下对TDengine3.0和InfluxDB1.8(该版本是InfluxDB能够运行TSBS框架的最新版本)进行了对比......
  • 为了成为Java大牛,我决定手写个JVM~
    JVM对我们很多人来说就像个黑盒子,无从下手,但是又是我们JavaCoder不得不去深入研究的一门技术国内玩JVM的大牛很少,知名的就那么几个,而玩好JVM又教好JVM的人更是少之又少。今天给大家介绍其中一位,江湖人送外号道格牙的子牙老师。下面的时间,交给他。哈喽,我就是江湖人送外号[......
  • 64位 CentOS 6.2 安装erlang及rabbitmq Server
    主题 RabbitMQErlangCentosCentOS6.264bit安装erlang及RabbitMQServer1、操作系统环境(CentOS6.264bit) [root@leekwen~]#cat/etc/issueCentOSrelease6.2(Final)Kernel\ronan\m[root@leekwen~]#cat/proc/cpuinfo|grep"clflushsize"c......
  • doubly block toeplitz matrix 在加速矩阵差卷积上的应用
    文档链接CNN的卷积是执行了\(w'_{i,j}=\sum\limits_{x,y}w_{i+x,j+y}\timesC_{x,y}\),有人认为每次平移卷积核,运算量很大,又是乘法又是加法。现在我们吧\(w_{x,y}\)展开形成一个\([n\timesm,1]\)的向量\(V\),然后构造一个大小为\([(n+1)\times(m+1),n\timesm]\)矩阵......
  • 9.手写实现智能指针类需要实现哪些函数?
    9.手写实现智能指针类需要实现哪些函数?1.智能指针是一个数据类型,一般用模板实现,模拟指针行为的同时还提供自动垃圾回收机制。它会自动记录SmartPointer<T*>对象的引用计数,一旦T类型对象的引用计数为0,就释放该对象。除了指针对象外,我们还需要一个引用计数的指针设定对象的值,并将......
  • react源码解析手写ReactDom.js和React
    前言大家好我是歌谣今天给大家带来react源码部分的实现创建项目首先npxcreate-react-appxxx降为17"dependencies":{"@testing-library/jest-dom":"^5.11.4","@testing-library/react":"^11.1.0","@testing-library/user-event&......
  • 多连接的数据库管理工具Navicat Premium 16.2.5 Mac版
    NavicatPremium是一款多连接的数据库管理工具,它是一款免费的多通道、多连接程序,它支持企业和组织同时使用多个应用程序,在一个应用程序中运行多个数据库管理程序。使用Premium可以在同一应用程序中执行多个数据库程序。NavicatPremium可根据应用程序或Web服务之间的速度差异调......
  • 卷积神经网络(LeNet)
    卷积神经网络(LeNet)卷积神经网络(LeNet)tensorflow..... pytorch实现LeNet5......