什么是Linfa
Linfa 是一组Rust高级库的集合,提供了常用的数据处理方法和机器学习算法。Linfa对标Python上的 scikit-learn ,专注于日常机器学习任务常用的预处理任务和经典机器学习算法,目前Linfa已经实现了scikit-learn中的全部算法。
项目结构
依赖
[package]
name = "rust-ml-example"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
linfa = {version="0.6.0"}
linfa-logistic = {version="0.6.0"}
ndarray = { version = "0.15.6", default-features = false }
ndarray-csv = "0.5.1"
csv = "1.1"
plotters = { version = "^0.3.0" }
rand="0.8.5"
代码
use csv::ReaderBuilder;
use linfa::prelude::*;
use linfa_logistic::LogisticRegression;
use ndarray::{prelude::*, OwnedRepr};
use ndarray_csv::Array2Reader;
use plotters::prelude::*
// 训练模型 - 程序入口
fn main() {
// 读取训练数据集
let train = load_data("data/train.csv");
// 读取测试数据集
let test = load_data("data/test.csv");
// // 获取训练集中的数据集
// let features = train.nfeatures();
// // 获取训练集中的标签集
// let targets = train.ntargets();
// println!("training with {} samples, testing with {} samples, {} features and {} target",
// train.nsamples(),test.nsamples(),features,targets);
// println!("plotting data...");
// 生成散点分布图
plot_data(&train);
// println!("training and testing model...");
// 随机设置初始决策阈值与最大回归次数
// 调用逻辑回归返回混淆矩阵数据
let mut max_accuracy_confusion_matrix = iterate_with_values(&train, &test, 0.01, 100);
// 临时决策阈值
let mut best_threshold = 0.0;
// 临时最大回归次数
let mut best_max_iterations = 0;
// 循环用的临时决策阈值
let mut threshold = 0.02;
// 迭代1000 - 4999 每500步长取一次值运算
// 500步长方式递增回归次数
for max_iterations in (1000..5000).step_by(500) {
// 决策阈值小于1则一直循环
while threshold < 1.0 {
// 调用逻辑回归返回混淆矩阵数据
let confusion_matrix = iterate_with_values(&train, &test, threshold, max_iterations);
// 此次混淆矩阵的精确值 大于 随机设置参数调用的混淆矩阵的精确值
if confusion_matrix.accuracy() > max_accuracy_confusion_matrix.accuracy() {
// 覆盖随机参数调用的混淆矩阵结果
max_accuracy_confusion_matrix = confusion_matrix;
// 覆盖决策阈值
best_threshold = threshold;
// 覆盖最大回归次数
best_max_iterations = max_iterations;
}
// 循环用的临时决策阈值 + 0.01
threshold += 0.01;
}
// 还原 循环用的临时决策阈值 临时值
threshold = 0.02;
}
println!("最精确混淆矩阵: {:?}", max_accuracy_confusion_matrix);
println!("最优迭代次数: {}\n最优决策阈值: {}", best_max_iterations, best_threshold);
println!("精确率:\t{}", max_accuracy_confusion_matrix.accuracy(),);
println!("准确率:\t{}", max_accuracy_confusion_matrix.precision(),);
println!("召回率:\t{}", max_accuracy_confusion_matrix.recall(),);
}
// csv文件数据读取
// 返回DataSet数据集
fn load_data(path: &str) -> Dataset<f64, &'static str, Ix1> {
// 获取csv文件并验证是否存在数据
let mut reader = ReaderBuilder::new()
// 是否具有表头
.has_headers(false)
// 分隔符
.delimiter(b',')
// 文件路径
.from_path(path)
// 存在就输出并提示
.expect("can create reader");
// 获取csv文件内的数据
let array: Array2<f64> = reader
// 动态序列化数据(转ndarray格式列表)
.deserialize_array2_dynamic()
// 存在就输出并提示
.expect("can deserialize array");
// 结构转化为 (数据 , 标签)
let (data, targets) = (
// 取所有列内的0,1下标的两个数据
// to_owned 拷贝原始数据并获取所有权
// 转换成一维数组
array.slice(s![.., 0..2]).to_owned(),
// 获取第3列的数据, 就是0或1的数据
array.column(2).to_owned(),
);
// 作散点图用的标签
// 特征名
let feature_names = vec!["test 1", "test 2"];
// 用ndarray创建Linfa的Dataset数据集
// 两个参数(数据, 标签)
Dataset::new(data, targets)
// 这里循环标签列表的第一个参数
.map_targets(|x| {
// 解引用x转usize类型 并且 ==1
if *x as usize == 1 {
// 录取
"accepted"
} else {
// 淘汰
"denied"
}
})
// 对数据命名(重写vec列表)并将标签列表转为字符列表
.with_feature_names(feature_names)
}
// 数据绘图
// 生成散点分布图片
fn plot_data(
train: &DatasetBase<
ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
>,
) {
// 录取列表的数据集
let mut positive = vec![];
// 淘汰列表的数据集
let mut negative = vec![];
// 获取Dataset中数据列表
let records = train.records().clone().into_raw_vec();
// 将一维数组通过固定大小分组转化为二维数组
// chunks(size:usize) 对数组/列表中的数据按指定大小size分块
let features: Vec<&[f64]> = records.chunks(2).collect();
// 获取Dataset中标签列表
let targets = train.targets().clone().into_raw_vec();
// 迭代二维数组
for i in 0..features.len() {
// 获取内部的数组
let feature = features.get(i).expect("feature exists");
// 判断标签是否为 录取
if let Some(&"accepted") = targets.get(i) {
// 录取列表数据集(第一次成绩, 第二次成绩)
positive.push((feature[0], feature[1]));
} else {
// 淘汰列表数据集(第一次成绩, 第二次成绩)
negative.push((feature[0], feature[1]));
}
}
// 创建绘图文件
let root_area = BitMapBackend::new("plot.png", (600, 400)).into_drawing_area();
// 设置图片白色背景
root_area.fill(&WHITE).unwrap();
// 开始绘制基础图表结构
let mut ctx = ChartBuilder::on(&root_area)
// 图表左侧与图片边缘的间距
.set_label_area_size(LabelAreaPosition::Left, 40)
// 图表底部与图片边缘的间距
.set_label_area_size(LabelAreaPosition::Bottom, 40)
// 图表名称 (字体样式, 字体大小)
.caption("分数分布", ("sans-serif", 30))
// 构建二维图像, x轴 0.0 - 120.0; y轴 0.0 - 120.0;
.build_cartesian_2d(0.0..120.0, 0.0..120.0)
.unwrap();
// 配置网格线
ctx.configure_mesh().draw().unwrap();
// 绘制录取的散点
ctx.draw_series(
positive
// 迭代
.iter()
// 绘制三角形(坐标数据,大小,颜色)
.map(|point| TriangleMarker::new(*point, 5, &BLUE)),
)
.unwrap();
// 绘制淘汰的散点
ctx.draw_series(
negative
// 迭代
.iter()
// 绘制圆圈(坐标数据,大小,颜色)
.map(|point| Circle::new(*point, 5, &RED)),
)
.unwrap();
}
// 逻辑回归(循环迭代值训练模型)
// 迭代次数(max_iterations) 和 决策阈值(threshold)
// 我们需要反复多次测试以找到这两个参数的最有值,为此我们需要构造循环多次调用上面的过程
// 返回混淆矩阵
fn iterate_with_values(
// 训练集
train: &DatasetBase<
ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
>,
// 测试集
test: &DatasetBase<
ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
ArrayBase<OwnedRepr<&'static str>, Dim<[usize; 1]>>,
>,
// 概率阈值
threshold: f64,
// 最大回归次数
max_iterations: u64,
) -> ConfusionMatrix<&'static str> {
// 构造逻辑回归模型
let model = LogisticRegression::default()
// 最大回归次数
.max_iterations(max_iterations)
// 设置梯度下降的学习率,当变化值小于该值时则停止迭代
.gradient_tolerance(0.0001)
// 传入训练集
.fit(train)
.expect("can train model");
// 调用测试集对模型进行测试
// .set_threshold(概率阈值) 设置预测“正”类的概率阈值,默认值为0.5
let validation = model.set_threshold(threshold).predict(test);
// 根据测试集创建混淆矩阵
let confusion_matrix = validation
// 传入测试集
.confusion_matrix(test)
.expect("can create confusion matrix");
// 返回混淆矩阵数据
confusion_matrix
}
test.csv
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
train.csv
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
文章转载至:https://jarod.blog.csdn.net/article/details/128089875
标签:matrix,max,threshold,train,let,rust,confusion,Linfa,Rust From: https://www.cnblogs.com/-CO-/p/18152737