首页 > 其他分享 >每天五分钟玩转深度学习框架PyTorch:获取神经网络模型的参数

每天五分钟玩转深度学习框架PyTorch:获取神经网络模型的参数

时间:2024-09-10 23:53:45浏览次数:13  
标签:nn parameters list 神经网络 PyTorch 参数 玩转 print net

本文重点

当我们定义好神经网络之后,这个网络是由多个网络层构成的,每层都有参数,我们如何才能获取到这些参数呢?我们将再下面介绍几个方法来获取神经网络的模型参数,此文我们是为了学习第6步(优化器)。

获取所有参数Parameters

from torch import nn
net=nn.Sequential(
 nn.Linear(4,2),
 nn.Linear(2,2)
)
print(list(net.parameters()))#返回一个列表,列表中的元素是每一个参数
print("-----------------------------")
print(list(net.parameters())[0].shape)
print(list(net.parameters())[1].shape)
print(list(net.parameters())[2].shape)
print(list(net.parameters())[3].shape)

net.Parameters()可以获得模型的所有参数,返回一个列表[w0,b0,w1,b1],有多少参数,列表中就有多少元素。一个全连接层有两个参数w,b,那

标签:nn,parameters,list,神经网络,PyTorch,参数,玩转,print,net
From: https://blog.csdn.net/huanfeng_AI/article/details/142112146

相关文章