首页 > 其他分享 >手写pytorch线性回归

手写pytorch线性回归

时间:2023-06-08 23:33:50浏览次数:47  
标签:features data torch pytorch num 线性 手写 true size


Python下划线的五种用法

手写线性回归
教程地址 未解决的问题:plt.show()会阻塞

import torch
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random
from tqdm import tqdm
from multiprocessing import Pool


# generate dataset

true_w = torch.Tensor([2, -3.4])
true_w = torch.unsqueeze(true_w, 1)
true_b = torch.Tensor([4.2])

feature_num = 2
data_num = 10000

features = torch.randn(data_num, feature_num, dtype=torch.float32)

true_labels = features.mm(true_w) + true_b
# add Gaussian noise
noise_labels = true_labels + torch.tensor(np.random.normal(0, 0.01, size=true_labels.size()), dtype=torch.float32)

# visualize
def use_svg_display():
    display.set_matplotlib_formats('svg')


def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    plt.rcParams['figure.figsize'] = figsize

#  手写数据迭代器
def data_iter(batch_size, features, labels):
    data_num = len(features)
    indices = list(range(data_num))
    random.shuffle(indices)  # 将indices的列表顺序打乱
    for i in range(0, data_num, batch_size):  # 从0迭代到data_num-1
        # 构建long类型的张量,and防止最后一次不足一个batch
        j = torch.LongTensor(indices[i: min(i + batch_size, data_num)])
        #  yield相当于return,但在哪里跌倒就在哪里站起来,并且节约空间
        # 0代表按行索引,j代表索引哪些行
        yield features.index_select(0, j), labels.index_select(0, j)


# initialize
w = torch.tensor(np.random.normal(0, 0.01, (feature_num, 1)), dtype=torch.float32)
b = torch.zeros(1, dtype=torch.float32)

w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)


# 线性回归计算
def linreg(X, w, b):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    return torch.mm(X, w) + b


def squared_loss(y_hat, y):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    # 注意这里返回的是向量, 另外, pytorch里的MSELoss并没有除以 2
    # size不需要用到
    # return (y_hat - y.view(y_hat.size())) ** 2 / 2
    return (y_hat - y.data) ** 2 / 2


def sgd(params, lr, batch_size):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    for param in params:
        # 注意这里更改param时用的param.data,这样不会被追踪自动微分
        # 因为微分是根据y的值用反向传播算出来的,如果直接改变x的值,会把改变的范围也算进微分里面
        # 或者会导致梯度算出来是None,因为x不再是叶子节点
        param.data -= lr * param.grad / batch_size


# configuration
batch_size = 100
lr = 0.03
num_epochs = 10
#  函数句柄
net = linreg
loss = squared_loss

for epoch in tqdm(range(num_epochs)):  # 训练模型一共需要num_epochs个迭代周期
    # 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X
    # 和y分别是小批量样本的特征和标签
    for X, y in data_iter(batch_size, features, noise_labels):
        l = loss(net(X, w, b), y).sum()  # l是有关小批量X和y的损失
        l.backward()  # 小批量的损失对模型参数求梯度
        sgd([w, b], lr, batch_size)  # 使用小批量随机梯度下降迭代模型参数

        # 不要忘了梯度清零
        w.grad.data.zero_()
        b.grad.data.zero_()
    train_l = loss(net(features, w, b), noise_labels)
    print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))

# 输出运行结果
print(true_w, '\n', w)
print(true_b, '\n', b)


# 画图的放在前面会阻塞后面的进程
set_figsize()
plt.scatter(features[:, 1].numpy(), noise_labels.numpy(), 1, c='r')
plt.scatter(features[:, 1].numpy(), true_labels.numpy(), 1, c='g')
plt.ioff()
plt.show()
plt.ioff()

运行结果

手写pytorch线性回归_线性回归


手写pytorch线性回归_Data_02


调用nn.linear自己构造类

import torch
import torch.utils.data as Data
import torch.nn as nn
import numpy as np

# generate data
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2

features = torch.tensor(np.random.normal(1, 1, (num_examples, num_inputs)), dtype=torch.float)
labels = true_w[0] * features[:, 0] + true_w[1] + features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

batch_size = 10
# 数据和标签组合
dataset = Data.TensorDataset(features, labels)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)


class LinearNet(nn.Module):
    def __init__(self, n_feature):
        ##  python2写法,class,self
        # super(LinearNet, self).__init__()
        # python3也可以这么写
        super().__init__()
        self.linear = nn.Linear(in_features=n_feature, out_features=1, bias=True)

    def forward(self, x):
        y = self.linear(x)
        return y


net = LinearNet(num_inputs)
print(net)

nn.init.normal_(net.linear.weight, mean=0, std=0.01)
nn.init.constant_(net.linear.bias, val=0)

loss = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.03)


num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        output = net(X)
        # output是(10,1),所以也要改y
        l = loss(output, y.view(-1, 1))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d, loss:%f' % (epoch, l.item()))

结果:

手写pytorch线性回归_迭代_03


用sequential构造网络

import torch
import torch.utils.data as Data
import torch.nn as nn
import numpy as np

# generate data
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2

features = torch.tensor(np.random.normal(1, 1, (num_examples, num_inputs)), dtype=torch.float)
labels = true_w[0] * features[:, 0] + true_w[1] + features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

batch_size = 10
# 数据和标签组合
dataset = Data.TensorDataset(features, labels)
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)

net = nn.Sequential()
net.add_module('linear', nn.Linear(num_inputs, 1))

print(net)
# 打印出第0层
print(net[0])

# 打印所有可学习参数
for param in net.parameters():
    print(param)

nn.init.normal_(net[0].weight, mean=0, std=0.01)
nn.init.constant_(net[0].bias, val=0)

loss = nn.MSELoss()

# 不同的网络设置不同的学习率
optimizer = torch.optim.SGD([
    {'params': net.linear.parameters()},  # lr=0.01
], lr=0.03)
print(optimizer)

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        output = net(X)
        # output是(10,1),所以也要改y
        l = loss(output, y.view(-1, 1))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d, loss:%f' % (epoch, l.item()))

运行结果:

手写pytorch线性回归_线性回归_04


标签:features,data,torch,pytorch,num,线性,手写,true,size
From: https://blog.51cto.com/u_16131692/6444256

相关文章

  • 非线性规划凸优化——凸函数、凸规划(二)
    凸规划是指若最优化问题的目标函数为凸函数,不等式约束函数也为凸函数,等式约束函数是仿射的。凸规划的可行域为凸集,因而凸规划的局部最优解就是它的全局最优解。当凸规划的目标函数为严格凸函数时,若存在最优解,则这个最优解一定是唯一的最优解。一、凸集凸集:设\(C\)为\(n\)维欧式......
  • 第2章-线性表
    1.顺序表1.1顺序表的定义1.1.1静态分配:#include<stdio.h>#defineMaxSize10typedefstruct{ intdata[MaxSize]; intlength;}Sqlist;//初始化一个顺序表voidInitList(Sqlist&L){ for(inti=0;i<MaxSize;i++){ //TODO L.data[i]=0;//将所有数据元素设置为默认......
  • 第2章-线性表习题
    P1708#include<stdio.h>#include<iostream>usingnamespacestd;voidreverse(inta[],intn,intm,intsize){ for(inti=0;i<size;i++){ a[i]=i+1; } for(inti=0;i<size;i++) cout<<a[i]<<""; cout<<endl;......
  • orin上安装cuda pytorch gpu运行环境
    https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048一、先重新装jetpack【JetsonAgxOrin】执行sudoaptinstallnvidia-jetpack命令时报错:E:Unabletolocatepackagenvidia-jetpack二、查看是否有/usr/local/cuda-11.4jetsonnano查看CUDA版本:nvcc-V报错......
  • 深度学习项目之mnist手写数字识别实战(TensorFlow框架)
    mnist手写数字识别是所有深度学习开发者的必经之路,mnist数据集的图片十分简单,是二值化图像,像素个数为28x28。所以对于所有研究深度学习的开发者来说学会mnist数据集的模型十分有必要。以此为实例进行计算机视觉如何进行识别出图片中的数据。MNIST手写数字数据集来自美国国家标准与......
  • 【数学】各种积性函数的线性筛法
    【数学】各种积性函数的线性筛法前置芝士:几种特殊的积性函数的定义及基本性质。定义积性函数:若函数\(f(x)\)满足\(f(x)=1\)且\(\forallx,y\inN^+,gcd(x,y)=1\),都有\(f(xy)=f(x)f(y)\),则\(f(x)\)为积性函数。完全积性函数:若函数满足\(f(x)=1\)且\(\forallx,y\in......
  • Pytorch
    Pytorch张量直接张量创建依据数值创建依据概率创建拼接切分索引变换四则运算自动求导数据如何读取你自己的数据集?如何图像数据预处理及数据增强?模型如何构建神经网络?如何初始化参数?损失函数如何选择损失函数?如何设置损失函数?优化器如何管理参数?如何调整学习率?迭代过程:如何观察训练......
  • 从0开始学pytorch【4】--维度变换、拼接与拆分
    从0开始学pytorch【4】--维度变换、拼接与拆分学习内容:维度变换:张量拆分与拼接:小结学习内容:维度变换、张量拆分与拼接维度变换:1、viewimporttorcha=torch.rand(4,1,28,28)print(a.shape)print(a.view(4,28*28))print(a.shape)b=a.view(4,28,-1)b.view(4,28,28,-1......
  • 从0开始学pytorch【3】--张量数据类型
    从0开始学pytorch【3】--张量数据类型前言学习目标基本数据类型创建tensor索引、切片小结前言  在前两篇博文中,从0开始学pytorch【1】–线性函数的梯度下降、从0开始学pytorch【2】——手写数字集案例中介绍了人工智能入门最为基础的梯度下降算法实现,以及机器学习、深度网络编......
  • python线性脚本生成基本eml邮件,压缩文件,接口灌数据
    1importdatetime,zipfile,tarfile,logging,os,string,random,ipaddress,uuid,pytz,py7zr2importio,socket3fromemail.mime.textimportMIMEText4fromemail.mime.multipartimportMIMEMultipart5fromemail.mime.applicationimportMIMEA......