首页 > 编程语言 >线性回归算法(Java)

线性回归算法(Java)

时间:2024-10-11 14:00:15浏览次数:6  
标签:slope Java double 算法 intercept learningRate iterations 线性 public

 

 

手动实现线性回归(梯度下降法)

 1 public class LinearRegressionGD {
 2     private double learningRate;
 3     private int iterations;
 4     private double slope;
 5     private double intercept;
 6 
 7     public LinearRegressionGD(double learningRate, int iterations) {
 8         this.learningRate = learningRate;
 9         this.iterations = iterations;
10         this.slope = 0;
11         this.intercept = 0;
12     }
13 
14     public void fit(double[] x, double[] y) {
15         int n = x.length;
16 
17         for (int i = 0; i < iterations; i++) {
18             double slopeGradient = 0;
19             double interceptGradient = 0;
20 
21             // 计算梯度
22             for (int j = 0; j < n; j++) {
23                 double prediction = slope * x[j] + intercept;
24                 slopeGradient += - (2.0 / n) * x[j] * (y[j] - prediction);
25                 interceptGradient += - (2.0 / n) * (y[j] - prediction);
26             }
27 
28             // 更新参数
29             slope -= learningRate * slopeGradient;
30             intercept -= learningRate * interceptGradient;
31         }
32     }
33 
34     public double predict(double x) {
35         return slope * x + intercept;
36     }
37 
38     public static void main(String[] args) {
39         double[] x = {1, 2, 3, 4, 5};
40         double[] y = {50, 55, 65, 70, 85};
41 
42         LinearRegressionGD model = new LinearRegressionGD(0.01, 1000); // 设置学习率和迭代次数
43         model.fit(x, y);
44 
45         double predictedScore = model.predict(6); // 预测学习时间为6小时时的考试成绩
46         System.out.println("预测的考试成绩: " + predictedScore);
47     }
48 }

代码解析

  1. 构造函数:初始化学习率和迭代次数。
  2. fit 方法:使用梯度下降法更新斜率和截距。predict 方法:根据学习到的斜率和截距进行预测。
    • 计算当前预测值与实际值之间的误差,并根据误差计算梯度。
    • 通过梯度调整斜率和截距。
  3. main 方法:创建数据集,实例化模型,训练模型,并进行预测。

参数设置

  • learningRate:控制每次更新的步长。较小的学习率可能导致收敛速度慢,而较大的学习率可能导致不收敛。
  • iterations:指定训练模型的轮数。过少的迭代可能导致模型未收敛,过多的迭代则可能导致过拟合。

通过这种手动实现的方式,你可以灵活地控制学习率和迭代次数,以优化模型的性能。

标签:slope,Java,double,算法,intercept,learningRate,iterations,线性,public
From: https://www.cnblogs.com/erichi101/p/18458212

相关文章

  • [1380]基于JAVA的建筑物施工智慧管理系统的设计与实现
    毕业设计(论文)开题报告表姓名学院专业班级题目基于JAVA的建筑物施工智慧管理系统的设计与实现指导老师(一)选题的背景和意义在当前全球信息化、智能化的大背景下,建筑施工行业的管理模式也正经历着深刻变革。随着国家对智慧城市和智慧工地的大力推广与政策支持,基于信息技术......
  • 为什么 Java 中的时间类如此繁多而复杂?
    为什么Java中的时间类如此繁多而复杂?从事程序员这些年,在业务中处理最繁琐且容易出现的场景就是时间处理,而且Java当中的时间类繁琐又复杂,类型从字符串转Date,LocalDate等等,时间计算、时间差、区间计算等场景太多且不可避免。怎么回事呢?在Java的世界中,时间类显得尤为繁多......
  • 【hot100-java】LRU 缓存
    链表篇灵神题解  classLRUCache{privatestaticclassNode{intkey,value;Nodeprev,next;Node(intk,intv){key=k;value=v;}}privatefinalintcapacity;//哨兵节点......
  • 【hot100-java】二叉树的右视图
    二叉树篇tql /***Definitionforabinarytreenode.*publicclassTreeNode{*intval;*TreeNodeleft;*TreeNoderight;*TreeNode(){}*TreeNode(intval){this.val=val;}*TreeNode(intval,TreeNodeleft,Tre......
  • 【hot100-java】合并 K 个升序链表
    链表篇/***Definitionforsingly-linkedlist.*publicclassListNode{*intval;*ListNodenext;*ListNode(){}*ListNode(intval){this.val=val;}*ListNode(intval,ListNodenext){this.val=val;this.next=next;......
  • [Java/Spring] 深入理解 : Spring ApplicationContext
    [Java/Spring]深入理解:SpringApplicationContext1概述:ApplicationContext简介2源码分析ApplicationContextpackageorg.springframework.context;publicinterfaceApplicationContextextendsEnvironmentCapable,ListableBeanFactory,HierarchicalBeanFactor......
  • C++ 算法学习——1.8 倍增与ST表
    在C++中,"倍增"(也称为"指数增长"或"指数级别增长")是一种算法优化技术,它通常用于解决一些需要频繁查询某个区间内的信息的问题,例如在处理动态规划、搜索等算法中。倍增思想的主要目的是通过预处理和存储一些中间结果,以加速后续的查询操作。具体来说,倍增思想通常包括以下步骤:......
  • C++ 算法学习——1.8 单调队列算法
    单调队列(MonotonicQueue)是一种特殊类型的队列,通常用于解决一些数组或序列相关的问题。和单调栈类似,单调队列也具有一些特定的性质,在解决一些问题时非常有用。以下是关于单调队列的一些重要点:定义:单调队列是一种数据结构,队列中的元素满足单调递增或单调递减的性质。应用:单......
  • 一款Java CMS 网站管理系统,基于RuoYi-fast二次开发,网站后台采用SpringBoot + MyBati
    一款JavaCMS网站管理系统基于RuoYi-fast二次开发,网站后台采用SpringBoot+MyBatis文章目录前言一、开源地址二、环境要求三、功能亮点3.1扩展功能3.2内置功能四、安装方法4.1、拉取源码4.2、修改数据库链接配置4.3、创建数据库并导入数据4.4、配置资源上传......
  • 【Java】异常处理指南
     ......