上一题篇文章写了线性回归以及梯度下降法,这篇文章讲一下逻辑回归。虽然它叫逻辑回归,但是它并非回归模型,而是一个分类模型。那么回归和分类有什么区别呢?在上一篇文章中,我们以住房各特征预测了房价中位数。这个是给定数据,预测一个连续的数据。而分类呢?还是举出上面的例子,只不过这次我不需要预测价格中位数了,只需要预测这个房子的“好与坏”,值域只有(好、坏)。
最后注意:
求导后是
矩阵形式是:
下面是逻辑回归矩阵形式的推导:
实验操作:
要求:已知有数据(exam1,exam2,aeecpted),第一个和第二个是成绩,第三个是是否被大学录取,要求根据成绩来预测是否被大学录取。
直接给出数据,自己复制到txt中测试(数据在最后面):
第一步先看一下数据可视化:
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('tKAgg')
import matplotlib.pyplot as plt
path = "D:\JD\Documents\大学等等等\自学部分\machine_-learning-master\machine_-learning-master\ex_2\ex2data1.txt"
data = pd.read_csv(path, names=['Exam1', 'Exam2', 'Accepted'])
print(data.head())
fig, ax = plt.subplots()
ax.scatter(data[data['Accepted'] == 0]['Exam1'], data[data['Accepted'] == 0]['Exam2'], c='r', marker='x', label='y=0')
ax.scatter(data[data['Accepted'] == 1]['Exam1'], data[data['Accepted'] == 1]['Exam2'], c='g', marker='o', label='y=1')
ax.legend()
ax.set(
xlabel='exam1',
ylabel='exam2'
)
plt.show()
接下来写函数:
预测值为:
这里面代价函数:
注意矩阵乘法和*乘法最后得到的结果不一样哦!
def get_Xy(data):
data.insert(0, 'ones', 1)
X_ = data.iloc[:, :-1] # 获取除了最后一列的数据集
y_ = data.iloc[:, -1] # 获取最后一列的数据集
X = X_.values # 转化为数组
y = y_.values.reshape(len(y_.values), 1) # 从pandas中取出的只有一维的数据自动是行向量,或者(n,)没有第二维,所以reshape以下称为(n,1)
return X, y
# sigmoid函数
def sigmoid(z):
return 1 / (1 + np.exp(-z))
#损失函数
def costFunction(X, y, theta):
A = sigmoid(X @ theta) # 预测值一维矩阵
epsilon = 1e-5 # 用来避免对数计算中的无效值
first = y * np.log(A + epsilon) # 转置后直接得到一个数,如果不转置还需要对矩阵求和
second = (1 - y) * np.log(1 - A + epsilon)
return -np.sum(first + second) / len(y)
接下来,梯度下降!
上面推导的过程求出来了损失函数求导的结果:
然后对参数进行梯度下降,迭代公式为:
# 定义梯度下降
def gradientDescent(X, y, theta, iters, alpha):
m = len(X)
costs = []
for i in range(iters):
A = sigmoid(X @ theta)
theta = theta - alpha / m * X.T @ (A - y)
cost = costFunction(X, y, theta)
costs.append(cost)
if i % 1000 == 0:
print(f"Iteration {i}: cost = {cost}")
return costs, theta
下面是损失函数随着迭代次数值的变化:
最后的theta是[[-23.77498778],[ 0.18690941],[ 0.18046614]]
实现预测:
def predict(X,theta):
pre = sigmoid(X@theta)
return [1 if i >= 0.5 else 0 for i in pre ]
预测值与真实值之间对比:
完整代码:
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('tKAgg')
import matplotlib.pyplot as plt
# 读取数据
path = "D:\\JD\\Documents\\大学等等等\\自学部分\\machine_-learning-master\\machine_-learning-master\\ex_2\\ex2data1.txt"
data = pd.read_csv(path, names=['Exam1', 'Exam2', 'Accepted'])
print(data.head())
# 绘制散点图
fig, ax = plt.subplots()
ax.scatter(data[data['Accepted'] == 0]['Exam1'], data[data['Accepted'] == 0]['Exam2'], c='r', marker='x', label='y=0')
ax.scatter(data[data['Accepted'] == 1]['Exam1'], data[data['Accepted'] == 1]['Exam2'], c='g', marker='o', label='y=1')
ax.legend()
ax.set(
xlabel='exam1',
ylabel='exam2'
)
plt.show()
# 提取X和y
def get_Xy(data):
data.insert(0, 'ones', 1)
X_ = data.iloc[:, :-1] # 获取除了最后一列的数据集
y_ = data.iloc[:, -1] # 获取最后一列的数据集
X = X_.values # 转化为数组
y = y_.values.reshape(len(y_.values), 1) # 从pandas中取出的只有一维的数据自动是行向量,或者(n,)没有第二维,所以reshape以下称为(n,1)
return X, y
# sigmoid函数
def sigmoid(z):
return 1 / (1 + np.exp(-z))
# 损失函数
def costFunction(X, y, theta):
A = sigmoid(X @ theta) # 预测值一维矩阵
epsilon = 1e-5 # 用来避免对数计算中的无效值
first = y * np.log(A + epsilon) # 转置后直接得到一个数,如果不转置还需要对矩阵求和
second = (1 - y) * np.log(1 - A + epsilon)
return -np.sum(first + second) / len(y)
theta = np.zeros((3, 1))
X, y = get_Xy(data)
const_init = costFunction(X, y, theta)
print(const_init)
# 定义梯度下降
def gradientDescent(X, y, theta, iters, alpha):
m = len(X)
costs = []
for i in range(iters):
A = sigmoid(X @ theta)
theta = theta - alpha / m * X.T @ (A - y)
cost = costFunction(X, y, theta)
costs.append(cost)
if i % 1000 == 0:
print(f"Iteration {i}: cost = {cost}")
return costs, theta
alpha = 0.004
iters = 200000
costs, theta = gradientDescent(X, y, theta, iters, alpha)
print("---------------------------")
print(costs)
print("---------------------------")
print(theta)
plt.figure()
plt.plot(range(iters), costs, label='Cost')
plt.xlabel('Iterations')
plt.ylabel('Cost')
plt.title('Cost Function Convergence')
plt.legend()
plt.show()
print("---------------------------")
# print(costs)
print(theta)
def predict(X,theta):
pre = sigmoid(X@theta)
return [1 if i >= 0.5 else 0 for i in pre ]
y_pre = predict(X,theta)
# 绘制真实值与预测值的比较图
plt.figure()
plt.plot(range(len(y)), y, label='real_values', linestyle='-', marker='o', color='g')
plt.plot(range(len(y)), y_pre, label='pre_value', linestyle='--', marker='x', color='r')
plt.xlabel('label')
plt.ylabel('value')
plt.title('differ')
plt.legend()
plt.show()
附:使用数据集
标签:逻辑,plt,机器,回归,np,print,theta,data,def From: https://blog.csdn.net/2303_77275067/article/details/14045533034.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0
70.66150955499435,92.92713789364831,1
76.97878372747498,47.57596364975532,1
67.37202754570876,42.83843832029179,0
89.67677575072079,65.79936592745237,1
50.534788289883,48.85581152764205,0
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1