首页 > 其他分享 >TensorFlow实践——Softmax Regression

TensorFlow实践——Softmax Regression

时间:2023-06-14 20:38:40浏览次数:43  
标签:batch cost Softmax input tf TensorFlow Regression mnist


Softmax Regression是Logistic回归在多分类上的推广,对于Logistic回归以及Softmax Regression的详细介绍可以参见:

  • 简单易学的机器学习算法——Logistic回归
  • 利用Theano理解深度学习——Logistic Regression
  • 深度学习算法原理——Softmax Regression

下面的代码是利用TensorFlow基本API实现的Softmax Regression:

'''
@author:zhaozhiyong
@date:20170822
Softmax Regression
'''

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)

learning_rate = 0.01
training_epochs = 1000
batch_size = 100
display_step = 50

n_input = 784
n_classes = 10

x = tf.placeholder("float", [None, n_input])
y = tf.placeholder("float", [None, n_classes])

w1 = tf.Variable(tf.random_normal([n_input, n_classes]))
b1 = tf.Variable(tf.random_normal([n_classes]))

pred = tf.add(tf.matmul(x, w1), b1)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

init = tf.global_variables_initializer()

with tf.Session() as sess:
	sess.run(init)
	for epoch in range(training_epochs):
		avg_cost = 0
		total_batch = int(mnist.train.num_examples/batch_size)
		for i in range(total_batch):
			batch_x, batch_y = mnist.train.next_batch(batch_size)
			_, c = sess.run([optimizer, cost], feed_dict={x: batch_x, y: batch_y})
			avg_cost += c / total_batch
		if epoch % display_step == 0:
			print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)
	print "Optimization Finished!"

	print "Get test data:"    	
	correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    	accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    	print "Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})

以下是运行的结果:

TensorFlow实践——Softmax Regression_Soft

参考文献

  1. [03]tensorflow实现softmax回归(softmax regression)


标签:batch,cost,Softmax,input,tf,TensorFlow,Regression,mnist
From: https://blog.51cto.com/u_16161414/6480389

相关文章

  • tensorflow实现花分类
    1.花数据集数据集来自kaggle官网下载。分为五类花,每类花有1000张图片。下载方式可以参考我的https://www.cnblogs.com/wancy/p/17446715.html 2.图片大小分布图训练模型之前,我们会需要先分析数据集,由于此类数据集每类花的图片数量一样,是均衡的。训练模型之前,我们......
  • 基于Tensorflow的Faster-Rcnn的断点续训
    一、前言最近在学习目标检测,到github上找了一个开源的Faster-RCNN项目(Tensorflow),项目地址是:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3根据网上的各种教程,模型训练还算顺利,不过这个项目缺少断点续训的功能。也就是中途误操作导致训练中止,就只能从头开始......
  • Python+tensorflow计算整数阶乘的方法与局限性
    本文代码主要演示tensorflow的基本用法。importtensorflowas#创建变量,保存计算结果start=tf.Variable(1,dtype=tf.int64)#初始化变量的opinit_op=tf.global_variables_initializer()#启用默认图withtf.Session()assess:#初始化变量sess.run(ini......
  • Google colab 更改Tensorflow深度学习框架版本
    Googlecolab默认导入的tensorflow版本是2.0想,想更改tensorflow版本为1.x则需要加入%tensorflow_version1.ximporttensorflowastftf.__version__之后要进行colab的重启,即可完成版本改装,由于colab没有terminal,所以只能这么修改。欢迎登陆官网(附https://tensorflow.google.cn/)......
  • 深度学习项目之mnist手写数字识别实战(TensorFlow框架)
    mnist手写数字识别是所有深度学习开发者的必经之路,mnist数据集的图片十分简单,是二值化图像,像素个数为28x28。所以对于所有研究深度学习的开发者来说学会mnist数据集的模型十分有必要。以此为实例进行计算机视觉如何进行识别出图片中的数据。MNIST手写数字数据集来自美国国家标准与......
  • CART——Classification And Regression Tree在python下的实现
    分类与回归树(CART——ClassificationAndRegressionTree))是一种非参数分类和回归方法,它通过构建二叉树达到预测目的。示例:1.样本数据集 2.运行结果-cart决策树的字典max_n_feats=3时tree_dict={house:{yes:agreen......
  • 神经网络:softmax激活函数
    softmax的作用:将多分类的输出值转换为范围在[0,1]和为1的概率分布soft反方词hardhardmax从一组数据中找到最大值softmax为每一个分类提供一个概率值,表示每个分类的可能性。所有分类的概念值之和是1.优点在x轴上一个很小的变化,可以导致y轴上很大的变化,将输出的数值拉开距离。在深......
  • 人工智能创新挑战赛:海洋气象预测Baseline[4]完整版(TensorFlow、torch版本)含数据转化
    人工智能创新挑战赛:海洋气象预测Baseline[4]完整版(TensorFlow、torch版本)含数据转化、模型构建、MLP、TCNN+RNN、LSTM模型训练以及预测1.赛题简介项目链接以及码源见文末2021“AIEarth”人工智能创新挑战赛,以“AI助力精准气象和海洋预测”为主题,旨在探索人工智能技术在气......
  • 人工智能创新挑战赛:海洋气象预测Baseline[4]完整版(TensorFlow、torch版本)含数据转化、
    人工智能创新挑战赛:海洋气象预测Baseline[4]完整版(TensorFlow、torch版本)含数据转化、模型构建、MLP、TCNN+RNN、LSTM模型训练以及预测1.赛题简介项目链接以及码源见文末2021“AIEarth”人工智能创新挑战赛,以“AI助力精准气象和海洋预测”为主题,旨在探索人工智能技术在气......
  • 如何使用深度学习和TensorFlow实现计算机视觉
    越来越多的地方正在使用计算机视觉。从增强安全系统到改进医疗保健诊断,计算机视觉技术正在彻底改变多个行业。##课程先睹为快本课程经过精心设计,涵盖了广泛的主题,从张量和变量的基础知识到高级深度学习模型的实现,以应对人类情感检测和图像生成等复杂任务。在介绍了先决条件并......