首页 > 其他分享 >RuntimeError: Trying to backward through the graph a second time

RuntimeError: Trying to backward through the graph a second time

时间:2022-12-07 10:57:17浏览次数:49  
标签:val graph RuntimeError torch grad https images backward com

起因是把别人的用clip做分割的模型加到自己的框架上,结果报这个错。Google了一下,发现可能是如下几种原因:多个loss都要backward却没有retain graphhttps://www.zhihu.com/question/414980879,或者是rnn时对于前一次的输出没有detach就送进网络等等,还有一些奇怪的原因比如https://www.zhihu.com/search?type=content&q=RuntimeError%3A%20Trying%20to%20backward%20through%20the,结果发现和自己的情况都不符合。后来看到某CSDN的一个帖子https://blog.csdn.net/qq_49030008/article/details/125440817,虽然和自己的情况也不太一样,但提到的预训练模型启发了我:现在跑的不就是clip做分割的任务吗!于是开始一通乱改,比如把embedding后的text feature给detach或者存储在循环外,每次forward的时候传进来,等等,结果都不work。无奈之下只好跑起来官方代码对拍,但官方代码用了Pytorch lighting,封装了不少东西,其余的地方看起来貌似都没啥特别的...最后在某一次调试的时候打印了一下text feature的is_leaf和requires_grad属性,发现两轮后这两个属性竟然会发生反转!仔细一看发现前两轮不是真正在train,而是进行了一次验证(可以自行查阅lighting框架的num_sanity_val_steps参数),猜想可能是在跑这两次测试的过程中对模型参数属性进行了一些奇妙的初始化,于是查看框架源码https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py:

with torch.no_grad():
	val_loop.run()

发现其实就是跑了一下val的loop,于是对自己的代码在训练前加上一部分:

with torch.no_grad():
	for cur_step, (images, labels) in enumerate(train_loader):
    images = images.to(device, dtype=torch.float32)
    outputs = model(images, labelset='')
    break

果然不报错了,但原理是什么还未搞懂。已经在 GitHub提了一个issue,希望能看到作者给的答案https://github.com/isl-org/lang-seg/issues/38

标签:val,graph,RuntimeError,torch,grad,https,images,backward,com
From: https://www.cnblogs.com/lipoicyclic/p/16962437.html

相关文章