首页 > 编程语言 >使用Java实现线性回归算法

使用Java实现线性回归算法

时间:2024-05-30 23:58:02浏览次数:26  
标签:Java nFeatures int double length 算法 intercept 线性 new

线性回归算法原理

线性回归的基本思想是通过一条直线来拟合数据点,使得数据点到这条直线的距离平方和最小。其数学表达式为:

y = β 0 + β 1 x 1 + β 2 x 2 + ⋯ + β n x n y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \cdots + \beta_n x_n y=β0​+β1​x1​+β2​x2​+⋯+βn​xn​

其中, β 0 \beta_0 β0​是偏置项(intercept), β 1 , β 2 , ⋯   , β n \beta_1, \beta_2, \cdots, \beta_n β1​,β2​,⋯,βn​是各个特征的系数(coefficients)。

Java实现线性回归

以下是一个简单的Java实现,分为以下几个部分:

  • 添加偏置项
  • 计算系数
  • 预测
  • 矩阵运算

1. 添加偏置项

首先,我们需要在特征矩阵X中添加一列全为1的偏置项。

private double[][] addIntercept(double[][] X) {
    int nSamples = X.length;
    int nFeatures = X[0].length;
    double[][] X_with_intercept = new double[nSamples][nFeatures + 1];

    for (int i = 0; i < nSamples; i++) {
        X_with_intercept[i][0] = 1;  // intercept
        System.arraycopy(X[i], 0, X_with_intercept[i], 1, nFeatures);
    }

    return X_with_intercept;
}

2. 计算系数

接下来,我们使用最小二乘法来计算系数。通过矩阵运算,我们可以得到以下公式:

β = ( X T X ) − 1 X T y \beta = (X^T X)^{-1} X^T y β=(XTX)−1XTy

private double[] calculateCoefficients(double[][] X, double[] y) {
    int nFeatures = X[0].length;
    double[][] XtX = new double[nFeatures][nFeatures];
    double[] XtY = new double[nFeatures];

    for (int i = 0; i < X.length; i++) {
        for (int j = 0; j < nFeatures; j++) {
            for (int k = 0; k < nFeatures; k++) {
                XtX[j][k] += X[i][j] * X[i][k];
            }
            XtY[j] += X[i][j] * y[i];
        }
    }

    return solveLinearEquation(XtX, XtY);
}

3. 预测

根据计算出的系数,我们可以对新的数据进行预测:

public double[] predict(double[][] X) {
    if (coefficients == null) {
        throw new IllegalStateException("模型尚未训练,请先调用fit方法进行训练。");
    }

    double[][] X_with_intercept = addIntercept(X);
    double[] predictions = calculatePredictions(X_with_intercept);
    return predictions;
}

private double[] calculatePredictions(double[][] X) {
    double[] predictions = new double[X.length];
    for (int i = 0; i < X.length; i++) {
        for (int j = 0; j < coefficients.length; j++) {
            predictions[i] += X[i][j] * coefficients[j];
        }
    }
    return predictions;
}

4. 矩阵运算

我们使用Jama库来解决线性方程:

private double[] solveLinearEquation(double[][] A, double[] b) {
    Matrix matrixA = new Matrix(A);
    Matrix matrixB = new Matrix(b, b.length);
    Matrix solution = matrixA.solve(matrixB);
    double[] result = new double[solution.getRowDimension()];
    for (int i = 0; i < result.length; i++) {
        result[i] = solution.get(i, 0);
    }
    return result;
}

5. 完整代码

以下是完整的代码实现:

package cn.intana.business.sdk.utils;

import Jama.Matrix;

public class LinearRegression {
    private double[] coefficients;

    public void fit(double[][] X, double[] y) {
        double[][] X_with_intercept = addIntercept(X);
        coefficients = calculateCoefficients(X_with_intercept, y);
    }

    public double[] predict(double[][] X) {
        if (coefficients == null) {
            throw new IllegalStateException("模型尚未训练,请先调用fit方法进行训练。");
        }

        double[][] X_with_intercept = addIntercept(X);
        double[] predictions = calculatePredictions(X_with_intercept);
        return predictions;
    }

    private double[][] addIntercept(double[][] X) {
        int nSamples = X.length;
        int nFeatures = X[0].length;
        double[][] X_with_intercept = new double[nSamples][nFeatures + 1];

        for (int i = 0; i < nSamples; i++) {
            X_with_intercept[i][0] = 1;
            System.arraycopy(X[i], 0, X_with_intercept[i], 1, nFeatures);
        }

        return X_with_intercept;
    }

    private double[] calculateCoefficients(double[][] X, double[] y) {
        int nFeatures = X[0].length;
        double[][] XtX = new double[nFeatures][nFeatures];
        double[] XtY = new double[nFeatures];

        for (int i = 0; i < X.length; i++) {
            for (int j = 0; j < nFeatures; j++) {
                for (int k = 0; k < nFeatures; k++) {
                    XtX[j][k] += X[i][j] * X[i][k];
                }
                XtY[j] += X[i][j] * y[i];
            }
        }

        return solveLinearEquation(XtX, XtY);
    }

    private double[] solveLinearEquation(double[][] A, double[] b) {
        Matrix matrixA = new Matrix(A);
        Matrix matrixB = new Matrix(b, b.length);
        Matrix solution = matrixA.solve(matrixB);
        double[] result = new double[solution.getRowDimension()];
        for (int i = 0; i < result.length; i++) {
            result[i] = solution.get(i, 0);
        }
        return result;
    }

    private double[] calculatePredictions(double[][] X) {
        double[] predictions = new double[X.length];
        for (int i = 0; i < X.length; i++) {
            for (int j = 0; j < coefficients.length; j++) {
                predictions[i] += X[i][j] * coefficients[j];
            }
        }
        return predictions;
    }
}

pom

<dependency>
            <groupId>gov.nist.math</groupId>
            <artifactId>jama</artifactId>
            <version>1.0.3</version>
</dependency>

标签:Java,nFeatures,int,double,length,算法,intercept,线性,new
From: https://blog.csdn.net/qq_57489104/article/details/139296109

相关文章

  • C语言贪心算法——解硬币
    题目:有1元,5元,10元,100元,500元的硬币各从c1枚,c5枚,c10枚,c50枚,c100枚,c500枚,现在要用这些硬币支付A元,最少需要多少枚硬币输入:第一行有六个数字,分别代表从小到大6种面值的硬币的个数:第二行为A案例:输入:321302620输出:6#include<stdio.h>intmain(){ intnumber[6]......
  • 算法金 | 详解过拟合和欠拟合!性感妩媚 VS 大杀四方
    大侠幸会,在下全网同名「算法金」0基础转AI上岸,多个算法赛Top「日更万日,让更多人享受智能乐趣」今天我们来战过拟合和欠拟合,特别是令江湖侠客闻风丧胆的过拟合,简称过儿,Emmm过儿听起来有点怪怪的1.楔子机器学习模型是一种能够从数据中学习规律并进行预测的算法。......
  • Vue前端的搭建(与后端JavaEE的连接)
    目录前端平台搭建(Vue2.6,App:HBulderX)创建Vue2.6项目下载相应插件方便开发路由配置对连接后端进行一些配置(main.js文件)导入ElementUI组件组件|Element同步与异步axios异步请求框架前端平台搭建(Vue2.6,App:HBulderX)创建Vue2.6项目如图,创完之后的样子下载相应插件方......
  • JAVA开发 利用代码生成奖状pdf-中文版
    利用代码生成奖状pdf-中文版1、图片模板2、实现代码3、生成模板(pdf文件截取)1、图片模板2、实现代码importorg.apache.pdfbox.pdmodel.PDDocument;importorg.apache.pdfbox.pdmodel.PDPage;importorg.apache.pdfbox.pdmodel.PDPageContentStream;importorg......
  • 【Java】类和对象
    类和类的实例化类就是一类对象的统称。对象就是这一类具体化的一个实例。声明一个类就是创建一个新的数据类型,而类在Java中属于引用类型,Java使用关键字class来声明类。我们来看以下简单的声明一个类。基本语法:class<class_name>{ ......
  • Java-IO-IO模型
    参考:UNIX下的五种IO模型10分钟看懂,JavaNIO底层原理Linux五种网络IO模式(阻塞IO、非阻塞IO、IO多路复用、信号驱动IO、异步IO)1.什么是IO根据冯.诺依曼结构,计算机结构分为5大部分:运算器、控制器、存储器、输入设备、输出设备。操作系统为了保证稳定性和安全性,一个进......
  • 什么是状态机,用简单的java示例说明状态机的概念
    1.什么是状态机状态机(StateMachine)是一种抽象的计算模型,用于描述一个系统在不同状态之间的转换以及触发这些转换的事件。它由状态、事件、动作和转换规则组成。状态代表系统在某个时刻的行为模式;事件是引起状态转换的外部或内部信号;动作是在状态转换时执行的操作;转换规则定义......
  • 在javascript中定义三个状态机
    //定义基础状态机类classBaseStateMachine{constructor(initialState){this.currentState=initialState;}//转换状态的方法,子类需要根据实际逻辑重写此方法transition(event){thrownewError("transitionmethodmustbeimp......
  • Java集合(一)
    集合概念:集合是JavaAPI所提供的一系列类,可以用于动态存放多个对象。集合只能存对象集合与数组的不同在于,集合是大小可变的序列,而且元素类型可以不受限定,只要是引用类型。(集合中不能放基本数据类型,但可以放基本数据类型的包装类),集合类全部支持泛型,是一种数据安全的用法。......
  • Java函数式编程
    Java函数式编程Java8引入了对函数式编程的支持。Java8中引入的主要特性1.Lambda表达式和函数式接口:Lambda表达式允许以更简洁的方式表达一个方法的实现。函数式接口,只定义了一个抽象方法的接口(使用@FunctionalInterface注解来标记此类接口),与Lambda表达式一起使用,以便可......