1 import numpy 2 #激活函数库 3 import scipy.special 4 5 import matplotlib.pyplot 6 7 #neutral network class definition 8 class neutralNetwork: 9 def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate): 10 #定义各个节点 11 self.inodes=inputnodes 12 self.hnodes=hiddennodes 13 self.onodes=outputnodes 14 15 #初始化权重矩阵(利用正态分布) 16 self.win=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes)) 17 self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes)) 18 19 #定义激活函数 20 self.activation_function=lambda x: scipy.special.expit(x) 21 22 #初始化学习率 23 self.lr=learningrate 24 pass 25 26 #训练网络并更新权重 27 def train(self,inputs_list,targets_list): 28 inputs=numpy.array(inputs_list,ndmin=2).T 29 targets=numpy.array(targets_list,ndmin=2).T 30 31 hidden_inputs=numpy.dot(self.win,inputs) 32 hidden_outputs=self.activation_function(hidden_inputs) 33 34 final_inputs=numpy.dot(self.who,hidden_outputs) 35 final_outputs=self.activation_function(final_inputs) 36 37 output_errors=targets-final_outputs 38 hidden_errors=numpy.dot(self.who.T,output_errors) 39 40 self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs)) 41 self.win+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs)) 42 43 pass 44 45 #查询每次输出结果 46 def query(self,inputs_list): 47 inputs=numpy.array(inputs_list,ndmin=2).T 48 49 hidden_inputs=numpy.dot(self.win,inputs) 50 hidden_outputs=self.activation_function(hidden_inputs) 51 52 final_inputs=numpy.dot(self.who,hidden_outputs) 53 final_outputs=self.activation_function(final_inputs) 54 55 return final_outputs 56 pass 57 58 #inputnode是像素的大小28*28 59 input_nodes=784 60 #选择比inputnode小的,强迫网络总结输入主要特点 61 hidden_nodes=100 62 #手写一共十个数字,所以设置outputnode为10 63 output_nodes=10 64 65 learning_rate=0.3 66 67 #训练2世代(太大会过度拟合) 68 epoches=2 69 70 n=neutralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate) 71 72 #加载mnist训练集 73 training_data_file=open("mnist_train.csv",'r') 74 training_data_list=training_data_file.readlines() 75 training_data_file.close() 76 77 #用训练集训练网络 78 for e in range(epoches): 79 for record in training_data_list: 80 all_values=record.split(',') 81 82 #转化成input矩阵格式(非0:会造成网络崩溃;除以最大像素是255得到0.01-0.99;激活函数不能达到1) 83 inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01 84 85 #设置目标输出:不能为0和1,否则会存在饱和网络(为了无限接近不可能的值0和1) 86 targets=numpy.zeros(output_nodes)+0.01 87 targets[int(all_values[0])]=0.99 88 n.train(inputs,targets) 89 pass 90 pass 91 92 #测试网络 93 test_data_file=open("mnist_test.csv",'r') 94 test_data_list=test_data_file.readlines() 95 test_data_file.close() 96 97 scorecard=[] 98 99 for record in test_data_list: 100 all_values=record.split(',') 101 correct_label=int(all_values[0]) 102 print(correct_label,"correct label") 103 image_array=numpy.asfarray(all_values[1:]).reshape((28,28)) 104 matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None') 105 106 inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01 107 108 outputs=n.query(inputs) 109 110 label=numpy.argmax(outputs) 111 print(label,"network's answer:") 112 113 if(label==correct_label): 114 scorecard.append(1) 115 else: 116 scorecard.append(0) 117 pass 118 119 scorecard_array=numpy.asfarray(scorecard) 120 print("performance=",scorecard_array.sum()/scorecard_array.size)
标签:inputs,outputs,self,py,list,神经网络,简单,hidden,numpy From: https://www.cnblogs.com/saucerdish/p/17841593.html