首页 > 其他分享 >学习进度笔记⑩

学习进度笔记⑩

时间:2025-01-16 21:21:36浏览次数:1  
标签:sess run training 笔记 学习 train 进度 tf cost

Tensorflow线性回归

源代码:

复制代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

#设置训练参数,learning_rate=0.01,training_epochs=1000,display_step=50
learning_rate=0.01
training_epochs=1000
display_step=50
#创建训练数据
train_X=np.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,
7.042,10.791,5.313,7.997,5.654,9.27,3.1])
train_Y=np.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
2.827,3.465,1.65,2.904,2.42,2.94,1.3])
n_samples=train_X.shape[0]
#构造计算图,使用变量Variable构造变量X,Y,代码如下:
X=tf.placeholder("float")
Y=tf.placeholder("float")
#设置模型的初始权重
W=tf.Variable(np.random.randn(),name="weight")
b=tf.Variable(np.random.randn(),name='bias')
#构造线性回归模型
pred=tf.add(tf.multiply(X,W),b)
#求损失函数,即均方差
cost=tf.reduce_sum(tf.pow(pred-Y,2))/(2*n_samples)
#使用梯度下降法求最小值,即最优解
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
#初始化全部变量
init =tf.global_variables_initializer()
#使用tf.Session()创建Session会话对象,会话封装了Tensorflow运行时的状态和控制。
with tf.Session() as sess:
sess.run(init)
#调用会话对象sess的run方法,运行计算图,即开始训练模型。
#Fit all training data
for epoch in range(training_epochs):
for (x,y) in zip(train_X,train_Y):
sess.run(optimizer,feed_dict={X:x,Y:y})

#Display logs per epoch step
if (epoch+1) % display_step==0:
c=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
print("Epoch:" ,'%04d' %(epoch+1),"cost=","{:.9f}".format(c),"W=",sess.run(W),"b=",sess.run(b))
#打印训练模型的代价函数。
training_cost=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
print("Train cost=",training_cost,"W=",sess.run(W),"b=",sess.run(b))
#可视化,展现线性模型的最终结果。
plt.plot(train_X,train_Y,'ro',label='Original data')
plt.plot(train_X,sess.run(W)*train_X+sess.run(b),label="Fitting line")
plt.legend()
plt.show()
 
复制代码

结果截图:

 

 

标签:sess,run,training,笔记,学习,train,进度,tf,cost
From: https://www.cnblogs.com/binglinll/p/18675768

相关文章

  • 学习进度笔记⑨
    tensorflow基本操作(类似numpy)源代码importtensorflowastfimportosos.environ["CUDA_VISIBLE_DEVICES"]="0"#构造计算图,创建两个常量节点a,b,值分别为2,3a=tf.constant(2)b=tf.constant(3)#创建一个Session会话对象,调用run方法,运行计算图。withtf.Session()assess:......
  • 深度学习-猫狗识别:警惕kaggle中有问题的数据集
    大家注意了,深度学习的书有讲到进行猫狗二元识别的。在kaggle上有一个支持率很高的数据集,无法用于学习。因为在Cat目录中的第666张图片,大小是0,这导致了加载函数的崩溃。考虑到这个图片名字很巧合,我有的理由怀疑这是故意的。 此数据集无法用于《Python深度学习·第2版》望周......
  • 【ArcGIS】基于ChatGPT、GIS与Python机器学习的地质灾害风险评估、易发性分析、信息化
    目录第一章、ChatGPT大语言模型提示词与地质灾害基础及平台介绍第二章、空间信息数据库建设第三章、ChatGPT支持下地质灾害风险评价模型与方法第四章、ChatGPT支持下地质灾害风险性、易损性、易发性评价第五章、基于ChatGPT、Python数据预处理与分析【进阶篇】第六章、Ch......
  • 大数据体系介绍与 SQL 处理流程自学笔记
    一、大数据体系概述(一)大数据的概念大数据是指数据的规模、复杂性和多样性超出了传统数据处理工具和技术的处理能力的数据集合。其特点通常被概括为“4V”:Volume(大量):数据量巨大,通常以TB、PB甚至EB为单位。例如,一家大型电商平台每天产生的用户浏览记录、购买记录、搜索......
  • Linux运维学习记录14
    1.Docker的网络模式Docker的网络支持5种网络模式nonebridgecontainerhostnetwork-namebridge:描述:默认网络模式。如果未显式指定网络模式,Docker会将容器连接到一个默认的桥接网络docker0。每个容器都有自己的虚拟网卡,并通过docker0与其他容器通信。特点:容器之间......
  • 数字电路课程笔记
    1、数制与码制补零二进制和十六进制之间的转换:整数部分:对于整数部分,您需要将二进制数从右往左,每四位一组。如果剩余的位数不足四位,就需要在左边补零,直到每组都有四位。小数部分:对于小数部分,您需要将二进制数从左往右,每四位一组。如果剩余的位数不足四位,就需要在右边补零,......
  • 前端基础学习问题
    前端学习笔记前端学习的目录a标签和iframe的使用css类选择器及优先级css选择器细节点前端学习的目录前端基础是由html、css、js三部分构成,后又扩展到vue和react等框架。1.html(超文本标记语言)负责页面整个框架的搭建,分为:用于显示信息的标签和用于收集信息的标签。......
  • Pytorch框架与经典卷积神经网络学习Day4|VGG原理与实战
    目录跟学视频1.原理1.1VGG网络诞生背景 1.2VGG网络结构 1.3VGG总结2.实战2.1model.py2.2model_train.py2.3model_test.py跟学视频炮哥带你学_Pytorch框架与经典卷积神经网络与实战1.原理VGG(VisualGeometryGroup)是一个深度卷积神经网络架构,广泛应用于计算机......
  • C语言学习笔记:if语句讲解和月份的天数判断以及闰年判断示例
    if语句的格式是:if(关系运算或逻辑运算表达式){语句 }if后如果不加花括号{}则判定只和第一句语句相关:if(n==0)printf(“1”);printf(“2”);//这种写法,if判定与语句2无关 (n==0)是否为真语句2都会打印 为避免if的(关系语句)少打一个= 可以把常量放在左侧如:if(7==n){}当......
  • 字玩FontPlayer开发笔记10 Tauri2多窗口通信
    字玩FontPlayer开发笔记10Tauri2多窗口通信字玩FontPlayer是笔者开源的一款字体设计工具,使用Vue3+ElementUI开发,源代码:github|gitee笔记最近在使用Tauri进行打包应用,这两天在测试可编程脚本模块时,发现原有代码使用了window.open方法,在Tauri应用中一直触发不了新窗口......