首页 > 其他分享 >5-用PyTorch实现线性回归

5-用PyTorch实现线性回归

时间:2024-08-13 20:16:19浏览次数:6  
标签:linear 回归 torch print PyTorch 线性 model data self





下面是损失函数

下面是优化器
下面通过model.parameters()可以获得model中所有的参数


点击查看代码
import torch
from torch import device

x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1) # 权重和偏置

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel() # 定义模型

# 定义损失函数
criterion = torch.nn.MSELoss()
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 优化器对哪些参数进行更新

for epoch in range(5000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data) # 计算损失
    print(epoch, loss.item())

    optimizer.zero_grad() # 梯度清零
    loss.backward() # 反向传播
    optimizer.step() # 参数更新

print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.tensor([4.0])
print('predict(4)=', model(x_test).item())

标签:linear,回归,torch,print,PyTorch,线性,model,data,self
From: https://www.cnblogs.com/morehair/p/18357625

相关文章

  • 【视频讲解】滚动回归Rolling Regression、ARIMAX时间序列预测Python、R实现应用
    原文链接: https://tecdat.cn/?p=37338原文出处:拓端数据部落公众号分析师:JixinZhong  本文将通过视频讲解,展示如何用滚动回归预测,并结合一个R语言多元时间序列滚动预测:ARIMA、回归、ARIMAX模型分析实例的代码数据,为读者提供一套完整的实践数据分析流程。滚动回归估计是于一......
  • 2024亚太杯数学建模b题基于机器学习回归的洪水预测模型研究
    本届亚太杯中文赛项已经结束,本文分享我的解决思路。摘 要洪水的频率和严重程度与人口增长趋势相近。迅猛的人口增长,扩大耕地,围湖造田,乱砍滥伐等人为破坏不断地改变着地表状态,改变了汇流条件,加剧了洪灾程度。2023年,全球洪水造成了数十亿美元的经济损失。因此构建与研究洪水......
  • 【人工智能】 使用线性回归预测波士顿房价 paddlepaddle 框架 飞桨
    一、简要介绍经典的线性回归模型主要用来预测一些存在着线性关系的数据集。回归模型可以理解为:存在一个点集,用一条曲线去拟合它分布的过程。如果拟合曲线是一条直线,则称为线性回归。如果是一条二次曲线,则被称为二次回归。线性回归是回归模型中最简单的一种。本示例简要介......
  • 【Python机器学习】树回归——使用Python的tkinter库创建GUI
    机器学习给我们提供了一些强大的工具,能从未知数据中抽取出有用的信息。因此,能否这些信息以易于人们理解的方式呈现十分重要。如果人们可以直接与算法和数据交互,将可以比较轻松的进行解释。其中一个能够同时支持数据呈现和用户交互的方式就是构建一个图形用户界面(GUI)。利用GUI......
  • 【图像去噪】论文复现:新手入门必看!DnCNN的Pytorch源码训练测试全流程解析!为源码做详细
    第一次来请先看【专栏介绍文章】:源码只提供了noiselevel为25的DnCNN-S模型文件。本文末尾有完整代码和训练好的σ=15,25,50的DnCNN-S、σ∈[0,55]的DnCNN-B和CDnCNN-B、DnCNN-3共6个模型文件!读者可以自行下载!本文亮点:以官方Pytorch源代码为基础,在DnCNN-S的基础上,增添Dn......
  • CF1615H-Reindeer Games【保序回归,整体二分,网络流】
    正题题目链接:https://www.luogu.com.cn/problem/CF1615H题目大意有\(n\)个点,每个点有个初始权值\(a_i\),你每次可以让一个点权值\(+1\)或者\(-1\)。有\(m\)个限制要求某个点最终权值小于等于另一个点。求最少的操作次数使得满足所有限制。\(2\leqn,m\leq1000,1......
  • 24/8/12算法笔记 复习_线性回归
    importnumpyasnp#导入包X=np.array([[1,1],[2,1]])#构造矩阵y=np.array([14,10])np.linalg.solve(X,y)#linalg是线性代数,用于求解线性方程AX=b,solve计算线性代数回归问题X.T#转置a=X.T.dot(X)#矩阵乘法B=np.linalg.inv(a)#求逆矩阵fromsklearn.line......
  • [WC2019] 数树纯组合线性做法
    NaCly_Fish的博客激发了继续思考的欲望。我是多项式白痴,所以让我们来思考组合意义做法!本题本质上是需要让我们求\(\sum_{E_1\text{是树}}\sum_{E_2\text{是树}}y^{-|E1\cupE2|}\)的值。我们容斥一下交集,发现考虑上容斥系数就是将\(y\leftarrow\frac{1}{y}-1\)。剩下......
  • 用Python实现9大回归算法详解——01线形回归算法
    1.线性回归的基本概念线性回归是一种最基本的监督学习算法,用于预测因变量(目标变量)和一个或多个自变量(特征变量)之间的关系。线性回归假设因变量与自变量之间的关系是线性的,即可以用以下形式的线性方程来表示:其中: 是因变量(目标变量)。 是自变量(特征变量)。是截距项,表示当所......
  • 构造用于线性回归分析使用的波动上升随机数据并绘制散点图
    一、简介进行线性数据回归分析经常需要用到波动上升的随机数据,本文给出了使用python构建的由线性数据+随机数据+正弦数据的波动上升数据并绘制散点图的代码和效果展示。该数据共5段100个可用于进行线性回归数据分析。二、代码#-*-coding:utf-8-*-#导入第三方库import......