首页 > 其他分享 >tf.keras实现逻辑回归和softmax多分类

tf.keras实现逻辑回归和softmax多分类

时间:2024-05-31 15:46:08浏览次数:29  
标签:keras image label train softmax tf model

逻辑回归实现

转自:https://www.cnblogs.com/miraclepbc/p/14311509.html

相关库引用

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

加载数据

data = pd.read_csv("E:/datasets/dataset/credit-a.csv", header = None) # 获取数据
x = data.iloc[:, :-1]
y = data.iloc[:, -1].replace(-1, 0)
data.head()


观察发现,最后一列(label)非0即1。因此,这是一个二分类问题。可以考虑把-1全都替换成0

定义模型

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, input_shape = (15, ), activation = 'relu'))
model.add(tf.keras.layers.Dense(4, activation = 'relu'))
model.add(tf.keras.layers.Dense(1, activation = 'sigmoid'))
model.summary()

这个模型第一层,有4个神经元,因为输入是15个参数,因此参数个数为4∗15+4=644∗15+4=64。这里使用ReLU作为激活函数;
模型第二层,有4个神经元,输入是4个参数,因此参数个数为4∗4+4=204∗4+4=20。这里使用ReLU作为激活函数;
模型第三层,有1个神经元,输入是4个参数,因此参数个数为1∗4+1=51∗4+1=5。这里使用Sigmoid作为激活函数。
这里总共有89个参数

模型编译

model.compile(
    optimizer = 'adam',
    loss      = 'binary_crossentropy',
    metrics   = ['acc'] # 设置显示的参数
)

这里是二分类问题,因此损失函数可以设置为binary_crossentropy

训练模型

 
history = model.fit(x, y, epochs = 1000) # 训练1000次

下面我们来看一下模型的一些参数

history.history.keys()

发现有loss和acc两个参数
然后,我们再画出随着训练轮数的增加,loss和acc的变化曲线图

plt.plot(history.epoch, history.history.get('loss'))
plt.plot(history.epoch, history.history.get('acc'))

loss变化曲线图:

acc变化曲线图:

softmax多分类实现

加载数据

(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data() # 获取数据
plt.imshow(train_image[0]) # 显示第一张图片

数据归一化:

train_image = train_image / 255
test_image = test_image / 255

定义模型

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape = (28, 28)))
model.add(tf.keras.layers.Dense(128, activation = 'relu'))
# model.add(tf.keras.layers.Dropout(0.5)) 添加一个dropout层,防止过拟合
model.add(tf.keras.layers.Dense(10, activation = 'softmax'))

因为输入图像是二维的(28*28),因此需要先将其变换成一维向量。
第一层128个神经元,激活函数为ReLU
第二层10个神经元,激活函数为softmax

模型编译

model.compile(
    optimizer = 'adam',
    loss      = 'sparse_categorical_crossentropy',
    metrics   = ['acc']
)

这里因为是多分类问题,并且标签是一般的数值标签,因此损失函数使用sparse_categorical_crossentropy

训练模型

model.fit(train_image, train_label, epochs = 10) # 训练10次
# model.fit(train_image, train_label, epochs = 10, validation_data = (test_image, test_label)) # validation_data可以同时查看测试集的正确率和损失

模型评价

在测试集上评估训练的模型

model.evaluate(test_image, test_label)

one-hot编码

one-hot编码的转换

train_label_onehot = tf.keras.utils.to_categorical(train_label)
test_label_onehot = tf.keras.utils.to_categorical(test_label)

模型的编译

model.compile(
    optimizer = 'adam',
    loss      = 'categorical_crossentropy',
    metrics   = ['acc']
)

因为使用的是one-hot编码,因此损失函数使用categorical-crossentropy

标签:keras,image,label,train,softmax,tf,model
From: https://www.cnblogs.com/gongzb/p/18224673

相关文章

  • tf.keras实现线性回归和多层感知器
    线性回归实现转自:https://www.cnblogs.com/miraclepbc/p/14287756.html相关库引用importtensorflowastfimportnumpyasnpimportpandasaspdimportmatplotlib.pyplotasplt%matplotlibinline加载数据data=pd.read_csv("E:/datasets/dataset/Income1.csv")#......
  • 低代码开发平台(Low-code Development Platform)的模块组成部分
    低代码开发平台(Low-codeDevelopmentPlatform)的模块组成部分主要包括以下几个方面:低代码开发平台的模块组成部分可以按照包含系统、模块、菜单组织操作行为等维度进行详细阐述。以下是从这些方面对平台模块组成部分的说明:包含系统低代码开发平台本身作为一个完整的系统,包含......
  • Delphi 2010 新增功能之: IOUtils 单元(1): 初识 TDirectory.GetFiles
    用IOUtils单元下的TDirectory.GetFiles获取文件列表太方便了;下面的例子只是TDirectory.GetFiles的典型应用...unitUnit1;interfaceuses Windows,Messages,SysUtils,Variants,Classes,Graphics,Controls,Forms, Dialogs,StdCtrls;type TForm1=......
  • 文本挖掘tf-idf,主题建模,情感分析,n-gram建模研究|附代码数据
    原文链接:http://tecdat.cn/?p=6864我们围绕文本挖掘技术进行一些咨询,帮助客户解决独特的业务问题。我们对20个Usenet公告板的20,000条消息进行分析 ( 点击文末“阅读原文”获取完整代码数据******** )。此数据集中的Usenet公告板包括新汽车,体育和密码学等主题。预处理我们首......
  • BUUCTF Crypto 1~20刷题记录
    文章目录一、Crypto1、MD52、Url编码3、摩丝4、password5、Quoted-printable6、篱笆墙的影子7、Rabbit8、RSA9、丢失的MD510、Alice与Bob11、大帝的密码武器12、rsarsa13、Windows系统密码14、信息化时代的步伐15、凯撒?替换?呵呵!16、萌萌哒的八戒17、权限获得第一步18、......
  • 模型节点操作学习笔记(Appendix)实验1 -- Tflite int8 删除最后的Round节点 (持续更新)
    背景如下:我要删除Round节点,同时看了一下,Dequantize和Quantize也是没有必要的。所以最好一起删除。原始项目地址:PINTO0309/hand-gesture-recognition-using-onnx:ThisisahandgesturerecognitionprogramthatreplacestheentireMediaPipeprocesswithONNX.Simultane......
  • BUUCTF-Misc(61-70)
    [ACTF新生赛2020]swp参考:[BUUCTFmisc专题(76)ACTF新生赛2020]swp-CSDN博客解开压缩包,密密麻麻,不懂咋办了然后这边进行协议分析大部分是tcp,所以我们导出对象->选择http然后我就找到这个加密的压缩包然后010editor打开发现伪加密,改成00,有两处我只圈了一处在flag.swp里面......
  • 【Swing】JTextField设置光标
    1、设置焦点焦点默认是在窗体的第一个组件上UIManager.setLookAndFeel(UIManager.getSystemLookAndFeelClassName());JFramewindow=newJFrame();JPanelpanel=newJPanel(newFlowLayout(FlowLayout.LEFT));JButtonmessageDialog=newJButton("消息框");messageDi......
  • TF-IDF算法
    TF-IDF(termfrequency–inversedocumentfrequency,词频-逆向文件频率)TF-IDF本质上是一种统计方法,用来评估一个词/token在整个语料库中当前文档中的重要程度,字词的重要性随着它在当前文档中出现的频率成正比增加,随着它在整个语料库中出现的频率成反比降低。主要思想:某个单词在当......
  • 自己实现dubbo参数校验(类似RestFul 参数校验)
    1.场景:因为工作中经常需要做参数校验,在springboot项目中使用@Valid+@NotNull、@NotBlank…注解开发API接口非常丝滑,相反在开发RPC接口时却还是需要编写大量的参数判断,严重影响主业务流程的开发(公司目前用的是Dubbo2.7.2)且代码整洁度、风格都受到了挑战。基于以上原因萌生了写一......