首页 > 其他分享 >简单神经网络(py)

简单神经网络(py)

时间:2023-11-19 09:11:38浏览次数:33  
标签:inputs outputs self py list 神经网络 简单 hidden numpy

  1 import numpy
  2 #激活函数库
  3 import scipy.special
  4 
  5 import matplotlib.pyplot
  6 
  7 #neutral network class definition
  8 class neutralNetwork:
  9     def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
 10         #定义各个节点
 11         self.inodes=inputnodes
 12         self.hnodes=hiddennodes
 13         self.onodes=outputnodes
 14 
 15         #初始化权重矩阵(利用正态分布)
 16         self.win=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
 17         self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
 18 
 19         #定义激活函数
 20         self.activation_function=lambda x: scipy.special.expit(x)
 21 
 22         #初始化学习率
 23         self.lr=learningrate
 24         pass
 25 
 26     #训练网络并更新权重
 27     def train(self,inputs_list,targets_list):
 28         inputs=numpy.array(inputs_list,ndmin=2).T
 29         targets=numpy.array(targets_list,ndmin=2).T
 30 
 31         hidden_inputs=numpy.dot(self.win,inputs)
 32         hidden_outputs=self.activation_function(hidden_inputs)
 33 
 34         final_inputs=numpy.dot(self.who,hidden_outputs)
 35         final_outputs=self.activation_function(final_inputs)
 36 
 37         output_errors=targets-final_outputs
 38         hidden_errors=numpy.dot(self.who.T,output_errors)
 39 
 40         self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
 41         self.win+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
 42 
 43         pass
 44 
 45     #查询每次输出结果
 46     def query(self,inputs_list):
 47         inputs=numpy.array(inputs_list,ndmin=2).T
 48 
 49         hidden_inputs=numpy.dot(self.win,inputs)
 50         hidden_outputs=self.activation_function(hidden_inputs)
 51 
 52         final_inputs=numpy.dot(self.who,hidden_outputs)
 53         final_outputs=self.activation_function(final_inputs)
 54 
 55         return final_outputs
 56         pass
 57 
 58 #inputnode是像素的大小28*28
 59 input_nodes=784
 60 #选择比inputnode小的,强迫网络总结输入主要特点
 61 hidden_nodes=100
 62 #手写一共十个数字,所以设置outputnode为10
 63 output_nodes=10
 64 
 65 learning_rate=0.3
 66 
 67 #训练2世代(太大会过度拟合)
 68 epoches=2
 69 
 70 n=neutralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
 71 
 72 #加载mnist训练集
 73 training_data_file=open("mnist_train.csv",'r')
 74 training_data_list=training_data_file.readlines()
 75 training_data_file.close()
 76 
 77 #用训练集训练网络
 78 for e in range(epoches):
 79     for record in training_data_list:
 80         all_values=record.split(',')
 81 
 82         #转化成input矩阵格式(非0:会造成网络崩溃;除以最大像素是255得到0.01-0.99;激活函数不能达到1)
 83         inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
 84 
 85         #设置目标输出:不能为0和1,否则会存在饱和网络(为了无限接近不可能的值0和1)
 86         targets=numpy.zeros(output_nodes)+0.01
 87         targets[int(all_values[0])]=0.99
 88         n.train(inputs,targets)
 89         pass
 90     pass
 91 
 92 #测试网络
 93 test_data_file=open("mnist_test.csv",'r')
 94 test_data_list=test_data_file.readlines()
 95 test_data_file.close()
 96 
 97 scorecard=[]
 98 
 99 for record in test_data_list:
100     all_values=record.split(',')
101     correct_label=int(all_values[0])
102     print(correct_label,"correct label")
103     image_array=numpy.asfarray(all_values[1:]).reshape((28,28))
104     matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None')
105 
106     inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
107 
108     outputs=n.query(inputs)
109 
110     label=numpy.argmax(outputs)
111     print(label,"network's answer:")
112 
113     if(label==correct_label):
114         scorecard.append(1)
115     else:
116         scorecard.append(0)
117     pass
118 
119 scorecard_array=numpy.asfarray(scorecard)
120 print("performance=",scorecard_array.sum()/scorecard_array.size)

 

标签:inputs,outputs,self,py,list,神经网络,简单,hidden,numpy
From: https://www.cnblogs.com/saucerdish/p/17841593.html

相关文章

  • Python、Spark SQL、MapReduce决策树、回归对车祸发生率影响因素可视化分析
    原文链接:https://tecdat.cn/?p=34286原文出处:拓端数据部落公众号分析师:ShichaoZhong项目挑战如何处理庞大的数据集,并对数据进行可视化展示;在后续分析中特征选择是重点之一,要根据事实情况和数据易处理的角度来筛选变量解决方案任务/目标根据已有的车祸数据信息,计算严重车祸......
  • python 数据可视化:直方图、核密度估计图、箱线图、累积分布函数图
    本文使用数据来源自2023年数学建模国赛C题,以附件1、附件2数据为基础,通过excel的数据透视表等功能重新汇总了一份新的数据表,从中截取了一部分数据为例用于绘制图表。绘制的图表包括一维直方图、一维核密度估计图、二维直方图、二维核密度估计图、箱线图、累计分布函数图。 目录......
  • 【Python自动化】定时自动采集,并发送微信告警通知,全流程案例讲解!
    目录一、概要二、效果演示三、代码讲解3.1爬虫采集行政处罚数据3.2存MySQL数据库3.3发送告警邮件&微信通知3.4定时机制四、总结一、概要您好!我是@马哥python说,一名10年程序猿。我原创开发了一套定时自动化爬取方案,完整开发流程如下:采集数据->筛选数据->存MySQL数据库......
  • 大白话说Python+Flask入门(二)
    写在前面笔者技术真的很一般,也许只靠着笨鸟先飞的这种傻瓜坚持,才能在互联网行业侥幸的生存下来吧!为什么这么说?我曾不止一次在某群,看到说我写的东西一点技术含量都没有,而且很没营养,换作一年前的我,也许会怼回去,现在的话,我只是看到了,完事忘记了。早期写文章是为了当笔记用,不会随......
  • Java开发者的Python快速进修指南:控制之if-else和循环技巧
    简单介绍在我们今天的学习中,让我们简要了解一下Python的控制流程。考虑到我们作为有着丰富Java开发经验的程序员,我们将跳过一些基础概念,如变量和数据类型。如果遇到不熟悉的内容,可以随时查阅文档。但在编写程序或逻辑时,if-else判断和循环操作无疑是我们经常使用的基本结构。毕竟,......
  • 囚徒4.0_11_基于python的风云云检测算法
    #囚徒4.0_11_基于python的风云算法#关于昨天数据不同的问题:是因为IDL和Python的逻辑不同而导致的,数据读取没问题,我表示错了。#换语言好麻烦,现在都不知道什么语法对应什么语言了,一团糟。#从上午十点写到现在,测试的时候发现python他的读取逻辑和IDL不一样,他的循环也不一样,我真......
  • Halo2简单示例
    Halo2简介[[Halo2]]是使用[[Rust]]语言开发,基于[[PLANK算法]]的,一款开源交互式([[STARKs]]),[[零知识证明(ZKP)]]的[[证明系统]]。GitHub仓库地址:halo2不同于普通的开发框架,Halo2中的功能开发称为电路(Circuit)开发,电路开发使用表格来设计并记录运算,并包含一系列的约束来验证......
  • python数据提取-正则表达式
    1.正则表达式 (1)re的findall()方法importrer_list=re.findall('AB','ABCABDDGAAGDSGSDG')#后匹配前print(r_list)#输出:['AB','AB'] (2)也可以写作下面importrepattern=re.compile('AB')r_list=pattern.findall('ABCABDDGA......
  • python数据持久化(mysql+CSV+mongodb)
    1.创建数据库createdatabasemydbcharsetutf8;usemydb;createtablemydb(namevarchar(100),starvarchar(200),timevarchar(100))charset=utf8;2.使用pymysql模块在mytab表中插入一条表记录importpymysql#(1)创建数据库连接对象db=pymysql.connect('localhost','roo......
  • 爬取python网站下载地址,并下载最新文件
    1.下载https://www.python.org/ftp/python/最新版本python文件  一个下载网站,查看最新的,然后下载对应版本文件(如,列出python版本,并下载https://www.python.org/ftp/python/3.5.2/Python-3.5.2.tar.xz)。 代码如下:importrequestsfromlxmlimportetreeimporttimeimportr......