首页 > 其他分享 >用TensorFlow实现线性回归

用TensorFlow实现线性回归

时间:2024-08-24 20:26:10浏览次数:6  
标签:loss keras 回归 tf shape 线性 TensorFlow net data

说明

本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然,如果您已经安装了tensorflow,可以采用python直接运行。

贡献

学习时采取动手学深度学习第二版作为教材,但由于本书通过引入d2l(著者自写库)进行深度学习,我希望将d2l的影响去掉,即不使用d2l,使用tensorflow,这一点通过查询GitHub中d2l库提供的相关函数尝试进行实现。

如果本系列文章具有良好表现,将译为英文版上传至Github。

预备知识

学习本篇文章之前,您最好具有以下基础知识:

  1. 线性回归的基础知识
  2. python的基础知识

基本原理 

使用一个仿射变换,通过y=wx+b的模型来对数据进行预测(w和x均为矩阵,大小取决于输入规模),反向传播采用随机梯度下降对参数进行更新,参数包括w和b,即权重和偏差。

实现过程

生成数据集

只需要引入tensorflow即可,synthetic_data()函数将初始化X和Y,即通过真实的权重和偏差值生成数据集。

import tensorflow as tf

def synthetic_data(w, b, num_examples):
    X = tf.zeros((num_examples, w.shape[0]))
    X += tf.random.normal(shape=X.shape)
    y = tf.matmul(X, tf.reshape(w, (-1, 1))) + b
    y += tf.random.normal(shape=y.shape, stddev=0.01)
    y = tf.reshape(y, (-1, 1))
    return X, y

true_w = tf.constant([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

读取数据集

加载刚刚生成的数据集,is_train表示是否进行打乱,默认对数据进行打乱处理,使用load_array函数加载数据集。

def load_array(data_arrays, batch_size, is_train=True):
    dataset = tf.data.Dataset.from_tensor_slices(data_arrays)
    if is_train:
        dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 10
data_iter = load_array((features, labels), batch_size)

定义模型

模型使用keras API实现,keras是tensorflow中机器学习相关的库。先使用Sequential类定义承载容器,之后添加一个单神经元的全连接层。在TensorFlow中,Sequential表示容器相关的类,layer表示层相关的类。线性回归只需要通过keras中的单神经元的全连接层即可实现,神经元的值即为输出结果。

net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1))

示例的线性回归仅有一个输入X,实际在其他线性回归过程中,很有可能有多个x及其对应的w,但keras的代码均不会发生改变,因为keras的Dense类可以自动判断输入的个数。 

初始化模型参数 

stddev表示标准差,initializer生成一个标准差为1,均值为0的正态分布。在构建全连接层时,使用该正态分布进行初始化。

initializer = tf.initializers.RandomNormal(stddev=0.01)
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1, kernel_initializer=initializer))

定义损失函数和优化算法 

损失函数使用平方损失函数进行计算,训练时使用小批量随机梯度下降SGD方法进行训练,学习率为0.03。

loss = tf.keras.losses.MeanSquaredError()
trainer = tf.keras.optimizers.SGD(learning_rate=0.03)

训练

运行以下代码可以观察训练结果。运行轮次为3轮,每一轮对所有训练集数据进行学习。计算w和b的梯度值,使用梯度下降更新权重w和偏差b。每一轮输出损失函数的值,最终显示权重和偏差的估计误差。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        with tf.GradientTape() as tape:
            l = loss(net(X, training=True), y)
        grads = tape.gradient(l, net.trainable_variables)
        trainer.apply_gradients(zip(grads, net.trainable_variables))
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')
w = net.get_weights()[0]
print('w的估计误差:', true_w - tf.reshape(w, true_w.shape))
b = net.get_weights()[1]
print('b的估计误差:', true_b - b)

运行结果

epoch 1, loss 0.000194

epoch 2, loss 0.000091

epoch 3, loss 0.000091

w的估计误差: tf.Tensor([-0.00026917 0.00094557], shape=(2,), dtype=float32)

b的估计误差: [4.7683716e-06]

 改进尝试

  1. 更改SGD优化算法为Adam
  2. 更改MeanSquaredError为其他损失函数

对于上述改进,损失均有显著增加,表明原有方法已为最好方法。

标签:loss,keras,回归,tf,shape,线性,TensorFlow,net,data
From: https://blog.csdn.net/2301_79335566/article/details/141434340

相关文章

  • 线性dp:LeetCode674. 最长连续递增序列
    LeetCode674.最长连续递增序列阅读本文之前,需要先了解“动态规划方法论”,这在我的文章以前有讲过链接:动态规划方法论本文之前也讲过一篇文章:最长递增子序列,这道题,阅读本文的同时可以与“最长递增子序列进行对比”,这样更能对比二者的区别!LeetCode300.最长递增子序列-To......
  • 【全面指导】线性代数如何高效备考?选择哪本习题集?
    作为一个过来人,在备考过程中,我发现线性代数这是个不容小觑的科目,在考研数学一二三中都占比20%,其复习策略和方法对最终成绩起到了决定性作用。那么,如何选择适合的习题集?怎样制定有效的复习计划?这些问题都是我们必须认真思考和了解的。今天,我将分享我的备考经验,从复习书籍选择到......
  • Mac M1用tensorflow中的Keras进行基本图像分类
    一.为什么要进行图像分类、图像识别目的是为了利用计算机对图像进行处理、分析和理解,让计算机能够像人类一样理解和解释图像中的内容。‌这一技术的应用范围广泛,包括但不限于人脸识别和商品识别。人脸识别技术主要应用于安全检查、身份核验与移动支付等领域,而商品识别则广......
  • TensorFlow 的基本概念和使用场景
    TensorFlow是一个开源的机器学习框架,由Google开发和维护。它允许开发者使用图形计算的方式构建和训练机器学习模型。TensorFlow的基本概念如下:张量(Tensor):TensorFlow使用张量来表示数据。张量是多维数组,在计算图中流动,是TensorFlow的基本数据单元。张量可以是标量(0维数组)、......
  • 线性dp:最长公共子串
    最长公共子串本文讲解的题与leetcode718.最长重复子数组,题意一模一样,阅读完本文以后可以去挑战这题。力扣链接题目叙述:给定两个字符串,输出其最长公共子串的长度。输入ABACCBAACCAB输出3解释最长公共子串是ACC,其长度为3。与最长公共子序列的区别公共子串:字符必须......
  • 【漫谈C语言和嵌入式028】稳压器的选择之道:线性稳压器与开关稳压器的深入比较
            在电子电路设计中,稳压器(Regulator)是不可或缺的组件,用于提供稳定的输出电压以满足电路的需求。稳压器的种类多种多样,其中最常见的两大类是线性稳压器(LinearRegulator)和开关稳压器(SwitchingRegulator)。它们在工作原理、效率、复杂性等方面各具特点,适用于不同的......
  • 线性dp:大盗阿福(打家劫舍)
    大盗阿福本题与leetcode198题——打家劫舍的题意一模一样,阅读完本文以后可以尝试以下题目力扣题目链接)题目叙述:阿福是一名经验丰富的大盗。趁着月黑风高,阿福打算今晚洗劫一条街上的店铺。这条街上一共有N家店铺,每家店中都有一些现金。阿福事先调查得知,只有当他同时洗劫了两......
  • 线性dp:编辑距离
    编辑距离本题与力扣72.编辑距离题意一样,阅读完本文可以尝试leetcode72.力扣题目链接题目叙述输入两个字符串a,b。输出从字符串a修改到字符串b时的编辑距离输入NOTVLOVER输出4题目解释:动态规划思路这个问题显然是一个最优解问题,我们可以考虑动态规划的思路,那么我......
  • 回归预测|基于卷积神经网络-长短期记忆网络-自注意力机制的数据回归预测Python程序 多
    回归预测|基于卷积神经网络-长短期记忆网络-自注意力机制的数据回归预测Python程序多特征输入单输出CNN-LSTM-Attention文章目录前言回归预测|基于卷积神经网络-长短期记忆网络-自注意力机制的数据回归预测Python程序多特征输入单输出CNN-LSTM-Attention一、CNN-......
  • 回归预测|基于北方苍鹰优化-卷积神经网络-双向长短期记忆网络-自注意力机制的数据回归
    **回归预测|基于北方苍鹰优化-卷积神经网络-双向长短期记忆网络-自注意力机制的数据回归预测Matlab程序多特征输入单输出含基础模型NGO-CNN-BiLSTM-Attention**文章目录前言回归预测|基于北方苍鹰优化-卷积神经网络-双向长短期记忆网络-自注意力机制的数据回归预测M......