线性回归算法原理
线性回归的基本思想是通过一条直线来拟合数据点,使得数据点到这条直线的距离平方和最小。其数学表达式为:
y = β 0 + β 1 x 1 + β 2 x 2 + ⋯ + β n x n y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \cdots + \beta_n x_n y=β0+β1x1+β2x2+⋯+βnxn
其中, β 0 \beta_0 β0是偏置项(intercept), β 1 , β 2 , ⋯ , β n \beta_1, \beta_2, \cdots, \beta_n β1,β2,⋯,βn是各个特征的系数(coefficients)。
Java实现线性回归
以下是一个简单的Java实现,分为以下几个部分:
- 添加偏置项
- 计算系数
- 预测
- 矩阵运算
1. 添加偏置项
首先,我们需要在特征矩阵X中添加一列全为1的偏置项。
private double[][] addIntercept(double[][] X) {
int nSamples = X.length;
int nFeatures = X[0].length;
double[][] X_with_intercept = new double[nSamples][nFeatures + 1];
for (int i = 0; i < nSamples; i++) {
X_with_intercept[i][0] = 1; // intercept
System.arraycopy(X[i], 0, X_with_intercept[i], 1, nFeatures);
}
return X_with_intercept;
}
2. 计算系数
接下来,我们使用最小二乘法来计算系数。通过矩阵运算,我们可以得到以下公式:
β = ( X T X ) − 1 X T y \beta = (X^T X)^{-1} X^T y β=(XTX)−1XTy
private double[] calculateCoefficients(double[][] X, double[] y) {
int nFeatures = X[0].length;
double[][] XtX = new double[nFeatures][nFeatures];
double[] XtY = new double[nFeatures];
for (int i = 0; i < X.length; i++) {
for (int j = 0; j < nFeatures; j++) {
for (int k = 0; k < nFeatures; k++) {
XtX[j][k] += X[i][j] * X[i][k];
}
XtY[j] += X[i][j] * y[i];
}
}
return solveLinearEquation(XtX, XtY);
}
3. 预测
根据计算出的系数,我们可以对新的数据进行预测:
public double[] predict(double[][] X) {
if (coefficients == null) {
throw new IllegalStateException("模型尚未训练,请先调用fit方法进行训练。");
}
double[][] X_with_intercept = addIntercept(X);
double[] predictions = calculatePredictions(X_with_intercept);
return predictions;
}
private double[] calculatePredictions(double[][] X) {
double[] predictions = new double[X.length];
for (int i = 0; i < X.length; i++) {
for (int j = 0; j < coefficients.length; j++) {
predictions[i] += X[i][j] * coefficients[j];
}
}
return predictions;
}
4. 矩阵运算
我们使用Jama库来解决线性方程:
private double[] solveLinearEquation(double[][] A, double[] b) {
Matrix matrixA = new Matrix(A);
Matrix matrixB = new Matrix(b, b.length);
Matrix solution = matrixA.solve(matrixB);
double[] result = new double[solution.getRowDimension()];
for (int i = 0; i < result.length; i++) {
result[i] = solution.get(i, 0);
}
return result;
}
5. 完整代码
以下是完整的代码实现:
package cn.intana.business.sdk.utils;
import Jama.Matrix;
public class LinearRegression {
private double[] coefficients;
public void fit(double[][] X, double[] y) {
double[][] X_with_intercept = addIntercept(X);
coefficients = calculateCoefficients(X_with_intercept, y);
}
public double[] predict(double[][] X) {
if (coefficients == null) {
throw new IllegalStateException("模型尚未训练,请先调用fit方法进行训练。");
}
double[][] X_with_intercept = addIntercept(X);
double[] predictions = calculatePredictions(X_with_intercept);
return predictions;
}
private double[][] addIntercept(double[][] X) {
int nSamples = X.length;
int nFeatures = X[0].length;
double[][] X_with_intercept = new double[nSamples][nFeatures + 1];
for (int i = 0; i < nSamples; i++) {
X_with_intercept[i][0] = 1;
System.arraycopy(X[i], 0, X_with_intercept[i], 1, nFeatures);
}
return X_with_intercept;
}
private double[] calculateCoefficients(double[][] X, double[] y) {
int nFeatures = X[0].length;
double[][] XtX = new double[nFeatures][nFeatures];
double[] XtY = new double[nFeatures];
for (int i = 0; i < X.length; i++) {
for (int j = 0; j < nFeatures; j++) {
for (int k = 0; k < nFeatures; k++) {
XtX[j][k] += X[i][j] * X[i][k];
}
XtY[j] += X[i][j] * y[i];
}
}
return solveLinearEquation(XtX, XtY);
}
private double[] solveLinearEquation(double[][] A, double[] b) {
Matrix matrixA = new Matrix(A);
Matrix matrixB = new Matrix(b, b.length);
Matrix solution = matrixA.solve(matrixB);
double[] result = new double[solution.getRowDimension()];
for (int i = 0; i < result.length; i++) {
result[i] = solution.get(i, 0);
}
return result;
}
private double[] calculatePredictions(double[][] X) {
double[] predictions = new double[X.length];
for (int i = 0; i < X.length; i++) {
for (int j = 0; j < coefficients.length; j++) {
predictions[i] += X[i][j] * coefficients[j];
}
}
return predictions;
}
}
pom
<dependency>
<groupId>gov.nist.math</groupId>
<artifactId>jama</artifactId>
<version>1.0.3</version>
</dependency>
标签:Java,nFeatures,int,double,length,算法,intercept,线性,new
From: https://blog.csdn.net/qq_57489104/article/details/139296109