首页 > 其他分享 >tensorflow从入门到精通——线性回归实现

tensorflow从入门到精通——线性回归实现

时间:2022-11-01 18:07:19浏览次数:47  
标签:sess run 入门 self tf bias 线性 tensorflow data


import tensorflow as tf
import os
os.environ['TF_LOG_MIN_LEVEL'] = '2'

class LinearRegression():

def __init__(self,data=None):
if data is None:
self.X,self.Y = self.gener_data()
else:
self.X,self.Y = data

self.weights = tf.Variable(initial_value=tf.random_normal(shape=[1,1],mean=0,stddev=0.1),dtype=tf.float32)
self.bias = tf.Variable(initial_value=tf.random_normal(shape=[1,1],mean=0,stddev=0.1),dtype=tf.float32)

with tf.get_default_graph().device("/gpu:0"):
self.y_ped = tf.matmul(self.X,self.weights)+self.bias

# 损失函数
self.loss = tf.reduce_mean(tf.square(self.Y-self.y_ped))
# 定义损失函数
self.optim = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(self.loss)
self.init_data = tf.global_variables_initializer()

def fit(self,epochs=1000):

print("model training....")
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True
)) as sess:
sess.run(self.init_data)
print("初始化变量:Weithts = %f, bias = %f, loss = %f" %
(sess.run(self.weights), sess.run(self.bias), sess.run(self.loss)))

for i in range(epochs):
#优化
sess.run(self.optim)
if (i + 1) % 1 == 0:
print("训练后第%d次后:Weithts = %f, bias = %f, loss = %f" %
(i + 1, sess.run(self.weights), sess.run(self.bias), sess.run(self.loss)))


def predict(self,x):

return tf.matmul(x,self.weights)+self.bias

def gener_data(self):

X = tf.random_normal(shape=[100,1],dtype=tf.float32)
noise = tf.random_normal(shape=[100,1],dtype=tf.float32)/1000.
# tf.case
Y = tf.matmul(X,[[2.0]])+0.5+noise

return X,Y

if __name__ == '__main__':
model = LinearRegression()
model.fit(epochs=10000)



```![在这里插入图片描述](/i/ll/?i=bd5efcf9b3514979aba18999f18d9b19.png?,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBA5bCP6ZmIcGhk,size_20,color_FFFFFF,t_30,g_se,x_16)
![在这里插入图片描述](/i/ll/?i=56a1000dc9ad4c1a8e7643e6f066448a.png?,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_20,text_Q1NETiBA5bCP6ZmIcGhk,size_20,color_FFFFFF,t_70,g_se,x_16)


标签:sess,run,入门,self,tf,bias,线性,tensorflow,data
From: https://blog.51cto.com/u_13859040/5814610

相关文章