首页 > 其他分享 >实验10-使用keras完成线性回归

实验10-使用keras完成线性回归

时间:2024-05-14 22:09:34浏览次数:11  
标签:10 plt 训练 keras print test train 线性 model

VMware虚拟机 Ubuntu20-LTS

python3.6

tensorflow1.15.0

keras2.3.1

运行截图:

 

 

 

代码:

import numpy as np

np.random.seed(1337)
from keras.models import Sequential
from keras.layers import Dense
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt

# 创建数据集
# 在[-1,1]的区间内等间隔创建200个样本数
X = np.linspace(-1, 1, 200)
np.random.shuffle(X)  # 将数据集随机化
# np.random.normal(0, 0.05, (200, ))为输出服从均值为0,标准方差为0.05的200个正太分布数据
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,))  # 假设我们真实模型为:Y=0.5X+2
# 绘制数据集plt.scatter(X, Y)
plt.show()

X_train, Y_train = X[:160], Y[:160]  # 把前160个数据放到训练集
X_test, Y_test = X[160:], Y[160:]  # 把后40个点放到测试集

# 定义一个model,
model = Sequential()  # Keras有两种类型的模型,序贯模型(Sequential)和函数式模型
# 比较常用的是Sequential,它是单输入单输出的
model.add(Dense(output_dim=1, input_dim=1))  # 通过add()方法一层层添加模型
# Dense是全连接层,第一层需要定义输入,
# 第二层无需指定输入,一般第二层把第一层的输出作为输入

# 定义完模型就需要训练了,不过训练之前我们需要指定一些训练参数
# 通过compile()方法选择损失函数和优化器
# 这里我们用均方误差作为损失函数,随机梯度下降作为优化方法
model.compile(loss='mse', optimizer='sgd')

# 开始训练
print('Training -----------')
for step in range(301):
    cost = model.train_on_batch(X_train, Y_train)  # Keras有很多开始训练的函数,这里用train_on_batch()
    if step % 100 == 0:
        print('train cost: ', cost)

# 测试训练好的模型
print('\nTesting ------------')
cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()  # 查看训练出的网络参数
# 由于我们网络只有一层,且每次训练的输入只有一个,输出只有一个
# 因此第一层训练出Y=WX+B这个模型,其中W,b为训练出的参数
print('Weights=', W, '\nbiases=', b)

# plotting the prediction
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

#使用r2 score评估准确度
pred_acc = r2_score(Y_test, Y_pred)
print('pred_acc',pred_acc)

#保存模型
model.save('keras_linear.h5')

 

标签:10,plt,训练,keras,print,test,train,线性,model
From: https://www.cnblogs.com/liucaizhi/p/18192351

相关文章

  • 实验11-使用keras完成逻辑回归
    VMware虚拟机Ubuntu20-LTSpython3.6tensorflow1.15.0keras2.3.1运行截图:   代码:importnumpyasnpfromkeras.modelsimportSequentialfromkeras.layersimportDense,Dropout,Activation,Flattenimportmatplotlib.pyplotaspltfromsklearnimport......
  • 实验6-使用TensorFlow完成线性回归
    VMware虚拟机Ubuntu20-LTSpython3.6tensorflow1.15.0keras2.3.1运行截图:  代码: %matplotlibinlineimportnumpyasnpimporttensorflowastfimportmatplotlib.pyplotaspltplt.rcParams["figure.figsize"]=(14,8)n_observations=100xs=np.li......
  • KylinV10SP2实现ARM和x86架构系统PXE部署(S3)
    KylinV10SP2实现ARM和x86架构系统PXE部署(S3)本文介绍在esxi(虚拟化)中Centos7.9操作系统上部署PXE服务端,集成麒麟系统安装源,TFTP服务,DHCP服务,HTTP服务,能够向裸机发送PXE引导程序、Linux内核、启动菜单等数据,以及提供安装文件。系统引导模式分为uefi引导以及legacy引导,本文主要UEFI,......
  • 桌面图标间距Bug:Win10/Win11桌面图标占用空间变成长方形怎么办?
    阅读全文:https://itxiaozhang.com/win10-win11-desktop-icon-bug-rectangular-fix/此教程配合视频学习效果最佳,视频教程在文章末尾。问题描述在使用Windows10或Windows11操作系统时,桌面图标的间距突然变得很大,变成了长方形。该问题通常发生在修改屏幕分辨率、连接外部显示......
  • 洛谷题单指南-动态规划3-P1070 [NOIP2009 普及组] 道路游戏
    原题链接:https://www.luogu.com.cn/problem/P1070题意解读:1~n个环形机器人工厂,相邻工厂之间的道路是1~n,每个时刻可以从任意工厂购买机器人,走不超过p时间,不同工厂购买机器人花费不同的金币,不同时刻走到不同道路也能得到不同的金币,问一共m时间,最多可以得到多少金币(需减去购买机器人......
  • 5.10
    IPv6vs.IPv4我一直对IPv6这个名字感到困惑,因为我觉得IPv4名字来源于它用来表示32位的四个字节,所以IPv6应该被称为IP16。但实际上,这只是协议的版本号。在IPv4推出之前,曾存在过IPv1、IPv2和IPv3,它们主要用于内部研究IP协议,后来被我们现在的IPv4所取代。在上世纪80年代,还提出过IP......
  • 桌面图标间距Bug:Win10/Win11桌面图标占用空间变成长方形怎么办?
    阅读全文:https://itxiaozhang.com/win10-win11-desktop-icon-bug-rectangular-fix/此教程配合视频学习效果最佳,视频教程在文章末尾。问题描述在使用Windows10或Windows11操作系统时,桌面图标的间距突然变得很大,变成了长方形。该问题通常发生在修改屏幕分辨率、连接外部显示......
  • MCal工程通用计算式算量表V1.3.2.10 2024.5.14
     1、更新下tab菜单2、增加计算式结果四舍五入,四舍六入的设置,在显示效果-工程结果中选择3、次级计算式增加到20个,欢迎测试。下载地址:www.zawen.net         https://club.excelhome.net/thread-1644206-1-1.html......
  • Mysql的max()函数9大于10问题
    前言在公司老项目使用过程中都遇到过这个问题,所以这里记录下来问题描述使用系统中发现系统的字典新增之后排序不固定,于是查找问题,打开数据库发现sort大部分都是"10"mybatis中查询语句如下: `selectIFNULL(MAX(DIC_SORT),0)FROMDIC_INFOWHERE'ProjectId'=#{ProjectId}`......
  • 2024年AppScan 最新10.5.0破解版 附详细安装教程
     免责声明请勿利用文章内的相关技术从事非法测试。由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,请务必遵守网络安全法律法规。本文仅用于测试,请完成测试后24小时删除,请勿用于商业用途。如文中内容涉及侵权行......