首页 > 编程问答 >sklearn 的 DbScan 与 Rapids 的 Dbscan

sklearn 的 DbScan 与 Rapids 的 Dbscan

时间:2024-07-23 09:14:05浏览次数:13  
标签:python scikit-learn dbscan rapids cuml

我正在尝试与 GPU 配合使用的 DbScan Rapids 版本。

我用 Python 编写了一段简短的代码,用于查找 3D 随机生成点中的簇,当我使用该代码时,我意识到 CPU 的方式由于某种我不熟悉的原因,它比 GPU 更快:

from sklearn.cluster import DBSCAN
from cuml.cluster import DBSCAN as cuDBSCAN
import cupy as cp

pij_gpu = cp.asarray(pij.astype(np.float32))  # Transfer data to GPU
eps = 0.3
min_samples = 5
start = time.time()
cluster = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1)
cluster.fit(pij)
print(f"CPU: {time.time() - start}")

start = time.time()
cluster_gpu = cuDBSCAN(eps=eps, min_samples=min_samples)
cluster_gpu.fit(pij_gpu)
print(f"GPU: {time.time() - start}")

日志:

size is: (670761, 3)
CPU: 1.508012294769287
[W] [14:24:17.089902] Batch size limited by the chosen integer type (4 bytes). 6007 -> 3201. Using the larger integer type might result in better performance
GPU: 10.168060541152954

我打印了 pij 大小。

我错过了什么吗?

我浏览了互联网并使用 ChatGPT找到答案,但我找不到答案。

我注意到的一件事是,当点更接近时,GPU 往往比 CPU 快一点。

我有 nvidia 3090 和比较强的电脑。


观察到 RAPIDS 的 cuML DBSCAN 比 scikit-learn 的 CPU 版本慢,有几个原因。让我们来分析一下这些原因,并探讨一下如何改进:

  1. 数据传输开销: 正确地将数据传输到了 GPU (使用 cp.asarray ),但此过程本身就需要时间。对于较小的数据集,此开销可能会盖过 GPU 计算带来的好处。

  2. DBSCAN 算法的性质: DBSCAN 是一种基于算法,它涉及到基于密度的聚类,需要进行大量的最近邻搜索。虽然 GPU 擅长并行计算,但在某些情况下,基于树的 CPU 算法(例如 scikit-learn 中使用的算法)对于此类搜索来说效率更高,尤其是在数据集不太大的情况下。

  3. cuML DBSCAN 实现: cuML 仍在积极开发中,其某些算法(包括 DBSCAN)可能尚未针对所有情况进行全面优化。

  4. 批处理大小警告: 收到的警告表明,所选数据类型限制了批处理大小。这可能会影响 cuML DBSCAN 的性能。

可能的改进:

  1. 更大的数据集: 尝试使用更大的数据集。随着数据集大小的增加,GPU 加速的优势将更加明显。

  2. 参数调整: eps min_samples 参数进行试验。这些参数会影响 DBSCAN 算法的行为,并且 GPU 和 CPU 实现的最佳值可能不同。

  3. 不同的 GPU: 与 3090 相比,功能更强大的 GPU 可能会显示出更好的性能。

  4. cuML 版本: 确保使用的是最新版本的 cuML,因为可能会进行性能改进。

  5. 替代算法: 对于的特定数据集和用例,其他聚类算法(例如 K 均值或基于树的方法)可能更适合于 GPU 加速。cuML 提供了这些算法的 GPU 加速版本。

基准测试:

为了更好地了解性能差异,请对不同数据集大小和参数值进行基准测试,比较 cuML DBSCAN 和 scikit-learn DBSCAN。这将有助于确定哪种实现对的特定用例更有利。

总而言之, 虽然 GPU 在处理许多机器学习任务时可以提供显著的加速,但并非所有算法或数据集大小都能从中平等受益。对于像 DBSCAN 这样的算法,数据传输开销、算法的性质以及特定的实现细节都会影响 CPU 和 GPU 性能之间的相对速度。

标签:python,scikit-learn,dbscan,rapids,cuml
From: 78754311

相关文章

  • python接口自动化(四十)- logger 日志 - 下(超详解)
    宏哥微信粉丝群:https://bbs.csdn.net/topics/618423372 有兴趣的可以扫码加入 1.简介按照上一篇的计划,这一篇给小伙伴们讲解一下:(1)多模块使用logging,(2)通过文件配置logging模块,(3)自己封装一个日志(logging)类。可能有的小伙伴在这里会有个疑问一个logging为什么分两篇的篇幅......
  • Python面试题:使用NumPy进行高效数组运算
    NumPy是Python中进行高效数组运算的基础库。以下是一些示例,展示了如何使用NumPy进行高效的数组运算,包括创建数组、数组操作、数学运算以及一些高级操作。安装NumPy如果你还没有安装NumPy,可以通过以下命令进行安装:pipinstallnumpy示例代码1.创建数组import......
  • Python面试题:使用Matplotlib和Seaborn进行数据可视化
    使用Matplotlib和Seaborn进行数据可视化是数据分析中非常重要的一部分。以下示例展示了如何使用这两个库来创建各种图表,包括基本的线图、柱状图、散点图和高级的分类数据可视化图表。安装Matplotlib和Seaborn如果你还没有安装这两个库,可以使用以下命令进行安装:pipins......
  • 20、Python之容器:红楼主角都有谁?10行代码生成《红楼梦》词云图
    引言Python系列前面的文章中,我们介绍了Python中容器的基本使用,上一篇中,我们又重点介绍了Counter计数器的使用。这些介绍,应该足以应付日常的工作需求了。在今天的文章中,我想以词云图的生成这个综合案例,巩固一下前面关于容器、字典推导式、Counter的使用。同时,介绍两个比较好......
  • 使用snowflake.connector 3.0.3上的密钥对从python连接到snowflake
    我正在尝试使用Snowflake.connector包从我的Python代码连接到Snowflake。不幸的是,由于遗留代码,我只能在python3.7.3上使用Snowflake连接器版本3.0.3,并且无法升级我确实设法从我自己的计算机进行连接,使用:con=Snowflake.connector。连接(...私有密钥文件=......
  • python 文件(txt)操作
    我有一个txt文件,其中包含一列(大约2000个)单词。每个单词只用一个新行分隔。我想将所有这些单词存储在一个数组中words.txt文件的示例:applebananaorange我尝试过的代码:importrandomwithopen('E:/Code/learn/Projects/word-guessing-game/words.txt','r')a......
  • matplotlib与Python不兼容问题
    “我使用的是Matplotlib版本3.7.2。Anaconda的Spyder中的Python版本是3.8.18。当我运行importmatplotlib.pyplotasplt行时,出现错误module'matplotlib'hasnoattribute'rcParams'。当我尝试将Matplotlib和matplotlib.base包从当前的3.7.2版本升级到3.8......
  • 使用 json 配置文件进行 Python 日志记录
    我玩弄了日志模块,发现了一些我不太清楚的事情。首先,因为这是一个大程序,我希望代码尽可能简单,因此我使用了json配置文件.{"version":1,"disable_existing_loggers":false,"formatters":{"simple":{"format":"%(asctime)s-%(name)s......
  • Python随机库:从帕累托分布进行模拟(使用形状和尺度参数)
    根据Python文档,random.paretovariate(alpha)模拟帕累托分布,其中alpha是形状参数。但帕累托分布同时采用形状和尺度参数。如何从指定这两个参数的分布中进行模拟?你说的对,帕累托分布是由形状和尺度参数定义的,而random.paretovariate(alpha)函数只接受形状参数......
  • 跳过 Python Spark Pyspark Databricks 未知字段异常中的不良记录
    我想知道是否有人知道如何跳过我们从json文件获取的记录这是错误[UNKNOWN_FIELD_EXCEPTION.NEW_FIELDS_IN_RECORD_WITH_FILE_PATH]在解析过程中遇到未知字段:这是失败的代码sent=spark.readStream.format('cloudFiles')\.option('cloudFiles.format','json')......