from numpy import mean
class LeastSquare:
"""
西瓜书 P54
最小二乘法 求解回归直线
"""
def __init__(self, x, y):
self.x = x
self.y = y
self.__m = len(self.x)
self.__mean_x = mean(self.x)
self.w = self.calculate_w()
self.b = self.calculate_b()
def calculate_w(self):
fz = sum([self.y[i] * (self.x[i] - self.__mean_x) for i in range(self.__m)])
fm = sum([self.x[i] ** 2 for i in range(self.__m)]) - (
sum([self.x[i] for i in range(self.__m)]) ** 2) / self.__m
return fz / fm
def calculate_b(self):
return sum([self.y[i] - self.w * self.x[i] for i in range(self.__m)]) / self.__m
测试代码
faker_xy_data()生成数据集 见: https://www.cnblogs.com/boran/p/16757677.html
from practice1.create_xy_data import faker_xy_data
from practice1.least_square_method import LeastSquare
x1 = faker_xy_data()
print(f"a={x1.a} b={x1.b}")
E = LeastSquare(x=x1.x, y=x1.y)
print(f"w={E.w} b={E.b}")
一次测试结果
a=1.033132634968986 b=10.570162394769465
w=0.9880089076924274 b=13.378591305915043
标签:直线,.__,calculate,求解,self,xy,x1,data,乘法
From: https://www.cnblogs.com/boran/p/16758024.html