首页 > 其他分享 >torch神经网络--线性回归

torch神经网络--线性回归

时间:2024-10-05 13:00:22浏览次数:1  
标签:dim -- torch 神经网络 train values model numpy

简单线性回归

y = 2*x + 1

import numpy as np
import torch
import torch.nn as nn


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.linear(x)
        return out


x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1, 1)
x_train.shape

y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)
y_train.shape
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim)

# 如果使用GPU训练,增加以下两行代码
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model.to(device)


# 指定好参数和损失函数
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# 训练模型
for epoch in range(epochs):
    epoch += 1
    # 使用cpu时,注意转行成tensor
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)
    # 如果使用GPU训练,将以上两行代码修改为
    # inputs = torch.from_numpy(x_train).to(device)
    # labels = torch.from_numpy(y_train).to(device)

    # 梯度要清零每一次迭代
    optimizer.zero_grad()
    # 前向传播
    outputs = model(inputs)
    # 计算损失
    loss = criterion(outputs, labels)
    # 反向传播
    loss.backward()
    # 更新权重参数
    optimizer.step()

    # 打印
    if epoch % 50 == 0:
        print('epoch {}, loss {}'.format(epoch, loss.item()))


# CPU测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()

# 模型的保存
torch.save(model.state_dict(), 'model.pkl')
# 模型读取
model.load_state_dict(torch.load('model.pkl'))

标签:dim,--,torch,神经网络,train,values,model,numpy
From: https://www.cnblogs.com/jackchen28/p/18408065

相关文章

  • Android 11 如何不要验证Wi-Fi CA 凭证(手工连接WIFI, 需要ROOT)
    Android11如何不要验证Wi-FiCA凭证(手工连接WIFI,需要ROOT)在获取了ROOT权限的基础上,如果因为您机器所使用OS版本的限制无法在GUI界面选择符合您企业设置的WI-FI选项,可以使用本文教程中指出的手工连接WIFI的方式.Step1.检查adbshellsucat/data/misc/apexdata/c......
  • 如何给易语言软件加网络验证 永久免费的网络验证 文心云验证
    当我们自己幸幸苦苦编写了一个软件,又不想泛滥时,我们应该如何给软件添加一个授权呢我这边找了很久找到了一个方法 就是对接网络验证实现授权才能登录文心云验证是可以为开发的软件增加收费授权的功能,让作者开发的软件可以进行销售、充值、登陆等操作,并且提供防破验证功能,可以......
  • 图表不会做怎么办?AI一键生成好看图表!
    本期教你如何用AI一键生成各种数据图表!本文阅读难度:★☆☆☆☆看看别人做的这些图表,是不是挺好看的?特别是作为接商单的新写手,看到这些,头都大了,该怎么办呢?不用怕,我教你一键生成,别忘了收藏起来。数据图表类型有很多种,适合不同的场景,比如柱形图、饼图、条形图、人形图、折......
  • hadoop初学篇之三——公网全分布式部署主机IP导致的问题
    不管是内网集群,还是公网集群(当然一般情况不会这么做),建议这个步骤都不要忽略。内网不一定会出现这个问题,但是公网不做肯定有问题!前提:在阿里云公网部署三台ecs,都有公网IP,内网有通有不通(测试公网所以忽略);按照全分布模式部署,按照JDK(8)、Hadoop(2.10),各种配置完毕,namenode格式化成功后,s......
  • Leecode热题100-3.无重复字符最长子串
    给定一个字符串 s ,请你找出其中不含有重复字符的 最长 子串 的长度。示例 1:输入:s="abcabcbb"输出:3解释:因为无重复字符的最长子串是"abc",所以其长度为3。示例2:输入:s="bbbbb"输出:1解释:因为无重复字符的最长子串是"b",所以其长度为1。......
  • 【Canvas与艺术】金属底座洞眼红心按钮
    【成图】【代码】<!DOCTYPEhtml><htmllang="utf-8"><metahttp-equiv="Content-Type"content="text/html;charset=utf-8"/><head><title>金属底座洞眼红心按钮</title><styletype="text/css&quo......
  • 第2关:16位先行进位加法器设计-实验指导
    第二关也通过啦!!任务描述        本关任务:16位先行进位加法器实验目的1、组间先行进位设计2、CLU和CLA级联应用实验原理         对于一个16位加法器,可以分成4组,每组用一个4位先行进位加法器CLA实现。下图是一个由4个4位先行进位加法器CLA与一个组间......
  • blender贴图丢失,贴图显示紫色
    闲言一般在模型复制粘贴或转移过程中,发生贴图加载失败,导致模型贴图位置显示紫色.如果是上述相关情况,那么本文章应能为你提供相关帮助.本人配置:win11-blender3.6(本案例演示版本)-blender4.2打开丢失材质模型(.blend).fbx导入也是一样的,这里不赘述.打开材质......
  • 织梦CMS遇到数据库连接失败怎么办?
    当织梦CMS遇到数据库连接失败时,可能有多种原因导致此问题。以下是一些常见的故障排查和解决方法:1.检查数据库配置文件打开配置文件打开织梦CMS的数据库配置文件 include/config.inc.php。使用FTP工具或SSH连接到服务器,然后打开该文件。检查配置信息确认数据库配置......
  • uv --- replacement of conda + pip (python version + package version install) pyt
    uvhttps://docs.astral.sh/uv/AnextremelyfastPythonpackageandprojectmanager,writteninRust. InstallingTrio'sdependencieswithawarmcache.Highlights......