首页 > 其他分享 >神经网络训练速度相关学习--1

神经网络训练速度相关学习--1

时间:2024-04-18 23:11:08浏览次数:15  
标签:显存 训练 -- batch 利用率 神经网络 GPU net CPU

2024-04-18

程序执行的调用顺序:

cpu接收到指令,执行——从存储器中加载数据到cpu,对数据进行预处理——预处理后的数据传输gpu——gpu执行运算——将运算结果存储到存储器——开始新一轮batch运算(每一次计算都需要从内存中读取数据)

另外参考:先将硬盘中的数据读取到内存中,然后CPU从内存中读取数据,再将数据传输到显存中,等待GPU处理,处理完成之后会传递一部分参数数据到CPU。同时CPU还会将网络和损失函数等传递给GPU。https://blog.csdn.net/qq_39333636/article/details/124873848

在训练神经网络的过程中在CPU主要负责对数据进行预处理以及记录训练日志,GPU主要负责进行前向传播与反向传播。

多GPU数据并行加速原理:

假设一台机器上有k块GPU。给定需要训练的模型,每块GPU及其相应的显存将分别独立维护一份完整的模型参数。在模型训练的任意一次迭代中,给定一个随机小批量,我们将该批量中的样本划分成k份并分给每块显卡的显存一份。然后,每块GPU将根据相应显存所分到的小批量子集和所维护的模型参数分别计算模型参数的本地梯度。接下来,我们把k块显卡的显存上的本地梯度相加,便得到当前的小批量随机梯度。之后,每块GPU都使用这个小批量随机梯度分别更新相应显存所维护的那一份完整的模型参数。

原文链接:https://blog.csdn.net/xiaosongshine/article/details/99440144

影响因素:

GPU显存:GPU的内存占用率主要由两部分组成:一是优化器参数,模型自身的参数,模型中间每一层的缓存,都会在内存中开辟空间来进行保存;二是batch size的大小。每个GPU都会分配一部分显存用于存储中间变量、梯度、权重等。GPU显存的使用量取决于模型的复杂度、批量大小(batch size)以及数据类型等因素(查阅,有如下公式:
显存占用 = 模型显存占用 + batch_size × 每个样本的显存占用)https://blog.csdn.net/zwhdldz/article/details/134711996

计算模型参数量:torchsummary

import torch as t
from torchsummary import summary
 
rgb = t.randn(1,3,352,480).cuda()
net = FCN(12).cuda()
out = net(rgb)
summary(net,input_size=(3,352,480),batch_size=1)

GPU利用率:高利用率表示GPU多数时间都在进行计算操作,低利用率表示GPU空闲。

Num Workers:在数据加载过程中cpu用于并行处理数据加载任务的线程数。增加Num Workers可以加速数据加载,但也需要更多的CPU资源。(常为4 8 16等数字)

查看GPU内存利用率和GPU利用率:

方法一:终端

nvidia-smi  #查看当前GPU情况
nvidia-smi -l 3 #每隔3秒查看GPU情况

方法二:任务管理器

 

现象:

GPU利用率波动问题(利用率变化类似0%,50%,80%,99%,0%)

其实是GPU在等待数据从CPU传输过来,当从总线传输到GPU之后,GPU逐渐起计算来,利用率会突然升高,但是GPU的算力很强大,0.5秒就基本能处理完数据,所以利用率接下来又会降下去,等待下一个batch的传入。

解决方法:调高num workers.但num_work不是越大越好。因为越大,虽然线程多了,但是切分到每一个线程消耗也大了,所以会增加CPU的负荷,一些进程间的分配和协作+I/O问题反而会拖慢速度,从而降低对GPU的利用。极端点可以理解为CPU处理数据很快只要1s,但是因为线程数太多而分发数据的时候花了5s。通常可以根据cpu和核数去设置num_work,查看cpu核数的命令:

lscpu

GPU内存利用率低

batchsize太少,占用的GPU内存空间少。会进一步导致GPU利用率低的情况。

解决方法:提高batchsize,一次性多载入数据到显存中,可以提高它的占用率。但batchsize很大,例如达到百万数量级,训练速度往往会很慢,因为每次迭代都要对所有样本进行进行求和运算和矩阵运算。同时 batch size 太大,也不利于网络跳出局部极小点,每次更新参数都是相同的样本,下降方向基本确定,这会导致模型的泛化性能下降。

其他优化方法

梯度累加:https://blog.csdn.net/qq_45756171/article/details/122910838

 

标签:显存,训练,--,batch,利用率,神经网络,GPU,net,CPU
From: https://www.cnblogs.com/pioner/p/18144735

相关文章

  • hyperf统一请求响应
    2024年4月18日08:48:45以下是两个方案:1,使用注解,直接返回<?phpnamespaceApp\Utils;useApp\Utils\GlobalCode;useApp\Utils\GlobalMsg;useHyperf\Contract\ContainerInterface;useHyperf\Di\Annotation\Inject;useHyperf\HttpServer\Contract\RequestInterfac......
  • Spring 源码阅读(一)环境搭建
    注意事项:使用2024-03-14发布的Spring5.3.33版本IDE工具使用了IntellijIDEA,同时为了简化不必要的内容没单独配置Gradle环境JDK版本采用EclipseTemurin1.8/11均可下载源码下载SpringFramework源码,本次选择5.3.33版本,发布日期2024-03-14,通过IntellijIDE......
  • EF有关include的优化
     1、不使用include。共执行SQL 3276800000次,数据库连接接近3276800000次2、使用include,执行一次,CPU执行3276800000次。或者胜出。负载代码//嵌套关联的优化使用varusersWithDetails=context.Users.Include(u=>u.Orders)......
  • 这个网络爬虫代码,拿到数据之后如何存到csv文件中去?
    大家好,我是皮皮。一、前言还是昨天的那个网络爬虫问题,大佬们,帮忙看看这个网络爬虫代码怎么修改?那个粉丝说自己不熟悉pandas,用pandas做的爬虫,虽然简洁,但是自己不习惯,想要在他自己的代码基础上进行修改,获取数据的代码已经写好了,就差存储到csv中去了。他的原始代码如下:importreq......
  • 性能测试——性能测试-常见linux性能指标监控命令
    vmstat命令:           top命令:            free-h命令:             df-h命令:         mpstat命令:             sar–收集和报告系统活动 ......
  • [题解]ABC282E Choose Two and Eat One
    ABC282EChooseTwoandEatOne又一个图论的回顾——Kruskal最小(最大)生成树算法。看到\(n\)的范围只有\(500\),应该没有什么特别的算法。那么我们考虑建一个*\(n\)个顶点的完全图,节点\(x\)到节点\(y\)的边权值就是\(x^y+y^x\)。然后跑一遍最大生成树,得到的和就是最大结果了。如......
  • 二叉树遍历根据先序中序创建二叉树根据中序后序创建二叉树计算二叉树的深度线索二叉树遍历线索二叉树赫夫曼树二叉排序树平衡二叉排序树C++代码#include<iostream>usingnamespacestd;structBiTree{ intval=0; BiTree*left; BiTree*right;};//二叉树的......
  • stm32例子 串口接受oled屏幕不卡显示
      #include<Arduino.h>#include<U8g2lib.h>#ifdefU8X8_HAVE_HW_SPI#include<SPI.h>#endif#ifdefU8X8_HAVE_HW_I2C#include<Wire.h>#endifU8G2_SSD1306_128X64_NONAME_F_SW_I2Cu8g2(U8G2_R0,/*clock=*/PB6,/*data=*/PB......
  • 位运算
    1.数组内存在一个个数为奇数的数,其余数个数都为偶数,求这个数。解决该问题要用到位运算中的异或,首先要了解异或的特性:1.a^0=a2.a^a=03.异或满足交换律和集合律:a^b=b^a,a^b^c=a^(b^c)因此面对该问题,可以将数组内的数全部异或,由于其余数为偶数个,异......
  • 4.18
    APP端连接数据库publicclassDBUtils{staticList<ESP8266>ESPlist=newArrayList<>();//存放设备的数组privatestaticESP8266Device=newESP8266();//初始化数组privatestaticStringdriver="com.mysql.jdbc.Driver";//MySql驱动,需要驱动才能接入MyS......