首页 > 编程问答 >TruePositive 如何是 keras.metrics.TruePositives 中的十进制数?

TruePositive 如何是 keras.metrics.TruePositives 中的十进制数?

时间:2024-07-24 16:22:41浏览次数:8  
标签:python tensorflow machine-learning keras metrics

我正在尝试在图像数据集上训练 CNN 模型,但我被获取 TruePositives、TrueNegatives、FalsePositives 和 FalseNegatives 的十进制值所困扰。这怎么可能?

ERROR sample
Epoch 1/3
36/36 ━━━━━━━━━━━━━━━━━━━━ 69s 2s/step - false_negatives: 30.1351 - false_positives: 35.3784 - loss: 2.1995 - true_negatives: 389.0540 - true_positives: 437.6487

有一些 (tp+tn+fp+tn)不等于样本总数。

完整代码


import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.layers import Dense,Flatten,InputLayer,Conv2D,MaxPooling2D,Concatenate,Input,BatchNormalization
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.losses import BinaryCrossentropy,CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from sklearn.metrics import classification_report
from tensorflow.keras.callbacks import EarlyStopping
datagen=ImageDataGenerator(rescale=1.0/255.0)
train_gen=datagen.flow_from_directory('train',class_mode='binary',
                                      target_size=(224,224),batch_size=32,shuffle=True)

output
Found 1146 images belonging to 2 classes.
tp = tf.keras.metrics.TruePositives()
tn = tf.keras.metrics.TrueNegatives()
fp = tf.keras.metrics.FalsePositives()
fn = tf.keras.metrics.FalseNegatives()
tp.update_state([0.4, .9, .7, .8], [1.0, 0.0, 1.0, 1.0])
tp.result()
output
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
model_input=Input(shape=(224,224,3))

x=Conv2D(filters=32, kernel_size=(3,3),activation='relu',padding='valid')(model_input)
x=MaxPooling2D(pool_size=(2,2),strides=2)(x)
x=Conv2D(filters=64, kernel_size=(3,3),activation='relu',padding='valid')(x)
x=MaxPooling2D(pool_size=(2,2),strides=2)(x)
x=BatchNormalization()(x)
x=Conv2D(filters=64, kernel_size=(3,3),activation='relu',padding='valid')(x)
x=MaxPooling2D(pool_size=(2,2),strides=2)(x)
x=BatchNormalization()(x)
x=Flatten()(x)
x=Dense(units=1000,activation='relu')(x)
output=Dense(units=1,activation='sigmoid')(x)
model=Model(inputs=model_input,outputs=output)
model.compile(optimizer=Adam(),loss=BinaryCrossentropy(),metrics=[tp,fp,fn,tn])
early_stopping = EarlyStopping(monitor='val_loss', patience=2,restore_best_weights=True)


history=model.fit(x=train_gen,epochs=3,callbacks=[early_stopping])

十进制值错误

Epoch 1/3
36/36 ━━━━━━━━━━━━━━━━━━━━ 69s 2s/step - false_negatives: 30.1351 - false_positives: 35.3784 - loss: 2.1995 - true_negatives: 389.0540 - true_positives: 437.6487
Epoch 2/3
36/36 ━━━━━━━━━━━━━━━━━━━━ 61s 2s/step - false_negatives: 7.8378 - false_positives: 13.5135 - loss: 0.1692 - true_negatives: 283.1081 - true_positives: 300.4054
Epoch 3/3
36/36 ━━━━━━━━━━━━━━━━━━━━ 65s 2s/step - false_negatives: 2.3243 - false_positives: 3.0811 - loss: 0.0546 - true_negatives: 289.8108 - true_positives: 308.3513


在 Keras 中看到的十进制值对于 TruePositives TrueNegatives FalsePositives FalseNegatives 是因为这些指标在每个时期结束时 平均

Keras 默认情况下不计算每个时期的指标的总和。相反,它计算指标的 运行平均值 ,因为它看到了更多的小批量数据。这就是为什么看到十进制值而不是整数的原因。

让我们分解一下:

  1. 小批量更新: 在每个时期内,的数据集被分成多个小批量(在本例中为 32)。在处理完每个小批量后,都会根据到目前为止看到的数据计算并更新指标(TP、TN、FP、FN)。
  2. 运行平均值: 为了给出指标值的表示,Keras 维护一个运行平均值。使用此公式更新此运行平均值:

当前平均值 = ((当前平均值 * (先前看到的样本数)) + (当前小批量指标值)) / (先前看到的样本数 + 当前小批量大小)

  1. 并非总和: 重要的是,在每个时期结束时显示的值不是所有样本的简单总和,而是在整个时期内计算出的这些指标的平均值的表示。

为什么这些数字不加起来?

正确地观察到 (tp + tn + fp + fn) 的总和不等于样本总数。这是因为:

  • 平均值,而非总和: 如前所述,显示的值是平均值,而非总和。
  • 类别不平衡: 如果在类别之间的数据不平衡(例如,一个类别中的图像比另一个类别中的多得多),那么即使在考虑平均值后,这些数字也可能不会直接对应于类别总数。

如果希望获得每个时期的指标总和,则需要实现自定义回调函数。 但是,出于大多数目的,使用 Keras 提供的运行平均值就足够了,并且可以很好地了解模型的性能。

标签:python,tensorflow,machine-learning,keras,metrics
From: 78787112

相关文章

  • 将 Python 类型提示限制为元组中的值
    我的方法之一采用status中使用的filter()参数。此参数与定义如下的模型字段相关:STATUS_CHOICES=((1,_("draft")),(2,_("private")),(3,_("published")),)classMyModel(Model):status=models.PositiveSmallIntegerFi......
  • 在python中查找区间数据的中位数
    我正在探索不同的python库,我想知道如何找到分组数据集的近似中值。这里有一个表格供参考。年龄频率1-1012310-203502......
  • 比较Python中的字符串统一特殊字符
    也许我可以使用更好的英语,但我想要的是忽略单词中的重音(和类似的),所以:renè、rené、rene'和rene应该是相同的,所以应该mañana和manana或even-distribuited和evendistribuited,可能还有sho......
  • 如何使用 Python 脚本从客户账单电子邮件中获取订单 ID - WooCommerce API
    我想创建一个python脚本,返回只知道客户的账单电子邮件的订单。我尝试这样做,但返回所有最近的订单:fromwoocommerceimportAPIwcapi=API(url="https://siteexample.com",consumer_key="ck_xxx",consumer_secret="cs_xxx",version="wc/v3")......
  • python基础理论小总结
    1.python语言的特性Python是一门解释型语言,简单清晰,开源免费,跨平台,有大量第三方库辅助开发,支持面向对象与自动垃圾回收,方便与其他编程语言相互调用。Python在数据采集、人工智能、WEB后台开发、自动化运维、测试等方向应用广泛。2.解释型语言和编译型语言的区别执行方式不......
  • python编码规范
    本篇讲的是代码格式化的问题,解决格式化的方法在最下方,不想看内容的,滑到最下方就好了。一、变量的命名规则1.组成:字母、数字、下划线2.不可以以数字开头3.不建议使用下划线开头4.命名需见名知意5.不要与关键字重名。如何查找所有关键字?importkeywordprint(keyword.k......
  • Python爬虫开发中的常用库与框架安装指南
    在Python爬虫开发中,选择合适的库和框架可以大大提高开发效率和爬虫的性能。本文将介绍一些常用的解析库、请求库、储存库、Web库、App爬取库以及爬虫框架,并展示如何使用pip命令进行安装。一、解析库1.BeautifulSoupBeautifulSoup是一个用于从HTML或XML文件中提取数据的Pyth......
  • 如何在Python中的指定项目之后添加新项目到嵌套列表?
    给定的列表是这样的。list1=[10,20,[300,400,[5000,6000],500],30,40]预期输出是这样的。我知道这是一个非常基本的问题,但我很困惑。输出:[10,20,[300,400,[5000,6000,7000],500],30,40]我希望有人能帮助我解决这个问题。并解释了嵌套列表的插入功......
  • python带界面实现word文档比对功能
    python实现word文档比对的功能较简单,笔者这里将其界面话,可以指定输入比对的文档,相似度,最小相似参数等。输出的结果以word的形式保存,重复部分会标出,基本实现了商业软件的功能。先看界面这里不废话了,直接给出全部源码,觉得好的点个赞。程序打包的话,自己百度。fromtkinterimp......
  • 具有固定字典键的 Python 函数返回类型提示
    我有一个函数返回一个始终具有相同键的字典(通过网络发送并使用json进行“字符串化”)。基本上我的函数看起来像这样:defgetTemps(self)->dict:"""getroomandcputemperaturein°Caswellashumidityin%"""#sendtemperaturerequesttoserve......