手动实现线性回归(梯度下降法)
1 public class LinearRegressionGD { 2 private double learningRate; 3 private int iterations; 4 private double slope; 5 private double intercept; 6 7 public LinearRegressionGD(double learningRate, int iterations) { 8 this.learningRate = learningRate; 9 this.iterations = iterations; 10 this.slope = 0; 11 this.intercept = 0; 12 } 13 14 public void fit(double[] x, double[] y) { 15 int n = x.length; 16 17 for (int i = 0; i < iterations; i++) { 18 double slopeGradient = 0; 19 double interceptGradient = 0; 20 21 // 计算梯度 22 for (int j = 0; j < n; j++) { 23 double prediction = slope * x[j] + intercept; 24 slopeGradient += - (2.0 / n) * x[j] * (y[j] - prediction); 25 interceptGradient += - (2.0 / n) * (y[j] - prediction); 26 } 27 28 // 更新参数 29 slope -= learningRate * slopeGradient; 30 intercept -= learningRate * interceptGradient; 31 } 32 } 33 34 public double predict(double x) { 35 return slope * x + intercept; 36 } 37 38 public static void main(String[] args) { 39 double[] x = {1, 2, 3, 4, 5}; 40 double[] y = {50, 55, 65, 70, 85}; 41 42 LinearRegressionGD model = new LinearRegressionGD(0.01, 1000); // 设置学习率和迭代次数 43 model.fit(x, y); 44 45 double predictedScore = model.predict(6); // 预测学习时间为6小时时的考试成绩 46 System.out.println("预测的考试成绩: " + predictedScore); 47 } 48 }
代码解析
- 构造函数:初始化学习率和迭代次数。
- fit 方法:使用梯度下降法更新斜率和截距。predict 方法:根据学习到的斜率和截距进行预测。
- 计算当前预测值与实际值之间的误差,并根据误差计算梯度。
- 通过梯度调整斜率和截距。
- main 方法:创建数据集,实例化模型,训练模型,并进行预测。
参数设置
learningRate
:控制每次更新的步长。较小的学习率可能导致收敛速度慢,而较大的学习率可能导致不收敛。iterations
:指定训练模型的轮数。过少的迭代可能导致模型未收敛,过多的迭代则可能导致过拟合。
通过这种手动实现的方式,你可以灵活地控制学习率和迭代次数,以优化模型的性能。
标签:slope,Java,double,算法,intercept,learningRate,iterations,线性,public From: https://www.cnblogs.com/erichi101/p/18458212