首页 > 其他分享 >torch工具箱

torch工具箱

时间:2022-08-19 16:23:09浏览次数:82  
标签:loss torch seed 反向 print 工具箱 grad

  • Autograde

    用户自己创建的叫叶子变量,计算得来的是中间变量;

    前向传播时,torch自动构建计算图,从input到loss;

    反向求导时,沿着计算图,从loss到input;

inpt = torch.ones(size=(4, ))
w = torch.tensor(2.0, requires_grad=True)
l = inpt * w
loss = l.mean()

# 钩子函数,在反向传播时将l的梯度打印,随后销毁
l.register_hook(lambda grad: print(f'l.grad:{grad}'))  

# 保存中间变量的梯度
# l.retain_grad()

loss.backward()  # 执行反向传播

print(w.grad)
print(w.grad_fn)  # None,因为是用户自创建
print(l.grad)     # None,应为是非叶子变量,所以默认不保存梯度

# <MulBackward0 object at xxx>,torch内定义了基本操作的反向传播函数
print(l.grad_fn)  

  • seed
# 以下两命令执行顺序不同,有不同的效果

torch.manual_seed(1)  # 指定seed

torch.seed()  # 随机seed

标签:loss,torch,seed,反向,print,工具箱,grad
From: https://www.cnblogs.com/wjw-cat/p/16602374.html

相关文章

  • 【PyTorch学习笔记】1.Tensor 与 Variable
    在PyTorch0.4.0之前,torch.autograd包中存在Variable这种数据类型,主要是用于封装Tensor,进行自动求导。Variable主要包含下面几种属性。 data:被包装的......
  • juypter notebook中报找不到scipy,torchvision的问题
    在初入深度学习使用juypter这块经常遇到各种问题,每次都被搞的很痛苦; 下面给大家带来我的一点问题解决方案: 首先检查下anaconda中有没有安装scipy这些模块,没有的话在......
  • DW组队学习——深入浅出PyTorch笔记
    本篇是针对DataWhale组队学习项目——深入浅出PyTorch而整理的学习笔记。由于水平实在有限,不免产生谬误,欢迎读者多多批评指正。安装PyTorch安装Anaconda这里为了避免手......
  • PyTorch 剪枝
    pytorch实现剪枝的思路是生成一个掩码,然后同时保存原参数、mask、新参数,如下图 pytorch剪枝分为局部剪枝、全局剪枝、自定义剪枝;局部剪枝是对模型内的部分模......
  • torch.nn.Dropout()
    1.torch.nn.Dropout()classtorch.nn.Dropout(p=0.5,inplace=False)随机将输入张量中部分元素设置为\(0\)。对于每次前向调用,被置\(0\)的元素都是随机的。参数:p......
  • PyTorch 环境配置及安装
    目录1.创建Python子环境:2.Pytorch的安装2.1.查看电脑GPU支持的CUDA版本2.2.CUDA驱动检查2.3.Pytorch包下载(GPU)2.4.检查安装3.JupyterNotebook1.创建P......
  • torch.utils.data
    classtorch.utils.data.Dataset表示\(Dataset\)的抽象类。所有其他数据集都应该进行子类化。所以子类应该覆写__len__和__getitem__,前者提供了数据集的大小,后者支持......
  • 1. Pytorch - 初识
    1.1学习动机2020-2022,Pytorch框架已经陪伴我两年,它是我研究生生活中必不可少的工具,在研究生最后的一年时光以及未来的工作中也同样是必不可少的工具。现已秋招,......
  • 关于安装Anaconda,以及GPU版的tensorflow,pytorch,最后配置jupyter
    1.首先是关于Anaconda的安装:  直接到官网上下载对应版本,直接安装,可以自定义安装目录,但是要注意的是你安装的目录必须是全英文(就很烦!)    然后是接下来的步骤......
  • Four---pytorch学习---基本数据类型/标量/张量/dim值
    pytorch学习(1)pytorch的基本数据类型在torch中默认的数据类型是32位浮点型(torch.FloatTensor)可以通过torch.set_default_tensor_type()函数设置默认的数据类型,但该函......