文章目录
前言
之前学习了深度学习的概念与基本过程,今天用一个简单的深度学习框架实现最大数字的找寻,理解深度学习的的基本流程。
问题描述
假设有一个5维数组, X = [ 2 , 3 , 7 , 4 , 5 ] X=[2,3,7,4,5] X=[2,3,7,4,5],则定义 Y = [ 0 , 0 , 1 , 0 , 0 ] Y=[0,0,1,0,0] Y=[0,0,1,0,0], X X X中数字最大的是第三个,所以 Y Y Y的第三个为1,并将他定义为第三类。
步骤1:生成训练集
步骤2:定义神经网络
步骤3:代入数据训练参数
步骤4:模型评估预测
生成训练集
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt##导入各种库
def build_sample():##数据规则
x=np.random.random(5)
y=np.zeros(5)
y[np.argmax(x)]=1
return x,y
def build_dataset(total_samlple_num):##生成数据
X=[]
Y=[]
for i in range(total_samlple_num):
x,y=build_sample()
X.append(x)
Y.append(y)
return torch.FloatTensor(np.array(X)),torch.FloatTensor(np.array(Y))
上面的代码块第一部分是导入各种库,第二部分是定义规则,先随机生成5维数组,然后将根据x的最大最小数值确定y的哪个值为1.第三部分为批量生成数据,并返回torch的浮点类型的数字,便于训练。
定义神经网络
我们利用最简单的线性层和激活层实现这个功能,我们的x是5维的,最终输出的y也是5维的,因为线性层有
y
=
X
w
+
b
y=Xw+b
y=Xw+b所以
w
w
w应该也是5*5的。我们生成的y值每一维应该是0到1的,属于五类数字的概率,哪个概率大就为哪类,所以我们激活函数可以选择为sigmoid函数,损失函数选择为交叉熵损失函数(因为是分类问题),最终的神经网络可以如下图(作者的手画):
利用代码定义这个网络层,有如下代码:
class TorchModel(nn.Module):
def __init__(self,input_size):
super(TorchModel,self).__init__()
self.linear=nn.Linear(input_size,5)
self.activation=torch.sigmoid
self.loss=nn.CrossEntropyLoss()
def forward(self,x,y=None):
x=self.linear(x)
y_pred=self.activation(x)
if y is not None:
return self.loss(y_pred,y)
else:
return y_pred
对于forward,有y值输入时为训练,计算损失函数,没有y值时时进行前向传播,输出预测值。
进行训练
# 测试代码
# 用来测试每轮模型的准确率
def evaluate(model):
model.eval()
test_sample_num = 100
x, y = build_dataset(test_sample_num)
a,b,c,d,e=0,0,0,0,0
for i in range(len(y)):
if y[i].argmax()==0:
a+=1
elif y[i].argmax()==1:
b+=1
elif y[i].argmax()==2:
c+=1
elif y[i].argmax()==3:
d+=1
else:
e+=1
print("本次预测集中共有%d个一类样本,%d个二类样本,%d个三类样本,%d个四类样本,%d个五类样本" % (a,b,c,d,e))
correct, wrong = 0, 0
with torch.no_grad():
y_pred = model(x) # 模型预测
for y_p, y_t in zip(y_pred, y): # 与真实标签进行对比
if int(y_p.argmax())==int(y_t.argmax()):
correct += 1 # 样本判断正确
else:
wrong += 1
print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))
return correct / (correct + wrong)
def main():
# 配置参数
epoch_num = 20 # 训练轮数
batch_size = 20 # 每次训练样本个数
train_sample = 5000 # 每轮训练总共训练的样本总数
input_size = 5 # 输入向量维度
learning_rate = 0.001 # 学习率
# 建立模型
model = TorchModel(input_size)
# 选择优化器
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
log = []
# 创建训练集,正常任务是读取训练集
train_x, train_y = build_dataset(train_sample)
# 训练过程
for epoch in range(epoch_num):
model.train()
watch_loss = []
for batch_index in range(train_sample // batch_size):
x = train_x[batch_index * batch_size : (batch_index + 1) * batch_size]
y = train_y[batch_index * batch_size : (batch_index + 1) * batch_size]
loss = model(x, y) # 计算loss
loss.backward() # 计算梯度
optim.step() # 更新权重
optim.zero_grad() # 梯度归零
watch_loss.append(loss.item())
print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
acc = evaluate(model) # 测试本轮模型结果
log.append([acc, float(np.mean(watch_loss))])
# 保存模型
torch.save(model.state_dict(), "max_number.pt")
# 画图
print(log)
plt.plot(range(len(log)), [l[0] for l in log], label="acc") # 画acc曲线
plt.plot(range(len(log)), [l[1] for l in log], label="loss") # 画loss曲线
plt.legend()
plt.show()
return
我们先定义一个评估函数evalute,在接下来训练时可以计算每轮训练的正确率。main函数为主要训练函数,先开始定义各种训练的超参数,20轮的反向传播。运行main()有如下结果:
可以看到损失函数在每轮过后逐渐减小,准确率在每轮训练后逐渐增大,准确率最终在百分之90以上。
模型评估预测
对于最终的结果我们可以查看模型 w w w, b b b的值是多少,并生成一些数据看看结果。
def predict(model_path, input_vec):
input_size = 5
model = TorchModel(input_size)
model.load_state_dict(torch.load(model_path)) # 加载训练好的权重
print(model.state_dict())
model.eval() # 测试模式
with torch.no_grad(): # 不计算梯度
result = model.forward(torch.FloatTensor(input_vec)) # 模型预测
for vec, res in zip(input_vec, result):
print("输入:%s, 预测类别:%d, 概率值:%f" % (vec, int(res.argmax())+1,res.max())) # 打印结果
test_vec = [[0.07889086,0.15229675,0.31082123,0.03504317,0.18920843],
[0.94963533,0.5524256,0.95758807,0.95520434,0.84890681],
[0.78797868,0.67482528,0.13625847,0.34675372,0.19871392],
[0.79349776,0.59416669,0.92579291,0.41567412,0.1358894]]
predict("model.pt",test_vec)
结果如下所示:
可以看到,对于差距比较大数字,该模型可以很大的概率预测准确的,对于精度比较小的值还是有小概率预测错误的。
以上就是深度学习的基本流程,你们还可以尝试用其他规则进行训练,比如说,第n个数字大于第m个数字归为正类,否则归为负类,小伙伴可以自己尝试一下。