首页 > 其他分享 >《手动学习深度学习》3.2和3.3的代码对比

《手动学习深度学习》3.2和3.3的代码对比

时间:2024-03-02 11:33:19浏览次数:31  
标签:loss torch batch 学习 epoch 3.2 3.3 data size

3.2 线性回归的从零开始

这是我的第一个代码,也算是属于自己的hello world了,特此纪念,希望继续努力。
代码中引入了3.1中的计时模块,用来对比训练时间。

  import random
  import torch
  from d2l import torch as d2l
  import  sys
  sys.path.append("..")
  from timer import Timer

  # 定时器计时
  timer = Timer()

  # 生成数据集
  def synt_data(w, b, num):
      X = torch.normal(0, 1, (num, len(w)))
      y = torch.matmul(X, w) + b
      y += torch.normal(0, 0.01, y.shape)
      return X, y.reshape((-1, 1))

  true_w = torch.tensor([2, -3.4])
  true_b = 4.2
  features, labels = synt_data(true_w, true_b, 1000)

  # d2l.set_figsize()
  # d2l.plt.scatter(features[: , 0].detach().numpy(), labels.detach().numpy(), 1)
  # d2l.plt.show()

  # 读取数据集
  def data_iter(batch_size, features, labels):
      num = len(features)
      # 打乱下标
      indices  = list(range(num))
      random.shuffle(indices)
      for i in range(0, num, batch_size):
          # 每次获取10个数据作为一个batch
          batch_indices = torch.tensor(indices[i: min(i + batch_size, num)])
          # 获取数据
          yield features[batch_indices], labels[batch_indices]

  batch_size = 10

  # for X, y in data_iter(batch_size, features, labels):
  #     print(X, '\n', y)
  #     break

  # 初始化模型参数
  w = torch.normal(0, 0.01, (2,1), requires_grad=True)
  b = torch.zeros(1, requires_grad=True)

  # 定义模型
  def linreg(X, w, b):
      return torch.matmul(X, w) + b

  #损失函数
  def squ_loss(y_hat, y):
      return (y_hat - y.reshape(y_hat.shape))**2 / 2

  # 定义优化算法,简单梯度下降
  def sgd(params, lr, batch_size):
      with torch.no_grad():
          for param in params:
              param -= lr * param.grad /batch_size
              param.grad.zero_()

  # 训练
  lr  = 0.03
  num_epochs = 3
  net = linreg
  loss = squ_loss

  for epoch in range(num_epochs):
      for X, y in data_iter(batch_size, features, labels):
          l = loss(net(X, w ,b), y)
          l.sum().backward()
          sgd([w,b], lr, batch_size)
      with torch.no_grad():
          train_l = loss(net(features, w, b), labels)
          print(f'epoch {epoch + 1 }, loss {float(train_l.mean()):f}')

  print(f'time {timer.stop(): .5f} sec')

3.3 线性回归的简单实现

这段代码敲的时候有几个有趣的发现:

  1. 手动写的梯度下降函数里面,有自动的梯度清零和参数更新,但是torch实现的SGD应该是没有的,所以才需要使用trainer.zero_grad()trainer.step()
  2. yeild的使用可以实现迭代器一样的效果
  import numpy as np
  import torch
  from torch.utils import data
  from torch import nn
  from d2l import torch as d2l
  import  sys
  sys.path.append("..")
  from timer import Timer

  # 定时器计时
  timer = Timer()

  # 生成数据集
  true_w = torch.tensor([2, -3.4])
  true_b = 4.2
  features, labels = d2l.synthetic_data(true_w, true_b, 1000)

  # 读取数据集
  def load_array(data_arrays, batch_size, is_train = True):
      # 构建一个迭代器
      dataset = data.TensorDataset(*data_arrays)
      return data.DataLoader(dataset, batch_size, shuffle=is_train)

  batch_size = 10
  data_iter = load_array((features, labels), batch_size)

  # 定义模型
  net = nn.Sequential(nn.Linear(2, 1))

  # 定义参数
  net[0].weight.data.normal_(0, 0.01)
  net[0].bias.data.fill_(0)

  # 定义损失函数
  loss  = nn.MSELoss()

  # 定义优化算法
  trainer = torch.optim.SGD(net.parameters(), lr = 0.03)

  # 训练
  num_epoch = 3
  for epoch in range(num_epoch):
      for X, y in data_iter:
          l = loss(net(X), y)
          trainer.zero_grad()
          l.backward()
          trainer.step()
      l = loss(net(features), labels)
      print(f'epoch {epoch + 1}, loss {l:f}')

  print(f'time {timer.stop(): .5f} sec')

对比

两个代码的运行结果分别为:

  (d2l) z**@e****:~/deeplearning/linear_regression$ ***/miniconda3/envs/d2l/bin/python ***/deeplearning/linear_regression/model_simple.py
  epoch 1, loss 0.000301
  epoch 2, loss 0.000114
  epoch 3, loss 0.000114
  time  0.16474 sec
  (d2l) z**@e****:~/deeplearning/linear_regression$ ***/miniconda3/envs/d2l/bin/python ***/deeplearning/linear_regression/model.py
  epoch 1, loss 0.028095
  epoch 2, loss 0.000099
  epoch 3, loss 0.000052
  time  0.13666 sec

可以看到,针对于简单的线性回归而言,手动写的代码无论是最终的精度还是时间上都是更优的,搜索到的可能的原因是:使用现有的机器学习框架可能会带来一些开销,例如框架本身的启动时间、内存占用等,手动编写的代码可以避免这些开销。

以上就是全部的内容了,由于作者刚开始学习,能力浅薄,待我学成归来,也许会有更深的了解。

标签:loss,torch,batch,学习,epoch,3.2,3.3,data,size
From: https://www.cnblogs.com/zcry/p/18048424

相关文章

  • 学习随笔Vue
    v-if:v-if是用于条件性地渲染HTML元素,根据表达式的值来决定是否将元素添加到DOM中。当表达式的值为true时,元素会被渲染到DOM中,当表达式的值为false时,元素不会被渲染到DOM中,也就是说元素会被完全删除。当条件频繁变化时,使用v-if适合,因为它能够完全销毁和重建元......
  • 找实习学习第四天
       第二: 注意命名规范,子路由路径不能加“/”,浏览器会自动匹配上  elementui布局aside并不在侧边,而是纵向排列原因是没引入css布局,(当你并不知道这个代码是干什么但是他又出现了的时候,就应该把他加上,不要觉得没用就放弃他) 是简便写法,等同于   引入el......
  • LLMOps 学习记录
    在OpenAI的GPT,Meta的Llama和Google的BERT等大型语言模型(LLM)发布之后,它们可以生成类似人类的文本,理解上下文并执行广泛的自然语言处理(NLP)任务。LLM将彻底改变我们构建和维护人工智能系统和产品的方式。因此,一种被称为“LLMOps”的新方法已经发展并成为每个AI/ML社区的话题,以简化......
  • 笔记:Git学习之应用场景和使用经验
    目标:整理Git工具的应用场景和使用经验一、开发环境Git是代码版本控制工具;Github是代码托管平台。工具组合:VSCode+Git需要安装的软件:vscode、Git其中vscode需要安装的插件:GitLens、GitHistory二、应用场景工作场景:嵌入式开发,多人本地使用三、使用总结基础操作,参考廖雪峰的Git教......
  • 李宏毅2022机器学习HW4 Speaker Identification上(Dataset &Self-Attention)
    Homework4Dataset介绍及处理Datasetintroduction训练数据集metadata.json包括speakers和n_mels,前者表示每个speaker所包含的多条语音信息(每条信息有一个路径feature_path和改条信息的长度mel_len或理解为frame数即可),后者表示滤波器数量,简单理解为特征数即可,由此可知每个.pt......
  • 【李宏毅机器学习2021】(四)Self-attention
    引入Self-attention前面学到的内容输入都是一个向量,假如输入是一排向量,又应如何处理。来看下有什么例子需要将一排向量输入模型:当输入是一排向量时,输出有三种类型:输入和输出的长度一样,每一个向量对应一个label,如词性标注、音标识别、节点特性(如会不会买某件商品)。一......
  • vue3 js 方式实现学习时长正向计数器 时分秒转秒 秒转时分秒
    //学习时长constLocktime=ref('00:00:00');consttimeAlarmTWO=ref(null);consthour=ref(0);constminute=ref(0);constsecond=ref(10);constreckon=ref(true);//判断是否在计时//判断一下数值的变化consttimer=()=>{second.value=second......
  • vagrant学习笔记
    vagrant镜像网站:https://app.vagrantup.com/boxes/search?utf8=%E2%9C%93&sort=downloads&provider=&q=centos使用putty连接vagrant创建的虚拟机:IP:127.0.0.1 端口:2222  ==============>IP&PORT是你在启动虚拟机的时候出现的IP与PORT在vagrant中创建一个虚拟机的过程:1)......
  • Java学习笔记——第二天
    进制知识二进制、八进制和十六进制二进制:只有0和1两个数字,按照逢2进1的方式表示数据。八进制:只有0~7八个数字,按照逢8进1的方式表示数据。十六进制:由0~9以及A,B,C,D,E,F,共十六个数字,按照逢16进1的方式表示数据,其中A,B,C,D,E,F分别代表十进制的10,11,12,13,14,15。Java程序中支持书写二进制、......
  • Spring-Boot学习
    Spring-boot学习笔记从零开始创建项目先创建一个空的Maven项目,然后在pom.xml引入Spring-boot-starter的父依赖<parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.2......