首页 > 其他分享 >torch反向传播backward()函数解析

torch反向传播backward()函数解析

时间:2023-12-06 16:45:28浏览次数:30  
标签:loss optimizer torch loader train 反向 model backward fn

参考网址:

https://blog.csdn.net/weixin_44179269/article/details/124573992?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522170167791616800197042802%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=170167791616800197042802&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-124573992-null-null.142^v96^pc_search_result_base1&utm_term=backword&spm=1018.2226.3001.4187

以下伪代码只供理解模型训练的一些步骤和流程

 1、网络类定义:

 1 class MYNET(nn.Module):
 2     def __init__(self,in_channels=3,out_channels=2):
 3         super(MYNET,self).__init__()
 4        ....
 5        .....
 6 
 7 
 8     def forward(self.x):
 9           ......
10           ......
11           return ... ...

2、训练定义 前向传播,loss计算,loss反向传播计算梯度
1 def main():
#训练网络 2 model = UNET(in_channels=3, out_channels=2).to(DEVICE) 3 #优化器 4 optimizer = optim.Adam(model.parameters(),lr= LEARNING_RATE)
    #损失函数
    loss_fn = nn.BCEWithLogitsLoss() 5 #数据集 6 train_loader, val_loader = get_loaders(train_pths, val_pths.batchsize.....)
#预训练模型加载
if LOAD_MODEL: load_ckpt(torch.load(Pretrained_Model),model)
   #tensorboard可视化,创建writer
global writer
writer = SummaryWriter(log_dir = '/path/to/log_dir')
for epoch in range(NUM_EPOCHS):
...
train_fn(epoch,train_loader,model,optimizer,loss_fn,scaler)
      checkpoint={
          "state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
        }
      save_ckpt(checkpoint)
    writer.close()

   
def train_fn(epoch, loader, model, optimizer, loss_fn):
            loop = tqdm(loader)#进度条
            for step,(data,targets) in enumerate(loop):
                 ....
                 #前向传播
                predict = model(data)
                 #计算loss
                loss = loss_fn(predict,targets)

                 #loss(代价函数)反向传播计算梯度
                 ‘’‘’backward执行之后,会自动对loss函数表达式中的可求导变量进行偏导数求解(梯度是个向量),并赋值到自变量(可训练的网络参数).grad中‘’


                 loss.backward()
                 #优化器执行参数更新
                 optimizer.step()

3、有关代价函数,反向求梯度

参考:https://www.cnblogs.com/xfuture/p/17869521.html

代价函数(是个标量):

图(二维):

 

梯度(是代价函数在某点对所有自变量的偏导数构成的向量):

 

 梯度更新(alpha为学习率):

 

           

标签:loss,optimizer,torch,loader,train,反向,model,backward,fn
From: https://www.cnblogs.com/zzc-Andy/p/17879064.html

相关文章

  • 【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误
    【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误报错详情错误产生背景原理解决方案RuntimeError:oneofthevariablesneededforgradientcomputationhasbeenmodifiedbyaninplaceoperation报错详情  模型在backward时,发现如下报错......
  • 自有AI芯片接入AI框架Pytorch的方案
    现在AI框架主要用Pytorch,包括一些常用的库对Pytorch支持都较好一、华为昇腾npu能够跟上Pytorch的更新,直接和Pytorch兼容,而且有较多人来维护,代码风格不错,之前是通过注入Pytorch预留的的xla搞的接入,现在被官方接收了。非常推荐,笔者用他的框架实现了自有GPGPU芯片手写数字识别的训......
  • [PyTorch] 如何判定运算维度
    实际上无论是几维,方法都是一样。假设以torch.softmax()为例:#下面运行结果所使用的代码importtorchimportnumpyasnpz=np.arange(1,33).reshape((2,2,2,4))z=torch.tensor(z,dtype=torch.float32)#为了使各元素softmax的结果相差不至于过大,这里简单处理一下,......
  • 使用 PyTorch 完全分片数据并行技术加速大模型训练
    本文,我们将了解如何基于PyTorch最新的完全分片数据并行(FullyShardedDataParallel,FSDP)功能用Accelerate库来训练大模型。动机......
  • 10. 从零用Rust编写正反向代理, HTTP内网穿透支持修改头信息
    wmproxywmproxy是由Rust编写,已实现http/https代理,socks5代理,反向代理,静态文件服务器,内网穿透,配置热更新等,后续将实现websocket代理等,同时会将实现过程分享出来,感兴趣的可以一起造个轮子法项目++wmproxy++gite:https://gitee.com/tickbh/wmproxygithub:https://github.com/tic......
  • torch版本真的很重要!!!
    事情的经过就是,跑深度学习代码的时候,遇到了一系列的错误参数维度对不上1.运行时,发现预训练模型得到的参数跟我模型要的对不上,傻逼了,当时没看见github得issues里面就有解答,找了大半天,还尝试去改模型参数。其实就是因为下载的预训练模型参数的版本不对,应该用旧的版本。cuda用不......
  • PyTorch解説
    PyTorch是一种面向Python的开源机器学习库。它是由Facebook的人工智能研究团队基于最初支持多范式脚本语言“Lua”的Torch开发而来。Python是一种广泛用于“利用机器学习进行人工智能开发”、“Web服务和Web应用开发”、“区块链开发”以及“物联网开发”等多个领域的编程语言。......
  • Keras 3.0正式发布:可用于TensorFlow、JAX和PyTorch
    前言 Keras3.0正式发布:可用于TensorFlow、JAX和PyTorch本文转载自机器之心仅用于学术分享,若侵权请联系删除欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。CV各大方向专栏与各个部署框架最全教程整理【CV技术指南】CV全栈......
  • pytorch 学习记录——计算图
    1.pytorch的计算图是动态更新的(tensorflow是静态计算图),数据流向可以是双向的。2.pytorchvariable(用于封装tensor,便于自动求导的变量类型,在pytorch0.4.0之后版本已被并入tensor)基本属性:data,dtype,shape,device,requires_grad,is_leaf,grad,grad_fn3.is_leaf是否为叶子节点:用户创......
  • 基于对象的跨表查询(正向反向)
    #跨表查询有两种方式-基于对象的跨表查询:子查询  -基于双下划线的跨表查询:关联查询,连表查询    #基于对象的跨表查询-查询主键为1的书籍的出版社所在的城市    #基于对象的跨表查询(子查询)  #一对多  #查询主键为1的书籍的出......