文章目录
1. 百度链接手动版
通过网盘分享的文件:conv2dtest.xlsx
链接: https://pan.baidu.com/s/1q3McqwfcKO1iX-Ms0BfAGA?pwd=ttsu 提取码: ttsu
2. Pytorch 版本
- python
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 1
in_channels = 2
out_channels = 3
input_h = 5
input_w = 5
total = batch_size * in_channels * input_h * input_w
# input_matrix = torch.randn(batch_size, in_channels, input_h, input_w)
input_matrix = torch.arange(total).reshape((batch_size, in_channels, input_h, input_w)).to(torch.float)
kernel_size = 3
# total = batch_size * in_channels * input_h * input_w
print(f"input_matrix=\n{input_matrix}")
my_conv2d = nn.Conv2d(out_channels=out_channels, in_channels=in_channels, kernel_size=kernel_size, stride=1,
bias=False)
weight_0 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float)
weight_1 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float) + 1
weight_2 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float) + 2
weight_3 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float) + 3
weight_4 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float) + 4
weight_5 = torch.arange(9).reshape((batch_size, kernel_size, kernel_size)).to(torch.float) + 5
weight_01 = torch.cat((weight_0, weight_1), dim=0)
weight_23 = torch.cat((weight_2, weight_3), dim=0)
weight_45 = torch.cat((weight_4, weight_5), dim=0)
print(f"weight_01=\n{weight_01}")
print(weight_01.shape)
weight_01 = torch.unsqueeze(weight_01, dim=0)
weight_23 = torch.unsqueeze(weight_23, dim=0)
weight_45 = torch.unsqueeze(weight_45, dim=0)
weight_012345 = torch.cat((weight_01, weight_23, weight_45), dim=0)
print(f"weight_0123.shape=\n{weight_012345.shape}")
print(f"weight_012345=\n{weight_012345}")
my_conv2d.weight = nn.Parameter(weight_012345)
output_matrix = my_conv2d(input_matrix)
print(f"output_matrix=\n{output_matrix}")
- 结果:
input_matrix=
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]],
[[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.]]]])
weight_01=
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]],
[[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]]])
torch.Size([2, 3, 3])
weight_0123.shape=
torch.Size([3, 2, 3, 3])
weight_012345=
tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]]],
[[[ 2., 3., 4.],
[ 5., 6., 7.],
[ 8., 9., 10.]],
[[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]]],
[[[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]],
[[ 5., 6., 7.],
[ 8., 9., 10.],
[11., 12., 13.]]]])
output_matrix=
tensor([[[[1803., 1884., 1965.],
[2208., 2289., 2370.],
[2613., 2694., 2775.]],
[[2469., 2586., 2703.],
[3054., 3171., 3288.],
[3639., 3756., 3873.]],
[[3135., 3288., 3441.],
[3900., 4053., 4206.],
[4665., 4818., 4971.]]]], grad_fn=<ConvolutionBackward0>)
- excel 版本结果:与pytorch版本一致,此版本无bias偏置。