首页 > 其他分享 >with torch.no_grad():注意事项

with torch.no_grad():注意事项

时间:2023-07-21 10:37:12浏览次数:35  
标签:no 梯度 torch param 计算 grad

1。 当执行原地操作时,例如 tensor.add_(x),将会在一个张量上直接修改数据,而不会创建新的张量。由于修改了张量的数据,因此计算图会失效,即计算图中的操作和输入输出关系都会发生变化。这会导致反向传播无法正确计算梯度。因此,PyTorch 禁止在需要梯度计算的张量上执行原地操作。为了解决这个问题,可以使用 with torch.no_grad() 上下文管理器来禁用梯度计算,并对结果进行复制,从而得到一个新的张量,从而避免原地操作对计算图的破坏。

点击查看代码
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
	for param in params:
		param -= lr * param.grad / batch_size
		param.grad.zero_()
在这段代码中,我们使用 with torch.no_grad() 上下文管理器来禁用梯度计算,并使用小批量随机梯度下降算法来更新模型参数。在每次迭代中,我们遍历模型的每个参数,并根据其梯度信息更新参数的值。在更新参数时,我们使用 param -= lr * param.grad / batch_size 的原地操作来更新参数的值,并使用 param.grad.zero_() 将梯度信息清零,以便进行下一个小批量的更新。

2.具体来说,with torch.no_grad() 会创建一个上下文环境,在该环境中计算的所有张量都不会被跟踪其操作历史,并且梯度信息也不会被记录。这可用于评估模型,例如在验证集上计算模型的性能指标,或者在训练过程中定期输出模型的训练进度。由于我们不需要计算梯度信息,因此使用 with torch.no_grad() 可以显著降低计算开销并减少内存使用。

点击查看代码
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
在这段代码中,我们使用 with torch.no_grad() 上下文管理器计算了训练集的损失,并输出了当前迭代的训练进度。由于这些计算不需要梯度信息,因此可以使用 with torch.no_grad() 来禁用梯度计算,从而提高计算效率并减少内存使用。

标签:no,梯度,torch,param,计算,grad
From: https://www.cnblogs.com/ML-WG/p/17570601.html

相关文章

  • Codeforces 856F - To Play or not to Play
    首先,DP肯定是逃不掉的,因为直接贪心其实不好判断在两个人都可以上线的时间段究竟是哪个人上线,需要通过后面的情况来做出判断,但是这题值域比较大直接维护DP值肯定不行,因此考虑先设计一个与值域有关的DP然后优化。将时间区间离散化,然后依次考虑每个时间区间。一个很自然的想法......
  • 7824. 【2023.07.20NOI模拟】哈密顿路
    Description大家最喜欢的典中典环节它来了。在图论中,无向图的哈密顿路径是恰好能将图中所有顶点各访问一次的路径。给定一张\(n\)个点的简单无向图。对于每个\(1\leqx,y\leqn(x\neqy)\),你想要知道,是否存在一条以顶点\(x\)为起点,以顶点\(y\)为终点的哈密顿路径......
  • nologin
    nologin拒绝用户登录系统补充说明nologin命令可以实现礼貌地拒绝用户登录系统,同时给出信息。如果尝试以这类用户登录,就在log里添加记录,然后在终端输出Thisaccountiscurrentlynotavailable信息,就是这样。一般设置这样的帐号是给启动服务的账号所用的,这只是让服务启动起来,......
  • 【雕爷学编程】Arduino动手做(49)---有源和无源蜂鸣器模块5
    37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的,这里准备逐一动手试试多做实验,不管成功与否,都会记录下来——小小的进步或是搞不掂的问题......
  • __use_no_semihosting 与调用C库函数冲突
    在已经移植freetype代码(使用了大量的C库函数)中,将printf硬件重定向到串口后,#pragmaimport(__use_no_semihosting_swi)与调用C库函数冲突,错误如下:ERROR:L6915E:Libraryreportserror:__use_no_semihosting_swiwasrequested,but_sys_openwasreferenced勾选Options->Target......
  • 解释一下为为什么使用 instance normalization可以消除说话人信息,保留说话人内容
    在contentencoder中使用instancenormalization,可以起到去除说话者信息的作用。首先来看一下instancenormalization的原理,一般会对输入语音做conv1d得到featuremap,有几个conv1dfilter就会得到几个featuremap,可以将这个过程理解为每一个filter都在提取声音的一个特征,通俗一点......
  • pytorch使用(三)torch.zeros用法
    torch.zeros用法torch.zeros()是PyTorch中用来创建全0张量的函数。用法为torch.zeros(size,out=None,dtype=None,layout=torch.strided,device=None,requires_grad=False)。其中,size参数表示张量的形状(shape),可以是一个整数或者一个包含多个整数的tuple。例如,torch.......
  • pytorch使用(四)np.random.randint用法
    np.random.randint用法np.random.randint是numpy库中用于生成随机整数的函数。它的用法如下:numpy.random.randint(low,high=None,size=None,dtype='l')其中,各个参数的含义如下:low:生成的随机整数的下限(包含)。high:生成的随机整数的上限(不包含)。如果不提供high参数,则生......
  • Interleaving Retrieval with Chain-of-Thought Reasoning for Knowledge-Intensive M
    目录概IRCoT代码TrivediH.,BalasubramanianN.,KhotT.,SabharwalA.Interleavingretrievalwithchain-of-thoughtreasoningforknowledge-intensivemulti-stepquestions.ACL,2023.概CoT(ChainofThought)+检索.IRCoT对于如上的问题,"Inwhatcountry......
  • (_mysql_exceptions.OperationalError) (2061, 'RSA Encryption not supported -
    RSA加密与数据库操作的关系在进行数据库操作时,我们有时会遇到类似于“(_mysql_exceptions.OperationalError)(2061,'RSAEncryptionnotsupported'”的错误提示。这个错误提示通常表示我们正在尝试使用RSA加密算法进行数据库操作,但是数据库不支持RSA加密。本文将介绍RSA加密算......