首页 > 其他分享 >【机器学习】TensorFlow 202107090086

【机器学习】TensorFlow 202107090086

时间:2024-06-07 22:29:29浏览次数:9  
标签:loss 机器 损失 求得 202107090086 当前 tf TensorFlow mean

【源代码】

import tensorflow as tf

l2_reg = tf.keras.regularizers.l2(0.1)  # 设置模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(30, activation='relu',
                          kernel_initializer='he_normal', kernel_regularizer=l2_reg),
    tf.keras.layers.Dense(60, activation='relu',
                          kernel_initializer='he_normal', kernel_regularizer=l2_reg),
    tf.keras.layers.Dense(60, activation='relu',
                          kernel_initializer='he_normal', kernel_regularizer=l2_reg),
    tf.keras.layers.Dense(2, activation="softmax")
])


def random_batch(X, y, batch_size=32):  # 随机抽取数据
    idx = np.random.randint(len(X), size=batch_size)
    Xx = np.array([X[i] for i in idx])
    Yy = np.array([y[i] for i in idx])
    return Xx, Yy


def print_status_bar(iteration, total, loss, metrics=None):  # 输出状态
    #     print('iteration',iteration)
    #     print('total',total)
    #     print('loss',loss.result())

    metrics = "-".join(["{}:{:4f}".format(m.name, m.result()) for m in [loss] + (metrics or [])])
    #     print('metrics',metrics)
    #     print("===============")
    end = "" if iteration < total else "\n"
    print("\r{}/{}-".format(iteration, total) + metrics, end=end)


import numpy as np
import matplotlib.pyplot as plt

a = tf.random.normal([1000, 2], 8, 2)  # 生成数据a类
a_lable = tf.cast([[0] for i in range(1000)], dtype=tf.int64)

b = tf.random.normal([1000, 2], 1, 2)  # 生成数据b类
b_lable = tf.cast([[1] for i in range(1000)], dtype=tf.int64)
X_train = np.concatenate([a, b])  # 合并数据
y_train = np.concatenate([a_lable, b_lable])
y_train = tf.one_hot(y_train[:, 0], 2)  # 将标签合转化为one-hot编码

n_epochs = 5  # 迭代次数
batch_size = 32  # 批次大小
n_steps = int(len(X_train) / 32)  # 分批次
optimizer = tf.keras.optimizers.Nadam(lr=0.01)  # 算法优化器,以及学习率
loss_fn = tf.keras.losses.CategoricalCrossentropy()  # 损失函数
mean_loss = tf.keras.metrics.Mean()  # 平均损失函数
metrics = [tf.keras.metrics.MeanAbsoluteError()]  # 要计算的误差

for epoch in range(1, n_epochs + 1):  # 迭代次数
    print("Epoch{}/{}".format(epoch, n_epochs))
    for step in range(1, n_steps + 1):  # 分批次
        X_batch, y_batch = random_batch(X_train, y_train)  # 从数据集中随机抽取一个批次的训练数据
        with tf.GradientTape() as tape:
            y_pred = model(X_batch, training=True)  # 得到当前模型的结果
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))  # 计算损失tf.keras.losses.BinaryCrossentropy()
            loss = tf.add_n([main_loss] + model.losses)  # 计算损失2
        gradients = tape.gradient(loss, model.trainable_variables)  # 求梯度
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))  # 应用梯度,并更新
        mean_loss(loss)  # 评价梯度
        for metric in metrics:  # 计算其他指标 metrics=[tf.keras.metrics.MeanAbsoluteError()]
            metric(y_batch, y_pred)
        print_status_bar(step * batch_size, len(y_train), mean_loss, metrics)  # 输出误差以及批次进度

    print_status_bar(len(y_train), len(y_train), mean_loss, metrics)  # 输出最后一次误差及批次进度

model.predict(a)  # 预测a类数据
model.predict(b)  # 预测b类数据
np.sum(np.argmax(model.predict(b), axis=1)) / len(b)  # TP计算准确率0.999
np.sum(np.argmax(model.predict(a), axis=1)) / len(a)  # TP计算准确率0.999

【运行结果】

Epoch1/5
2000/2000-mean:14.044858-mean_absolute_error:0.182928
Epoch2/5
2000/2000-mean:8.709734-mean_absolute_error:0.1166409
Epoch3/5
2000/2000-mean:6.232692-mean_absolute_error:0.088220
Epoch4/5
2000/2000-mean:4.812209-mean_absolute_error:0.073772
Epoch5/5
2000/2000-mean:3.906441-mean_absolute_error:0.065139

[1]:

0.008

【源代码】

import tensorflow as tf
import matplotlib.pyplot as plt

x=tf.linspace(1.,10.,100)
x=tf.reshape(x,[1,100])
w=tf.Variable([[0.5]])
y=tf.matmul(tf.transpose(w),x)
w=tf.Variable([[0.1]])
a=tf.Variable([[0.001]])
LOSS=[]
for i in range(100):
    with tf.GradientTape() as tape:
        y_pre=tf.matmul(tf.transpose(w),x)
        loss=0.5*tf.reduce_mean(tf.square(y-y_pre))
        LOSS.append(loss)
    g=tape.gradient(loss,w)
    w=w-a*g
    w=tf.Variable(w)
    print("当前损失为:",loss.numpy(),"求得的w为:",w.numpy())
plt.plot(LOSS)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

【运行结果】

当前损失为: 2.9709094 求得的w为: [[0.11485454]]
当前损失为: 2.7543488 求得的w为: [[0.12915745]]
当前损失为: 2.5535743 求得的w为: [[0.1429292]]
当前损失为: 2.3674352 求得的w为: [[0.1561895]]
当前损失为: 2.1948645 求得的w为: [[0.16895737]]
当前损失为: 2.0348728 求得的w为: [[0.1812511]]
当前损失为: 1.8865434 求得的w为: [[0.19308826]]
当前损失为: 1.7490267 求得的w为: [[0.20448585]]
当前损失为: 1.6215336 求得的w为: [[0.21546017]]
当前损失为: 1.5033342 求得的w为: [[0.22602694]]
当前损失为: 1.3937508 求得的w为: [[0.2362013]]
当前损失为: 1.2921551 求得的w为: [[0.24599783]]
当前损失为: 1.1979653 求得的w为: [[0.25543055]]
当前损失为: 1.1106414 求得的w为: [[0.26451296]]
当前损失为: 1.0296828 求得的w为: [[0.2732581]]
当前损失为: 0.95462537 求得的w为: [[0.28167847]]
当前损失为: 0.8850394 求得的w为: [[0.28978613]]
当前损失为: 0.8205256 求得的w为: [[0.2975927]]
当前损失为: 0.7607146 求得的w为: [[0.30510938]]
当前损失为: 0.7052632 求得的w为: [[0.3123469]]
当前损失为: 0.6538541 求得的w为: [[0.31931567]]
当前损失为: 0.60619223 求得的w为: [[0.32602564]]
当前损失为: 0.5620047 求得的w为: [[0.33248642]]
当前损失为: 0.52103806 求得的w为: [[0.33870727]]
当前损失为: 0.48305768 求得的w为: [[0.3446971]]
当前损失为: 0.4478459 求得的w为: [[0.35046446]]
当前损失为: 0.41520083 求得的w为: [[0.35601768]]
当前损失为: 0.3849353 求得的w为: [[0.36136466]]
当前损失为: 0.356876 求得的w为: [[0.36651307]]
当前损失为: 0.33086202 求得的w为: [[0.3714703]]
当前损失为: 0.30674422 求得的w为: [[0.37624344]]
当前损失为: 0.28438443 求得的w为: [[0.38083932]]
当前损失为: 0.2636546 求得的w为: [[0.38526452]]
当前损失为: 0.24443586 求得的w为: [[0.38952538]]
当前损失为: 0.22661799 求得的w为: [[0.393628]]
当前损失为: 0.21009903 求得的w为: [[0.39757827]]
当前损失为: 0.19478416 求得的w为: [[0.40138185]]
当前损失为: 0.1805856 求得的w为: [[0.40504417]]
当前损失为: 0.16742207 求得的w为: [[0.40857047]]
当前损失为: 0.15521811 求得的w为: [[0.41196582]]
当前损失为: 0.14390373 求得的w为: [[0.41523507]]
当前损失为: 0.13341412 求得的w为: [[0.41838294]]
当前损失为: 0.12368905 求得的w为: [[0.4214139]]
当前损失为: 0.11467293 求得的w为: [[0.4243323]]
当前损失为: 0.10631403 求得的w为: [[0.42714232]]
当前损失为: 0.098564394 求得的w为: [[0.429848]]
当前损失为: 0.09137968 求得的w为: [[0.4324532]]
当前损失为: 0.084718674 求得的w为: [[0.43496162]]
当前损失为: 0.07854325 求得的w为: [[0.43737692]]
当前损失为: 0.07281793 求得的w为: [[0.4397025]]
当前损失为: 0.067509964 求得的w为: [[0.44194174]]
当前损失为: 0.06258892 求得的w为: [[0.44409782]]
当前损失为: 0.05802657 求得的w为: [[0.44617382]]
当前损失为: 0.053796817 求得的w为: [[0.44817272]]
当前损失为: 0.04987539 求得的w为: [[0.45009738]]
当前损失为: 0.04623981 求得的w为: [[0.45195058]]
当前损失为: 0.04286923 求得的w为: [[0.45373496]]
当前损失为: 0.039744332 求得的w为: [[0.45545307]]
当前损失为: 0.03684724 求得的w为: [[0.45710737]]
当前损失为: 0.034161333 求得的w为: [[0.45870024]]
当前损失为: 0.031671196 求得的w为: [[0.46023396]]
当前损失为: 0.029362578 求得的w为: [[0.46171072]]
当前损失为: 0.027222238 求得的w为: [[0.46313265]]
当前损失为: 0.025237901 求得的w为: [[0.46450177]]
当前损失为: 0.02339822 求得的w为: [[0.46582004]]
当前损失为: 0.02169264 求得的w为: [[0.46708935]]
当前损失为: 0.020111393 求得的w为: [[0.46831155]]
当前损失为: 0.018645389 求得的w为: [[0.46948835]]
当前损失为: 0.017286247 求得的w为: [[0.47062144]]
当前损失为: 0.016026199 求得的w为: [[0.47171244]]
当前损失为: 0.014858003 求得的w为: [[0.47276294]]
当前损失为: 0.013774947 求得的w为: [[0.47377443]]
当前损失为: 0.012770831 求得的w为: [[0.47474834]]
当前损失为: 0.011839931 求得的w为: [[0.4756861]]
当前损失为: 0.01097687 求得的w为: [[0.47658902]]
当前损失为: 0.010176734 求得的w为: [[0.47745842]]
当前损失为: 0.00943492 求得的w为: [[0.47829553]]
当前损失为: 0.00874717 求得的w为: [[0.47910157]]
当前损失为: 0.008109552 求得的w为: [[0.47987765]]
当前损失为: 0.007518423 求得的w为: [[0.4806249]]
当前损失为: 0.006970382 求得的w为: [[0.48134443]]
当前损失为: 0.0064622904 求得的w为: [[0.48203725]]
当前损失为: 0.005991219 求得的w为: [[0.4827043]]
当前损失为: 0.0055545014 求得的w为: [[0.4833466]]
当前损失为: 0.0051496117 求得的w为: [[0.48396507]]
当前损失为: 0.0047742343 求得的w为: [[0.48456055]]
当前损失为: 0.0044262214 求得的w为: [[0.48513392]]
当前损失为: 0.0041035754 求得的w为: [[0.48568597]]
当前损失为: 0.0038044578 求得的w为: [[0.48621756]]
当前损失为: 0.0035271323 求得的w为: [[0.48672938]]
当前损失为: 0.0032700289 求得的w为: [[0.4872222]]
当前损失为: 0.00303167 求得的w为: [[0.4876967]]
当前损失为: 0.002810685 求得的w为: [[0.4881536]]
当前损失为: 0.0026058045 求得的w为: [[0.48859355]]
当前损失为: 0.002415853 求得的w为: [[0.48901713]]
当前损失为: 0.002239759 求得的w为: [[0.489425]]
当前损失为: 0.0020764905 求得的w为: [[0.4898177]]
当前损失为: 0.0019251321 求得的w为: [[0.49019584]]
当前损失为: 0.0017848014 求得的w为: [[0.49055994]]
当前损失为: 0.0016547017 求得的w为: [[0.4909105]]

标签:loss,机器,损失,求得,202107090086,当前,tf,TensorFlow,mean
From: https://blog.csdn.net/2201_75425839/article/details/139536336

相关文章

  • ChatGPT-4o在临床医学日常工作、数据分析与可视化、机器学习建模中的技术
    2022年11月30日,可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT-3.5,将人工智能的发展推向了一个新的高度。2023年11月7日,OpenAI首届开发者大会被称为“科技界的春晚”,吸引了全球广大用户的关注,GPT商店更是显现了OpenAI旨在构建AI生态......
  • 滑坡、泥石流等地质灾害风险评价、基于机器学习的滑坡易发性分析技术
    入门篇,ArcGIS软件的快速入门与GIS数据源的获取与理解;方法篇,致灾因子提取方法、灾害危险性因子分析指标体系的建立方法和灾害危险性评价模型构建方法;拓展篇,GIS在灾害重建中的应用方法;高阶篇:Python环境中利用机器学习进行灾害易发性评价模型的建立与优化方法。原文链接:滑坡、泥......
  • 机器学习--有监督学习--算法整理
     整理原因:为了更好的理解学习算法为什么有用,还是决定认真看看推导公式和过程。以下是有监督学习线性回归的推导过程。算法目标:根据一组x和y的对应关系,找到他们的线性关系,得到拟合线性方程:y=ax+b,从而对于任意的自变量x,都可以预测到对应的因变量y的值。并且,要保证这个a,b足够可靠......
  • 厂区车间佩戴安全帽检测系统 TensorFlow
    厂区车间佩戴安全帽检测系统提升了工作人员安全帽佩戴和面部实名认证管理效率和监管水平。厂区车间佩戴安全帽检测系统根据搜集现场施工作业人员的脸部信息内容和监控画面视频图像检测优化算法,可以设置访问限制。假如作业人员不戴头盔,作业人员将被禁止进入施工区域,并会语音播报......
  • 机器学习笔记(2): Logistic 回归
    Logistic回归是线性回归中一个很重要的部分。Logistic函数:\[\sigma(x)=\frac{L}{1+\exp(-k(x-x_0))}\]其中:\(L\)表示最大值\(x_0\)表示对称中心\(k\)表示倾斜度一般来说,都将\(L\)设为\(1\),而\(k\)和\(x_0\)在参数中控制。认为特征只有一个,那么自......
  • 机器学习策略篇:详解进行误差分析(Carrying out error analysis)
    从一个例子开始讲吧。假设正在调试猫分类器,然后取得了90%准确率,相当于10%错误,,开发集上做到这样,这离希望的目标还有很远。也许的队员看了一下算法分类出错的例子,注意到算法将一些狗分类为猫,看看这两只狗,它们看起来是有点像猫,至少乍一看是。所以也许的队友给一个建议,如何针对狗的......
  • 机器学习-聚类算法
    1.有监督学习与无监督学习有监督:在训练集中给的数据中有X和Y,根据这些数据训练出一组参数对预测集进行预测无监督:在训练集中给的数据只有X没有Y,根据X数据找相似度参数来对预测集进行预测2.数据间的相似度2.1距离相似度:每一条数据可以理解为一个n维空间中的点,可以根据点点之......
  • 机器学习-支持向量机
    目录一支持向量机1.支持向量机SVM2构建svm目标函数3.拉格朗日乘法,kkt条件拉格朗日乘法:kkt条件 对偶问题 4.最小化SVM目标函数kkt条件: 对偶转换: 5软间隔及优化优化svm目标函数 构造拉格朗日函数对偶转换关系:求解结果:总结:都看到这里了点个赞吧! 一支持......
  • 在虚拟机上搭建 Docker Kafka 宿主机器程序无法访问解决方法
    1、问题描述在虚拟机CentOS-7上搭建的DockerKafka,docker内部可以创建Topic、可以生产者数据、可以消费数据,而在宿主机开发程序无法消费Docker Kafka的数据。1.1、运行情况[docker@localhost~]$dockerps-aCONTAINERIDIMAGECOMMAND......
  • 机器学习算法(一):1. numpy从零实现线性回归
    系列文章目录机器学习算法(一):1.numpy从零实现线性回归机器学习算法(一):2.线性回归之多项式回归(特征选取)@目录系列文章目录前言一、理论介绍二、代码实现1、导入库2、准备数据集3、定义预测函数(predict)4代价(损失)函数5计算参数梯度6批量梯度下降7训练8可视化一下损失总结前......