Pytorch 自动梯度法 完整例子
下面是一个使用 PyTorch 自动梯度法的完整例子。这个例子展示了如何训练一个简单的线性回归模型来拟合一组数据。我们将从头到尾覆盖所有步骤,包括数据准备、模型定义、训练过程以及评估。
1. 安装 PyTorch
确保你已经安装了 PyTorch。如果没有,请先安装:
pip install torch
2. 完整代码示例
import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 1. 生成模拟数据 # 输入数据 x = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) # 目标数据 y = np.array([2.0, 4.0, 6.0, 8.0, 10.0]) # 将数据转换为 PyTorch 张量 x_tensor = torch.tensor(x, dtype=torch.float32).view(-1, 1) # shape: (5, 1) y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1) # shape: (5, 1) # 2. 定义线性回归模型 class LinearRegression(nn.Module): def __init__(self): super(LinearRegression, self).__init__() self.linear = nn.Linear(1, 1) # 输入特征1,输出特征1 def forward(self, x): return self.linear(x) # 实例化模型 model = LinearRegression() # 3. 定义损失函数和优化器 criterion = nn.MSELoss() # 均方误差损失 optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器 # 4. 训练模型 num_epochs = 1000 for epoch in range(num_epochs): # 前向传播 outputs = model(x_tensor) loss = criterion(outputs, y_tensor) # 反向传播 optimizer.zero_grad() # 清空之前的梯度 loss.backward() # 计算新的梯度 optimizer.step() # 更新参数 # 打印损失 if (epoch+1) % 100 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}') # 5. 评估模型 with torch.no_grad(): predicted = model(x_tensor).numpy() # 6. 可视化结果 plt.scatter(x, y, color='blue', label='Original data') plt.plot(x, predicted, color='red', label='Fitted line') plt.xlabel('x') plt.ylabel('y') plt.legend() plt.show()
代码解释
-
生成数据:
- 创建了简单的输入数据
x
和目标数据y
。 - 将这些数据转换为 PyTorch 张量,并调整形状以符合模型的输入要求。
- 创建了简单的输入数据
-
定义模型:
- 创建一个简单的线性回归模型
LinearRegression
,该模型包含一个线性层nn.Linear
。
- 创建一个简单的线性回归模型
-
定义损失函数和优化器:
- 使用均方误差损失函数
nn.MSELoss()
。 - 使用随机梯度下降优化器
optim.SGD()
,设置学习率为 0.01。
- 使用均方误差损失函数
-
训练模型:
- 在每个训练轮次中,进行前向传播计算输出和损失。
- 调用
loss.backward()
计算梯度,并使用optimizer.step()
更新模型参数。 - 每 100 个 epoch 打印一次损失。
-
评估模型:
- 在不需要计算梯度的情况下进行预测,并将结果转换为 NumPy 数组以便于绘图。
-
可视化结果:
- 使用 Matplotlib 绘制原始数据点和模型拟合的直线。
这个完整的例子展示了如何从头到尾使用 PyTorch 进行基本的深度学习任务,包括数据准备、模型定义、训练和评估。
标签:plt,tensor,nn,梯度,模型,torch,Pytorch,相关,第四篇 From: https://www.cnblogs.com/lovebay/p/18401647