涌井良幸、涌井贞美著的《深度学习的数学》这本书,浅显易懂。书中还用Excel示例(如下图)神经网络的计算,真是不错。但光有Excel示例还是有点欠缺的,如果有pytorch代码演示就更好了。
百度了半天在网上没找到别人写的,只好自己撸一个(使用python + pytorch),供同样在学习神经网络的初学者参考。
(注,这是书中5-6节:体验卷积神经网络的误差反向传播法,数据是96个6x6的1、2和3,用平方误差的总和作为代价函数, 用 Sigmoid 函数作为激活函数)
(书中4-4节神经网络计算pytorch示例一请参考:https://blog.51cto.com/oldycat/8133220)
(看这本书前建议可以先看立石贤吾著的《白话机器学习的数学》,再看这本书会变得很简单)
demo54.py:(注:这一版和Excel数据仍然部份差异,暂时还找不到应修改的地方)
import torch
import torch.nn as nn
import torch.optim as optimal
from torch import cosine_similarity
import demo54data as demo
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.activation = nn.Sigmoid()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(3 * (2 * 2), 3) # 输出层有3个神经元,对应数字0、1、2
self.conv1.weight.data = demo.get_param_co().resize(3, 1, 3, 3) # 若用正态分布,注释此行
self.conv1.bias.data = demo.get_param_co_bias() # 若用正态分布,注释此行
self.fc.weight.data = demo.get_param_op().resize(3, 12) # 若用正态分布,注释此行
self.fc.bias.data = demo.get_param_co_bias() # 若用正态分布,注释此行
def forward(self, x):
x = self.conv1(x)
demo.print_x("zF=", x)
x = self.activation(x)
x = self.pool(x)
demo.print_x("aF=", x)
x = x.view(x.size(0), -1)
x = self.fc(x) # 这里可能错误,导致算出来的数据和Excel不尽相同
demo.print_x("zO=", x)
x = self.activation(x)
demo.print_x("aO=", x)
return x
def mse_loss(x, y):
# m = nn.MSELoss(size_average=False)
# return m(x, y) / 2
z = x - y
print(" c= ", end='')
print(((z[0, 0] ** 2 + z[0, 1] ** 2) / 2).data.numpy(), end='')
for i in range(z.size()[0]):
if i > 0:
print("\t", ((z[i, 0] ** 2 + z[i, 1] ** 2) / 2).data.numpy(), end='')
print()
return (z[:, 0] ** 2 + z[:, 1] ** 2).sum() / 2
# 创建模型实例
model = CNN()
for param in model.parameters():
print(param)
# 定义损失函数和优化器
criterion = mse_loss
optimizer = optimal.SGD(model.parameters(), lr=0.2)
# 转换输入数据为张量
train_data = demo.get_data()
train_labels = demo.get_result()
# 开始训练
num_epochs = 1000
for epoch in range(num_epochs):
print("\nepoch=", epoch + 1)
optimizer.zero_grad()
outputs = model(train_data)
loss = criterion(outputs, train_labels)
print("Loss: {:.4f}".format(loss.item()))
loss.backward()
optimizer.step()
if (epoch + 1) == num_epochs or loss.item() < 0.05:
print("Epoch [{}/{}], Loss: {:.4f}".format(epoch + 1, num_epochs, loss.item()))
break
# 使用训练好的模型进行预测
model.eval()
print()
output = model(demo.get_test()).data
print(output.argmax(dim=1) + 1)
print("\n======= 比对全部结果 ======")
test_data = demo.get_data()
predictions = model(test_data)
result = (predictions.argmax(dim=1) + 1)
print(result.data)
print("差异:")
print((demo.get_result2() - result).long())
print()
print("准确度:", (torch.round(
cosine_similarity(result.unsqueeze(0), demo.get_result2().unsqueeze(0)).mean() * 10000) / 100).data.numpy(),
"%")
import torch
def get_param_co():
return torch.tensor([[
-1.277, -0.454, 0.358,
1.138, -2.398, -1.664,
-0.794, 0.899, 0.675
], [
-1.274, 2.338, 2.301,
0.649, -0.339, -2.054,
-1.022, -1.204, -1.900
], [
-1.869, 2.044, -1.290,
-1.710, -2.091, -2.946,
0.201, -1.323, 0.207
]])
def get_param_co_bias():
return torch.tensor([-3.363, -3.176, -1.739])
def get_param_op():
return torch.tensor([
[[
-0.276, 0.124,
- 0.961, 0.718
], [
-3.680, - 0.594,
0.280, - 0.782
], [
-1.475, - 2.010,
- 1.085, - 0.188
]],
[[
0.010, 0.661,
- 1.591, 2.189
], [
1.728, 0.003,
- 0.250, 1.898
], [
0.238, 1.589,
2.246, - 0.093
]],
[[
-1.322, - 0.218,
3.527, 0.061
], [
0.613, 0.218,
- 2.130, - 1.678
], [
1.236, - 0.486,
- 0.144, - 1.235
]]
])
def get_param_o_bias():
return torch.tensor([2.060, -2.746, -1.818])
def get_test():
return (torch.tensor([
[
1.0, 1, 1, 1, 0, 0,
1, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
1, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 1, 0, 0, 1, 1,
0, 0, 1, 1, 1, 0]])
.resize(2, 1, 6, 6))
def print_x(name, x):
if x.dim() > 3:
print(name, end='')
# for i in range(x.size()[0]):
for j in range(x.size()[1]):
print("\t[", end='')
for k in range(x.size()[2]):
print("", x[0, j, k, :].data.numpy(), end='')
print("]\n ", end='')
print()
elif x.dim() > 1:
print(name, end='')
print("", x[0, :].data.numpy(), end='')
for i in range(x.size()[0]):
if i > 0:
print("\t\t", x[i, :].data.numpy(), end='')
print()
def print_params(params):
for param in params:
if param.dim() > 1:
for i in range(param.size()[0]):
print('\t[', end='')
print(param[i, 0].data.numpy(), end='')
for j in range(param.size()[1]):
if j > 0:
print('\t', param[i, j].data.numpy(), end='')
print('] ', end='')
print()
else:
print('\t[', end='')
print(param[0].data.numpy(), end='')
for i in range(param.size()[0]):
if i > 0:
print('\t', param[i].data.numpy(), end='')
print('] ')
print()
def get_data():
return torch.tensor([[
0.0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, # 10
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 1, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0], [
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 0, 0, 0, # 20
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 0, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0], [
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0], [
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0], [
0, 1, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0], [
0, 1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 0, 1, 0, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0, # 30
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
1, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
1, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
1, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
1, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0,
1, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 1, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 0, 0,
0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 1, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 1,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 1, 1,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 1, 1,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 0, 1,
0, 0, 0, 1, 1, 1,
0, 0, 0, 0, 1, 1,
0, 1, 0, 0, 0, 1,
0, 0, 1, 1, 1, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 0, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
1, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
1, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
1, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 0, 1,
0, 0, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 1, 0,
0, 0, 1, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
1, 0, 0, 0, 1, 0,
0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0,
0, 0, 0, 0, 1, 0,
1, 1, 1, 1, 0, 0], [
1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 1, 0,
0, 0, 1, 1, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 1,
0, 1, 1, 1, 1, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 1, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
0, 1, 1, 1, 0, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 1, 1, 0, 1, 0,
0, 0, 1, 1, 0, 0], [
0, 0, 1, 1, 0, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
0, 1, 1, 1, 1, 0], [
1, 1, 1, 1, 1, 0,
0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
0, 1, 1, 1, 0, 0], [
1, 1, 1, 1, 0, 0,
1, 1, 0, 0, 1, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 0,
1, 1, 1, 1, 0, 0], [
0, 0, 1, 1, 1, 0,
0, 1, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 0,
0, 1, 0, 0, 1, 1,
0, 0, 1, 1, 1, 0]]
).resize(96, 1, 6, 6)
def get_result():
return torch.tensor([[
1.0, 0.0, 0.0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
1, 0, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 1, 0], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1], [
0, 0, 1]])
def get_result2():
return torch.tensor([
1.0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3
])
运行结果: