首页 > 编程问答 >KerasRegressor 决定系数 R^2 分数

KerasRegressor 决定系数 R^2 分数

时间:2024-07-30 12:05:16浏览次数:4  
标签:python tensorflow machine-learning keras scikit-learn

我正在 Keras 中构建一个用于回归任务的小型神经网络,我想使用与 scikit-learn RandomForestRegressor :

系数 R^2 定义为|相同的精度指标| ||,其中 (1 - u/v) 是回归平方和 u ,而 ((y_true - y_pred) ** 2).sum() 是残差平方和 v ((y_true - y_true.mean()) ** 2).sum() 这是一个方便的指标,因为它显示的值高达1.0(类似于百分比准确度)分类)。我对 Keras 后端的使用是否符合我想要的准确度指标?

这似乎有效,没有任何错误,并且指标随着时间的推移逐渐增加到 1,但我想确保我正确实现了该指标。

def create_model():
    model = Sequential()
    model.add(Dense(10, input_dim=X.shape[1], activation="relu"))
    model.add(Dense(10, activation="relu"))
    model.add(Dense(1))

    # Compile model
    model.compile(loss="mean_squared_error", optimizer="adam", metrics=[det_coeff])
    return model

# Is this computing the right thing?
def det_coeff(y_true, y_pred):
    u = K.sum(K.square(y_true - y_pred))
    v = K.sum(K.square(y_true - K.mean(y_true)))
    return K.ones_like(v) - (u / v)

This appears to work in that nothing errors and the metric is increasing towards 1 over time, but I want to be sure I implemented the metric correctly.


关于 R² 指标的实现 几乎是正确的 。只需要对 K.mean(y_true) 进行一次小的修改即可。

问题在于,当前的实现计算的是 每个批次的平均值 ,而不是 整个数据集的全局平均值 。在 R² 的上下文中,需要将预测值与目标变量的全局平均值进行比较,而不仅仅是当前批次的平均值。

以下是如何修复它的方法:

def det_coeff(y_true, y_pred):
    SS_res =  K.sum(K.square( y_true - y_pred )) 
    SS_tot = K.sum(K.square( y_true - K.mean(y_true) ) ) 
    return ( 1 - SS_res/(SS_tot + K.epsilon()) )

解释:

  1. 使用 K.epsilon() : 我们在分母中添加了一个小的常数 K.epsilon() ,以防止出现除以零的情况,当 SS_tot 非常小时,可能会发生这种情况。

通过此更改,的 det_coeff 函数现在将计算作为 Keras 模型指标的 R² 分数。

其他建议:

  • 可以考虑使用 tf.keras.metrics.R_square ,它可以直接计算 R² 分数,而无需手动实现。
from tensorflow.keras.metrics import R_square

model.compile(loss="mean_squared_error", 
              optimizer="adam", 
              metrics=[R_square()])
  • 请记住,虽然 R² 是一个有用的指标,但它并不是评估回归模型的唯一指标。还应该考虑其他指标,例如均方误差 (MSE)、均方根误差 (RMSE) 和平均绝对误差 (MAE)。

标签:python,tensorflow,machine-learning,keras,scikit-learn
From: 45250100

相关文章

  • 在家用电脑上设置 Python 和 Jupyter,尝试打开 Jupyter 笔记本并显示错误,无法获取
    我有最新的Python版本3.12.4和以下版本的Jupyter:SelectedJupytercorepackages...IPython:8.26.0ipykernel:6.29.5ipywidgets:notinstalledjupyter_client:8.6.2jupyter_core:5.7.2jupyter_server:2.14.2jupyterlab......
  • Python - Reloading a module
    Eachmoduleisloadedintomemoryonlyonceduringaninterpretersessionorduringaprogramrun,regardlessofthenumberoftimesitisimportedintoaprogram.Ifmultipleimportsoccur,themodule’scodewillnotbeexecutedagainandagain.Suppose......
  • vscode python 3.7 pylance debugpy 插件 vsix
    可能报错  crashed5timesinthelast3minutes.Theserverwillnotberestarted.  ---pylance 可能报错  cannotreadpropertiesofundefinedreadingresolveEnvironment   --- debugger可能      vscodepython3.7调试没有反应......
  • Python获取秒级时间戳与毫秒级时间戳的方法[通俗易懂]
    参考资料:https://cloud.tencent.com/developer/article/21581481、获取秒级时间戳与毫秒级时间戳、微秒级时间戳代码语言:javascript复制importtimeimportdatetimet=time.time()print(t)#原始时间数据print(int(t))......
  • CEFPython
    在Tkinter界面中直接嵌入Selenium的浏览器视图并不是一件直接的事情,因为Selenium本身并不提供图形界面嵌入的功能。Selenium主要用于自动化web浏览器,但它并不直接控制浏览器窗口的显示方式,而是依赖于WebDriver来与浏览器交互。然而,你可以使用一些替代方案来在Tkinter应用中模拟或......
  • 《最新出炉》系列初窥篇-Python+Playwright自动化测试-58 - 文件下载
    1.简介前边几篇文章讲解完如何上传文件,既然有上传,那么就可能会有下载文件。因此宏哥就接着讲解和分享一下:自动化测试下载文件。可能有的小伙伴或者童鞋们会觉得这不是很简单吗,还用你介绍和讲解啊,不说就是访问到下载页面,然后定位到要下载的文件的下载按钮后,点击按钮就可以了。其实......
  • Python - Function Annotations
     deffunc(s:str,i:int,j:int)->str:returns[i:j]Theparametersissupposedtobeastring,soweplaceacolonaftertheparameternameandthenwritestr.Parametersiandjaresupposedtobeintegerssowewriteintforthem.Returntypeis......
  • 使用带有 pythonKit XCODE 的嵌入式 Python,在 iOS 应用程序中与 OpenCV-python 签名不
    我根据Beewares使用指南在XCODE中将Python嵌入到我的iOS项目中https://github.com/beeware/Python-Apple-support/blob/main/USAGE.md运行时,我得到pythonKit找不到由ultralytics导入的cv2错误。当我将OpenCV-python添加到我的app_packages文件夹时......
  • Python - Arguments and Parameters
    ParametersinFunctionDefinitionA.deffunc(name):MatchbypositionorbynameB.deffunc(name=value):DefaultargumentC.deffunc(*args):CollectextrapositionalargumentsintuplenamedargsD.deffunc(**kwargs):Collectextrakeywordargumentsi......
  • Python MySQL 无法连接,原因不明
    当我尝试使用python连接到我的MySQL数据库时,由于未知原因显示错误:dTraceback(mostrecentcalllast):File"/usr/local/bin/flask",line8,in<module>sys.exit(main())^^^^^^File"/usr/local/lib/python3.12/site-packages/flask/cli.py&......