首页 > 编程语言 >Python小练习:权重初始化(Weight Initialization)

Python小练习:权重初始化(Weight Initialization)

时间:2023-04-06 20:12:40浏览次数:61  
标签:nn weight Python Initialization Weight init data self Linear

Python小练习:权重初始化(Weight Initialization)

作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/

调用Pytorch中的torch.nn.init.xxx实现对模型权重与偏置初始化。

1. weight_init_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凯鲁嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小练习:权重初始化(Weight Initialization)
 5 # Custom weight init for Conv2D and Linear layers.
 6 import torch
 7 import torch.nn.functional as F
 8 import torch.nn as nn
 9 # 根据网络层的不同定义不同的初始化方式
10 # 以下是两种不同的初始化方式:
11 # 正态分布+常数
12 def weight_init(m):
13     if isinstance(m, nn.Linear):
14         # 如果传入的参数是 nn.Linear 类型,则执行以下操作:
15         nn.init.xavier_normal_(m.weight) # 将权重初始化为 Xavier 正态分布
16         nn.init.constant_(m.bias, 0) # 将权重初始化为常数
17     elif isinstance(m, nn.Conv2d):
18         # 如果传入的参数是 nn.Conv2d 类型,则执行以下操作:
19         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 将权重初始化为正态分布
20     elif isinstance(m, nn.BatchNorm2d):
21         # 如果传入的参数是 nn.BatchNorm2d 类型,则执行以下操作:
22         nn.init.constant_(m.weight, 1)
23         nn.init.constant_(m.bias, 0)
24 
25 # 正交+常数
26 def weight_init2(m):
27     if isinstance(m, nn.Linear):
28         # 如果传入的参数是 nn.Linear 类型,则执行以下操作:
29         nn.init.orthogonal_(m.weight.data) # 对权重矩阵进行正交化操作,使其具有对称性。
30         if hasattr(m.bias, 'data'):
31             m.bias.data.fill_(0.0) # 如果传入的参数包含偏置项,则将其填充为零。
32     elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
33         # 如果传入的参数是 nn.Conv2d 或 nn.ConvTranspose2d 类型,则执行以下操作:
34         gain = nn.init.calculate_gain('relu') # 用于计算激活函数的增益
35         nn.init.orthogonal_(m.weight.data, gain) # 对权重矩阵进行正交化操作,使其具有对称性。
36         if hasattr(m.bias, 'data'):
37             m.bias.data.fill_(0.0) # 如果传入的参数包含偏置项,则将其填充为零。
38 
39 class Net(nn.Module):
40     def __init__(self, input_size=1):
41         self.input_size = input_size
42         super(Net, self).__init__()
43         self.fc1 = nn.Linear(self.input_size, 2)
44         self.fc2 = nn.Linear(2, 4)
45         self.fc3 = nn.Linear(4, 2)
46 
47     def forward(self, x):
48         x = x.view(-1, self.input_size)
49         x = F.relu(self.fc1(x))
50         x = F.relu(self.fc2(x))
51         x = self.fc3(x)
52         return F.log_softmax(x, dim=1)
53 
54 torch.manual_seed(1)
55 num = 4 # 输入维度
56 x = torch.randn(1, num)
57 # 方式1:
58 model = Net(input_size = num)
59 print('网络结构:\n', model)
60 print('输入:\n', x)
61 model.apply(weight_init)
62 y = model(x)
63 print('输出1:\n', y.data)
64 print('权重1:\n', model.fc1.weight.data)
65 # 方式2:
66 model = Net(input_size = num)
67 model.apply(weight_init2)
68 y = model(x)
69 print('输出2:\n', y.data)
70 print('权重2:\n', model.fc1.weight.data)

2. 结果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Neural Network/weight_init_test.py"
网络结构:
 Net(
  (fc1): Linear(in_features=4, out_features=2, bias=True)
  (fc2): Linear(in_features=2, out_features=4, bias=True)
  (fc3): Linear(in_features=4, out_features=2, bias=True)
)
输入:
 tensor([[0.6614, 0.2669, 0.0617, 0.6213]])
输出1:
 tensor([[-0.7233, -0.6639]])
权重1:
 tensor([[ 2.0709, -1.0573,  0.9230, -0.7373],
        [ 0.1879, -0.2766,  0.7962,  1.4599]])
输出2:
 tensor([[-0.6951, -0.6912]])
权重2:
 tensor([[-0.8471, -0.4721,  0.1653,  0.1795],
        [-0.4072,  0.5991, -0.6437,  0.2467]])

Process finished with exit code 0

完成。

标签:nn,weight,Python,Initialization,Weight,init,data,self,Linear
From: https://www.cnblogs.com/kailugaji/p/17294001.html

相关文章

  • 【转】python pip 换源阿里云
    via:pythonpip换源阿里云-知乎(zhihu.com)pip换源阿里云只需要在cmd输入一条命令:pipconfigsetglobal.index-urlhttps://mirrors.aliyun.com/pypi/simple ......
  • opencv-python 4.13. 霍夫线变换
    前言霍夫变换是一种特征检测(featureextraction),被广泛应用在图像分析(imageanalysis)、计算机视觉(computervision)以及数位影像处理(digitalimageprocessing)。霍夫变换(HoughTransform)是图像处理中的一种特征提取技术,它通过一种投票算法检测具有特定形状的物体。该过程......
  • Python字节
    python的文件操作中,一个中文字符等于3个字节。1B(byte,字节)=8bit(位)1KB(Kilobyte,千字节)=1024B =(10的3次方)B1MB(Megabyte,兆字节,百万字节,简称“兆”)=1024KB = (10的6次方)B1GB(GB,吉字节,千兆) = 1024MB1TB(TB,万亿字节,太字节) = 1024GB1PB(PB,千万亿字节,拍字节) ......
  • CS50-Python实验5,6
    Week5UnitTestsTestingmytwittr题目描述:测试ProblemSet2中Settingupmytwttr程序;题解:twttr.pydefmain():print("Output:",shorten(input("Input:")))defshorten(word):ans=""foriinword:ifi.lowe......
  • python系列007
    //使用类进行设备控制类文件内容importpyvisaimportnumpyasnpimporttimeclassPiDevice:def__init__(self):self.device_id=None#默认无设备连接deffind_device_address(self,device_id):rm=pyvisa.ResourceMa......
  • 20230406-Python-if判断-day4
    条件语句4月6场景假设:网吧上网去⽹吧进⻔想要上⽹必须做的⼀件事是做什么?(考虑重点)为什么要把身份证给⼯作⼈员?是不是就是为了判断是否成年?是不是如果成年可以上⽹?如果不成年则不允许上⽹?其实这⾥所谓的判断就是条件语句,即条件成⽴执⾏某些代码,条件不成⽴则不执⾏这些......
  • Python学习——Day1
      学习python与C语言相似,第一件事也是输出一个”HelloWorld"。  但是相比C语言,python的输出要简洁好多,他没有换行符\n也能自动换行,print()函数里字符串无论是使用单引号还是双引号结果都能正常输出且输出结果一样。  第二个就是注释,python则与C语言不同,这里用到......
  • 【Python从零到壹】Python条件语句详解
    欢迎大家来到互联网老辛的专栏《Python从零到壹》,在这里我将分享约300篇Python系列文章,所有文章都将结合案例、代码和作者的经验讲解,真心想把自己近十年的编程经验分享给大家,希望对您有所帮助,文章中不足之处也请海涵。从事教学工作以来,越来越觉得时间的宝贵,每届学生都要讲重复的课,......
  • 关于python安装模块之后pychram仍然提示没有安装模块的问题
    项目场景:如图所示:需要安装的包已经安装好,但是到了pycharm里就没法使用,相信很多小伙伴遇到过这个问题。原因分析:遇到这个问题的主要原因是你的电脑里安装了两个pycharm解释器,你安装后,实际上是安装到了你电脑的Python3而非pycharm解释器。解决方案:所以我们可以在pycharm里面直......
  • python requests-html
    #pipinstallrequests-html '''目标网站:https://pic.netbian.com'''fromrequests_htmlimportHTMLSessionimportre,osimportrequestsfromtqdmimporttqdmfromfunctoolsimportpartialfrommultiprocessingimportPools......