首页 > 其他分享 >【深度学习】| 线性回归的简单实现

【深度学习】| 线性回归的简单实现

时间:2022-08-17 08:58:47浏览次数:56  
标签:loss data 回归 torch 点击 深度 线性 net true

1 概述

本文的主要目的是通过实现最简单的线性回归模型,理解pytorch在数据导入、模型定义、、损失计算、优化迭代、自动求导和批次训练等方面的特点。

2 数据导入

首先,生成真实的线性函数,参数为w和b;接着按照w和b的size来生成1000个样本数据

点击查看代码
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
构造出数据的Dataset类,将其放入到DataLoader中方便进行批次训练,DataLoader可以实现对Dataset中的数据进行shuffle和批次大小的划分。
点击查看代码
def load_array(data_arrays, batch_size, is_train = True):
    '''构造一个PyTorch数据迭代器'''
    dataset = data.TensorDataset(*data_arrays)# 此处*的作用
    return data.DataLoader(dataset, batch_size, shuffle = is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

3 模型定义

使用框架预定义好的层,nn是神经网络的缩写

点击查看代码
from torch import nn
net = nn.Sequential(nn.Linear(2, 1))
初始化模型参数
点击查看代码
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

4 定义损失函数

点击查看代码
loss = nn.MSELoss()

5 定义优化算法

点击查看代码
trainer = torch.optim.SGD(net.parameters(),lr = 0.03)

6 训练

使用随机小批量梯度下降法
训练过程中打印的loss是为研究者观察模型是否往参数逐步优化的方向变化而给的一个参考指标。
data_iter中的loss是指一个批次的损失。
trainer.zero_grad()是梯度清零
l.backward()是求解这一个批次的样本的导数和
trainer.step()# 以求得的导数和,结合优化器,更新参数 w和b,然后进行下一批次的训练

点击查看代码
# 训练
num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
#         print(X.shape)
        l = loss(net(X), y)# 一个批次的损失
        trainer.zero_grad()
        l.backward()
        trainer.step()
#     print(features.shape)
    l = loss(net(features), labels)# 整个数据集的损失
    print(f'epoch {epoch + 1}, loss {l:f}')

7 打印结果

点击查看代码
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

image

标签:loss,data,回归,torch,点击,深度,线性,net,true
From: https://www.cnblogs.com/arkon/p/16593619.html

相关文章

  • 数模(1)—— 多重共线性的验证
    模型的解释变量之间存在线性关系若中心化之后自变量的相关系数矩阵R=X'X接近于退化就存在多重共线性R有多少个特征根接近于零,设计矩阵X就有多少个多重共线性关系......
  • 【数据结构与算法】线性表——顺序表的实现
    顺序表的实现C++代码使用了模板。使用的时候直接导入头文件即可。代码实现相关细节、解释都在注释里了。那么就直接上代码了。//MySeqList.h文件#ifndef__MYSEQLIS......
  • C# 深度复制对象 反序列化方式与复制构造函数方式的效率分析
    先看结果 所以复制构造函数优于序列化和反序列化代码如下:usingSystem;usingSystem.Collections.Generic;usingSystem.Diagnostics;usingSystem.Linq;using......
  • Linux内核深度解析 pdf
    高清扫描版下载链接:https://pan.baidu.com/s/1rfIX0DCTQeqXmNCCnFhteQ点击这里获取提取码。 ......
  • 一文看懂线性回归和非线性回归
    一文看懂线性回归和非线性回归           1.非线性回归           2.线性回归           3.总结1.非线性回归我们首先来看维基百......
  • autodl3-配置深度学习环境
    1.激活conda在jupyterlab终端输入vim ~/.bashrc  首先输入i,进入编辑模式在最后加上路径:(minconda路径)   按esc:wq保存退出    刷新    ---......
  • 深度领先 |《测试开发工程师质量监控实战训练营》开营啦!
    这个训练营有多难得,就不用我多说什么啦,懂的都懂。《测试开发工程师质量监控实战训练营》由资深测试架构师、开源项目作者亲授BAT大厂前沿最佳实践。手把手带你搭建质量监......
  • 线性基
    writtenon2022-08-14学高斯消元的时候顺便学到了线性基,线性基通常在异或运算中出现。这里先贴一下别人的博客,个人认为这篇博客总结的还是蛮好的,可以特别关注一下里面......
  • 【MIT18.06·线性代数02】
    线性方程组的矩阵形式可以将线性方程组写成\(Ax=b\)的矩阵相乘形式:比如线性方程组\(\left\{\begin{matrix}\begin{aligned}2x-y&=0\\-x+2y&=3\end{aligned}\end......
  • 【MIT18.06·线性代数01】过去对线性方程组的理解
    如何理解一个线性方程组?考虑这样一个方程组:\(\left\{\begin{matrix}\begin{aligned}2x-y&=0\\-x+2y&=3\end{aligned}\end{matrix}\right.\)在之前的理解方式中,求......