1、nn.Parameter函数
2、torch.mm 和torch.matmul区别
都是 PyTorch 中用于矩阵乘法的函数,但它们在使用上有细微的差别
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyLinear(nn.Module):
def __init__(self, in_units, out_units):
super(MyLinear, self).__init__()
self.weight = nn.Parameter(torch.randn((in_units, out_units)))
self.bias = nn.Parameter(torch.randn(out_units))
def forward(self, x):
linear = torch.matmul(x, self.weight) + self.bias
return F.relu(linear)
linear = MyLinear(5, 3)
print(linear.weight)
y = linear(torch.rand((2, 5)))
print(y)
标签:__,nn,自定义,18,self,torch,神经网络,units,linear
From: https://www.cnblogs.com/morehair/p/18378769