- 二维绘图
import matplotlib.pyplot as plt
plt.plot (x,y)
- 三维绘图
import matplotlib.pyplot as plt
fig = plt.figure ()
//创建一个图形窗口
ax = fig.add_subplot(111, projection = '3d')
//111指的是一行一列子图的第一个是这个图
ax.plot_surface (w_grid,b_grid,mse_grid)
//x,y,z轴
w_grid,b_grid = numpy.meshgrid (numpy.arange (0.0,4.1,0.1),numpy.arange (0.0,4.1,0.1))
mse_grid = numpy.zeros_like (w_grid)
//meshgrid是将两个数组合为一个二维数组每一行是w_grid的copy,每一列是b_grid的copy,numpy.arrange(0.0,4.1,0.1)是指produce了一个从0.0到4.0步长为0.1的数组
- 反向传播
l.backward()
//反向传播
w.grad.data.zero_()
//每次梯度下降之后要记得将grad归零,不归零的话会把每次的grad给加起来
标签:plt,4.1,0.1,0.0,pytorch,grid,numpy
From: https://www.cnblogs.com/currytrey/p/18372665