首页 > 其他分享 >【2022.11.15】pytorch的使用相关(四)

【2022.11.15】pytorch的使用相关(四)

时间:2022-11-16 13:44:22浏览次数:83  
标签:torch 15 features labels pytorch num true 2022.11 size

参考资料

ShusenTang/Dive-into-DL-PyTorch: 本项目将《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。 (github.com)

python数组冒号取值操作 - python中冒号的用法 - 实验室设备网 (zztongyun.com)

生成数据集

image-20221115104751837

num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
# 创建一个随机1000样本,每个样本2个特征
features = torch.randn(num_examples, num_inputs,
                       dtype=torch.float32)

labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()),
                       dtype=torch.float32)
# 这里得到的label是1000*1的矩阵
print(labels.shape)

冒号的使用

技术分享图片

生成数据集

num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
# 创建一个随机1000样本,每个样本2个特征
features = torch.randn(num_examples, num_inputs,
                       dtype=torch.float32)

labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()),
                       dtype=torch.float32)
# 这里得到的label是1000*1的矩阵
print(labels.shape)

def use_svg_display():
    # 用矢量图显示
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(5.5, 5.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize

# # 在../d2lzh_pytorch里面添加上面两个函数后就可以这样导入
# import sys
# sys.path.append("..")
# from d2lzh_pytorch import * 

set_figsize()

根据以上内容生成代码

plt.scatter(np.arange(1000), features[:, 0].numpy())
plt.scatter(np.arange(1000), features[:, 1].numpy())

image-20221115131628053

而计算得到的label得到如下结果

plt.scatter(np.arange(1000), labels.numpy())

image-20221115131902009

plt.scatter(features[:, 0].numpy(), labels.numpy(), 1);
plt.scatter(features[:, 1].numpy(), labels.numpy(), 1);

可以看到两个不同方向的点集,这是因为权重是一正一负导致的,此时是[2, -3.4]

image-20221115130309531

当我们拉得极限一些,比如[2, -23.4],可以看到图像变化,呈现一个明显的斜率

可以大致估算出是(0-80)/(4-0)=-20,而另一个数字,难以从图上估算出来

image-20221115134514741

读取数据

函数

# 本函数已保存在d2lzh包中方便以后使用
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 样本的读取顺序是随机的
    for i in range(0, num_examples, batch_size):
        j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) # 最后一次可能不足一个batch
        yield  features.index_select(0, j), labels.index_select(0, j)

读取

batch_size = 5
for X, y in data_iter(batch_size, features, labels):
    print(X,'\n', y)
    break

# 以下为输出
tensor([[-1.4393, -1.0121],
        [ 0.0677, -0.1193],
        [ 0.7968, -0.2002],
        [ 0.0873, -0.8061],
        [ 1.0987,  0.2934]]) 
 tensor([4.7532, 4.7421, 6.4511, 7.1202, 5.3975])

取其中[ 0.7968, -0.2002]进行计算可得6.4743,与实际例子的6.4511不等

初始化模型参数

w = torch.tensor(np.random.normal(0, 0.01, (num_inputs, 1)), dtype=torch.float32, requires_grad=True)
b = torch.zeros(1, dtype=torch.float32, requires_grad=True)

print(w.requires_grad)
print(b.requires_grad)

将权重初始化成均值为0、标准差为0.01的正态随机数,偏差则初始化成0。

之后的模型训练中,需要对这些参数求梯度来迭代参数的值,因此我们要让它们的requires_grad=True

定义模型

线性回归的矢量计算表达式的实现。我们使用mm函数做矩阵乘法。

def linreg(X, w, b):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    return torch.mm(X, w) + b

定义损失函数

使用平方损失来定义线性回归的损失函数。在实现中,我们需要把真实值y变形成预测值y_hat的形状。以下函数返回的结果也将和y_hat的形状相同

def squared_loss(y_hat, y):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    # 注意这里返回的是向量, 另外, pytorch里的MSELoss并没有除以 2
    return (y_hat - y.view(y_hat.size())) ** 2 / 2

定义优化算法

def sgd(params, lr, batch_size):  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    for param in params:
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data

训练模型

在训练中,我们将多次迭代模型参数。在每次迭代中,我们根据当前读取的小批量数据样本(特征X和标签y),通过调用反向函数backward计算小批量随机梯度,并调用优化算法sgd迭代模型参数。由于我们之前设批量大小batch_size为10,每个小批量的损失l的形状为(10, 1)。回忆一下自动求梯度一节。由于变量l并不是一个标量,所以我们可以调用.sum()将其求和得到一个标量,再运行l.backward()得到该变量有关模型参数的梯度。注意在每次更新完参数后不要忘了将参数的梯度清零。

在一个迭代周期(epoch)中,我们将完整遍历一遍data_iter函数,并对训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设3和0.03。在实践中,大多超参数都需要通过反复试错来不断调节。虽然迭代周期数设得越大模型可能越有效,但是训练时间可能过长。而有关学习率对模型的影响,我们会在后面“优化算法”一章中详细介绍。

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):  # 训练模型一共需要num_epochs个迭代周期
    # 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X
    # 和y分别是小批量样本的特征和标签
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y).sum()  # l是有关小批量X和y的损失
        l.backward()  # 小批量的损失对模型参数求梯度
        sgd([w, b], lr, batch_size)  # 使用小批量随机梯度下降迭代模型参数

        # 不要忘了梯度清零
        w.grad.data.zero_()
        b.grad.data.zero_()
    train_l = loss(net(features, w, b), labels)
    print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))

训练完成后,我们可以比较学到的参数和用来生成训练集的真实参数。它们应该很接近

print(true_w, '\n', w)
print(true_b, '\n', b)

# 输出
[2, -3.4] 
 tensor([[ 1.9992],
        [-3.3994]], requires_grad=True)
4.2 
 tensor([4.1998], requires_grad=True)

标签:torch,15,features,labels,pytorch,num,true,2022.11,size
From: https://www.cnblogs.com/mokou/p/16895614.html

相关文章

  • Word15 财务部年度报告office真题
    1.课程的讲解之前,先来对题目进行分析,首先需要在考生文件夹下,将Wrod素材.docx文件另存为Word.docx,后续操作均基于此文件,否则不得分。   2.这一步非常的简单,打开下载素......
  • NTMFS4C810NAT3G场效应管30V NCH,DEC1515H-D0-I/Z2集成电路TQFP
    产品参数1、型号:DEC1515H-D0-I/Z2封装:TQFP128批次:新年份2、型号:NTMFS4C810NAT3GFET类型:N通道技术:MOSFET(金属氧化物)漏源电压(Vdss):30V25°C时电流-连续漏极(Id):8.2A(Ta......
  • 2022.11.14模拟赛题解
    树的覆盖\(dp_{i,j,0/1/2}\)表示以\(i\)为根的子树中覆盖\(j\)个点的方案数。其中\(0/1/2\)分别表示了\(3\)种情况。\(0\)表示示当前节点和子节点都没被选中......
  • 记录visiual studio 编译qt5.15.6
    准备工作1.qt源码下载可通过gitee下载,具体不介绍2.perl安装建议通过360软件管家安装,或者其他软件市场。比去官方下载快多了3.python我是用3.9的4.visiualstudio......
  • Navicat premium 15安装+激活-九五小庞
    下载安装包:https://wwz.lanzoue.com/b021z9k1e密码:7pm1 NavicatPremium15安装教程:一、解压后,双击安装包 ​二、点击下一步​三、勾选同意 ,点击下一步......
  • 11.15 解题报告
    T1考场用时:\(40\)min期望得分:\(30\)pts实际得分:\(30\)pts这题以前做过。首先显然的一点是小Y行走的路径是一棵树,这题可以分两部分来做,首先对于每一个节点按照节......
  • 力扣 153. 寻找旋转排序数组中的最小值 [二分变种]
    153.寻找旋转排序数组中的最小值已知一个长度为 n 的数组,预先按照升序排列,经由 1 到 n 次 旋转 后,得到输入数组。例如,原数组 nums=[0,1,2,4,5,6,7] 在变......
  • Python 文本文件拖上转自适应图片 - 学习笔记(2022.11.16)
    Python文本文件拖上转自适应图片功能:1、支持拖拽执行2、文本文件转为自适应尺寸图片1importre2importos3importsys4importtime5fromPI......
  • GL-Learning new words 20221115
    GLLearingnewwords20221115HowdoyoulearnnewwordsinEnglish?WhenIcomeacrossawordthatIdon'tknow,Idirectlycheckthepronunciationanddefini......
  • CF刷题计划?(upd:11.15)
    CF刷题计划?CF1285F太nb了这个题暴力一点的做法是二分后直接莫反,但是不够快考虑枚举一个\(\gcd\),令其为\(d\)然后从大到小枚举数,然后把\(\gcd(\frac{x}{d},\frac{y}{d}......