首页 > 其他分享 >【rust】《Rust深度学习[6]-简单实现逻辑回归(Linfa)》

【rust】《Rust深度学习[6]-简单实现逻辑回归(Linfa)》

时间:2024-04-23 14:14:13浏览次数:31  
标签:matrix max threshold train let rust confusion Linfa Rust

什么是Linfa

Linfa 是一组Rust高级库的集合,提供了常用的数据处理方法和机器学习算法。Linfa对标Python上的 scikit-learn ,专注于日常机器学习任务常用的预处理任务和经典机器学习算法,目前Linfa已经实现了scikit-learn中的全部算法。

项目结构

依赖

[package]
name = "rust-ml-example"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
linfa = {version="0.6.0"}
linfa-logistic = {version="0.6.0"}
ndarray = { version = "0.15.6", default-features = false }
ndarray-csv = "0.5.1"
csv = "1.1"
plotters = { version = "^0.3.0" }
rand="0.8.5"

代码

use csv::ReaderBuilder;
use linfa::prelude::*;
use linfa_logistic::LogisticRegression;
use ndarray::{prelude::*, OwnedRepr};
use ndarray_csv::Array2Reader;
use plotters::prelude::*
// 训练模型 - 程序入口
fn main() {
    // 读取训练数据集
    let train = load_data("data/train.csv");
    // 读取测试数据集
    let test = load_data("data/test.csv");

    // // 获取训练集中的数据集
    // let features = train.nfeatures();
    // // 获取训练集中的标签集
    // let targets = train.ntargets();
    // println!("training with {} samples, testing with {} samples, {} features and {} target",
    //     train.nsamples(),test.nsamples(),features,targets);

    // println!("plotting data...");
    // 生成散点分布图
    plot_data(&train);

    // println!("training and testing model...");
    // 随机设置初始决策阈值与最大回归次数
    // 调用逻辑回归返回混淆矩阵数据
    let mut max_accuracy_confusion_matrix = iterate_with_values(&train, &test, 0.01, 100);

    // 临时决策阈值
    let mut best_threshold = 0.0;
    // 临时最大回归次数
    let mut best_max_iterations = 0;
    // 循环用的临时决策阈值
    let mut threshold = 0.02;

    // 迭代1000 - 4999 每500步长取一次值运算
    // 500步长方式递增回归次数
    for max_iterations in (1000..5000).step_by(500) {
        // 决策阈值小于1则一直循环
        while threshold < 1.0 {
            // 调用逻辑回归返回混淆矩阵数据
            let confusion_matrix = iterate_with_values(&train, &test, threshold, max_iterations);
            // 此次混淆矩阵的精确值 大于 随机设置参数调用的混淆矩阵的精确值
            if confusion_matrix.accuracy() > max_accuracy_confusion_matrix.accuracy() {
                // 覆盖随机参数调用的混淆矩阵结果
                max_accuracy_confusion_matrix = confusion_matrix;
                // 覆盖决策阈值
                best_threshold = threshold;
                // 覆盖最大回归次数
                best_max_iterations = max_iterations;
            }
            // 循环用的临时决策阈值 + 0.01
            threshold += 0.01;
        }
        // 还原 循环用的临时决策阈值 临时值
        threshold = 0.02;
    }
    println!("最精确混淆矩阵: {:?}", max_accuracy_confusion_matrix);
    println!("最优迭代次数: {}\n最优决策阈值: {}", best_max_iterations, best_threshold);
    println!("精确率:\t{}", max_accuracy_confusion_matrix.accuracy(),);
    println!("准确率:\t{}", max_accuracy_confusion_matrix.precision(),);
    println!("召回率:\t{}", max_accuracy_confusion_matrix.recall(),);
}


// csv文件数据读取
// 返回DataSet数据集
fn load_data(path: &str) -> Dataset<f64, &'static str, Ix1> {
    // 获取csv文件并验证是否存在数据
    let mut reader = ReaderBuilder::new()
        // 是否具有表头
        .has_headers(false)
        // 分隔符
        .delimiter(b',')
        // 文件路径
        .from_path(path)
        // 存在就输出并提示
        .expect("can create reader");

    // 获取csv文件内的数据
    let array: Array2<f64> = reader
        // 动态序列化数据(转ndarray格式列表)
        .deserialize_array2_dynamic()
        // 存在就输出并提示
        .expect("can deserialize array");

    // 结构转化为 (数据 , 标签)
    let (data, targets) = (
        // 取所有列内的0,1下标的两个数据
        // to_owned 拷贝原始数据并获取所有权
        // 转换成一维数组
        array.slice(s![.., 0..2]).to_owned(),
        // 获取第3列的数据, 就是0或1的数据
        array.column(2).to_owned(),
    );
    // 作散点图用的标签
    // 特征名
    let feature_names = vec!["test 1", "test 2"];
    // 用ndarray创建Linfa的Dataset数据集
    // 两个参数(数据, 标签)
    Dataset::new(data, targets)
        // 这里循环标签列表的第一个参数
        .map_targets(|x| {
            // 解引用x转usize类型 并且 ==1
            if *x as usize == 1 {
                // 录取
                "accepted"
            } else {
                // 淘汰
                "denied"
            }
        })
        // 对数据命名(重写vec列表)并将标签列表转为字符列表
        .with_feature_names(feature_names)
}

// 数据绘图
// 生成散点分布图片
fn plot_data(
    train: &DatasetBase<
        ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
        ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
    >,
) {
    // 录取列表的数据集
    let mut positive = vec![];
    // 淘汰列表的数据集
    let mut negative = vec![];

    // 获取Dataset中数据列表
    let records = train.records().clone().into_raw_vec();
    // 将一维数组通过固定大小分组转化为二维数组
    // chunks(size:usize) 对数组/列表中的数据按指定大小size分块
    let features: Vec<&[f64]> = records.chunks(2).collect();
    // 获取Dataset中标签列表
    let targets = train.targets().clone().into_raw_vec();
    // 迭代二维数组
    for i in 0..features.len() {
        // 获取内部的数组
        let feature = features.get(i).expect("feature exists");
        // 判断标签是否为 录取
        if let Some(&"accepted") = targets.get(i) {
            // 录取列表数据集(第一次成绩, 第二次成绩)
            positive.push((feature[0], feature[1]));
        } else {
            // 淘汰列表数据集(第一次成绩, 第二次成绩)
            negative.push((feature[0], feature[1]));
        }
    }
    // 创建绘图文件
    let root_area = BitMapBackend::new("plot.png", (600, 400)).into_drawing_area();
    // 设置图片白色背景
    root_area.fill(&WHITE).unwrap();
    // 开始绘制基础图表结构
    let mut ctx = ChartBuilder::on(&root_area)
        // 图表左侧与图片边缘的间距
        .set_label_area_size(LabelAreaPosition::Left, 40)
        // 图表底部与图片边缘的间距
        .set_label_area_size(LabelAreaPosition::Bottom, 40)
        // 图表名称  (字体样式, 字体大小)
        .caption("分数分布", ("sans-serif", 30))
        // 构建二维图像, x轴 0.0 - 120.0; y轴 0.0 - 120.0;
        .build_cartesian_2d(0.0..120.0, 0.0..120.0)
        .unwrap();
    // 配置网格线
    ctx.configure_mesh().draw().unwrap();
    // 绘制录取的散点
    ctx.draw_series(
        positive
            // 迭代
            .iter()
            // 绘制三角形(坐标数据,大小,颜色)
            .map(|point| TriangleMarker::new(*point, 5, &BLUE)),
    )
    .unwrap();
    // 绘制淘汰的散点
    ctx.draw_series(
        negative
            // 迭代
            .iter()
            // 绘制圆圈(坐标数据,大小,颜色)
            .map(|point| Circle::new(*point, 5, &RED)),
    )
    .unwrap();
}

// 逻辑回归(循环迭代值训练模型)
// 迭代次数(max_iterations) 和 决策阈值(threshold)
//      我们需要反复多次测试以找到这两个参数的最有值,为此我们需要构造循环多次调用上面的过程
// 返回混淆矩阵
fn iterate_with_values(
    // 训练集
    train: &DatasetBase<
        ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
        ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
    >,
    // 测试集
    test: &DatasetBase<
        ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
        ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
    >,
    // 概率阈值
    threshold: f64,
    // 最大回归次数
    max_iterations: u64,
) -> ConfusionMatrix<&'static str> {
    // 构造逻辑回归模型
    let model = LogisticRegression::default()
        // 最大回归次数
        .max_iterations(max_iterations)
        // 设置梯度下降的学习率,当变化值小于该值时则停止迭代
        .gradient_tolerance(0.0001)
        // 传入训练集
        .fit(train)
        .expect("can train model");
    // 调用测试集对模型进行测试
    // .set_threshold(概率阈值)  设置预测“正”类的概率阈值,默认值为0.5
    let validation = model.set_threshold(threshold).predict(test);
    // 根据测试集创建混淆矩阵
    let confusion_matrix = validation
        // 传入测试集
        .confusion_matrix(test)
        .expect("can create confusion matrix");
    // 返回混淆矩阵数据
    confusion_matrix
}

test.csv

32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0

train.csv

45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1

 

文章转载至:https://jarod.blog.csdn.net/article/details/128089875

标签:matrix,max,threshold,train,let,rust,confusion,Linfa,Rust
From: https://www.cnblogs.com/-CO-/p/18152737

相关文章

  • 【rust】《Rust深度学习[4]-理解线性网络(Candle)》
    全连接/线性在神经网络中,全连接层,也称为线性层,是一种层,其中来自一层的所有输入都连接到下一层的每个激活单元。在大多数流行的机器学习模型中,网络的最后几层是完全连接的。实际上,这种类型的层执行基于在先前层中学习的特征输出类别预测的任务。全连接层的示例,具有四个输入节点......
  • 【rust】《Rust深度学习[5]-理解卷积神经网络(Candle)》
    卷积神经网络ConvolutionalNeuralNetwork,简称为CNN。CNN与一般的顺传播型神经网络不同,它不仅是由全结合层,还由卷积层(ConvolutionLayer)和池层(PoolingLayer)构成的神经网络。在卷积层和池化层中,如下图所示,缩小输入神经元的一部分区域,局部地与下一层进行对应。每一层都有一个称......
  • 【rust】《Rust深度学习[2]-数据分析和挖掘库(Polars)》
    什么是Polars?Polars是一个用于操作结构化数据的高性能DataFrame库,可以用来进行数据清洗和格式转换、数据分析和统计、数据可视化、数据读取和存储、数据合并和拼接等等,相当于Rust版本的Pandas库。Polars读写数据支持如下:  常见数据文件:csv、parquet(不支持xlsx、json文件) ......
  • 【rust】《Rust深度学习[3]-数据可视化(Plotters)》
    什么是Plotters?Plotters是一个用纯Rust开发的图形库,用于中渲染图形、图表和数据可视化。它支持静态图片渲染和实时渲染,并支持多种后端,包括:位图格式(png、bmp、gif等)、矢量图(svg)、窗口和HTML5Canvas。Plotters对不同后端使用统一的高级API,并允许开发者自定义坐标系。在Plotters......
  • 【rust】《Rust深度学习[1]-科学计算库(Ndarray)》
    什么是Ndarray?ndarray是Rust生态中用于处理数组的库。它包含了所有常用的数组操作。简单地说ndarray相当于Rust版本的numpy。ndarray生态系统中crate的文档:ndarray基础库ndarray-rand随机数生成库ndarray-stats统计方法  顺序统计(最小、最大、中值、分位数等);  汇总......
  • three.js使用Instanced Draw+Frustum Cull+LOD来渲染大场景(开源)
    大家好,本文使用three.js实现了渲染大场景,在移动端也有较好的性能,并给出了代码,分析了关键点,感谢大家~关键词:three.js、InstancedDraw、大场景、LOD、FrustumCull、优化、Web3D、WebGL、开源代码:Github我正在承接Web3D数字孪生项目,具体介绍可看承接各种Web3D业务加QQ群交流:106......
  • vim配置rust开发环境
    vim配置需要环境首先需要安装rust,然后安装rust-analysis,还需要nodejs,npm。插件使用vim-plug管理,也是需要提前安装的安装coc之后还需要安装CocInstallcoc-rust-analysis下边是踩坑出来的配置文件"插件安装在callplug#begin('~/.vim/plugged')和callplug#end()之间。cal......
  • Windows快速安装Rust
    本文是最简最快最小化安装重点提示:如果不想安装VS消耗时间和6-8G的空间,可以按本文安装。如果系统中已经安装了VS,那么直接运行rustup-init安装Rust,并一路回车即可。前置条件:安装C++环境rust底层是依赖C环境的连接器,所以需要先安装C/C++编译环境,点击下载64位mingw-builds......
  • The Stack and the Heap栈与堆__Rust
    Manyprogramminglanguagesdon’trequireyoutothinkaboutthestackandtheheapveryoften.许多编程语言并不会要求你经常思考堆栈。ButinasystemprogramminglanguagelikeRust,whetheravalueisonthestackortheheapaffectshowthelanguagebehaves......
  • rust程序中设置和访问环境变量
    在项目中,我们通常需要设置一些环境变量,用来保存一些凭证或其它数据,这时我们可以使用dotenv这个crate。1、添加crate依赖首先在项目中添加dotenv这个依赖:  2、添加.env文件在开发环境下,我们可以在项目根目录下创建和编辑.env这个文件: 在运行环境下,这个.env文件要......