学习地址:
https://www.bilibili.com/video/BV1Y7411d7Ys?p=2&vd_source=001ba1b001e88ca6d09a9b0de2a86d71
colab链接:
https://colab.research.google.com/gist/cyberangelisme/e5d90757fa1fafe0068b344298d05d7e/.ipynb
import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
def forward(x):
return x * w
def loss(x, y):
y_pred = forward(x)
return (y - y_pred)*(y - y_pred)
w_list=[]
mse_list=[]
for w in np.arange(0,4.1,0.1):
print("w = ",w)
l_sum = 0;
for x_val, y_val in zip(x_data,y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val,y_val)
l_sum += loss_val
print('\t',x_val,y_val,y_pred_val,loss_val)
print("MSE = ",l_sum/3)
w_list.append(w)
mse_list.append(l_sum/3)
plt.plot(w_list,mse_list)
plt.xlabel("w")
plt.ylabel("MSE")
plt.show()
标签:loss,plt,val,--,pred,list,pytorch,线性,data
From: https://www.cnblogs.com/rabbithacker/p/17023117.html