loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) z_q是codebook 找到的最接近z的向量. z是encoder生成的向量. L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach()) # 这个部分对于encoder做了训练. L对z_q求导=2(z_q - z.detach()) #这个部分对于codebook做了训练. 所以这个detach对于变量x虽然对x不求导,但是计算其他变量时候参与计算. 很早之前,在RVQ那篇文章里说到过,VQ-VAE中是通过在codebook中选择欧式距离最近的embedding对应的index作为离散token的。即其中涉及到argmin操作,该操作是不可导的。因此重建loss的梯度是无法传递到encoder网络的。 如果我们写成 loss= torch.mean((z-z_q)**2) 那么L对z_q求导=2(z_q-z) 对z求导=2(z_q-z)*(-1)=2(z-z_q). 这里面的两个导数是算不了的.因为argmin不可导. 导数没法从z_q传到变量x上.(x是输入网络参数) 所以我们只能用上面的方法来计算. 我们上面的方法. L对z求导 = 2(z_q.detach()-z)*(-1)=2(z - z_q.detach()) # 这个部分对于encoder做了训练. L对z_q求导=2(z_q - z.detach()) #这个部分对于codebook做了训练. 我们再计算z_q 对x求导.即可. z是没法对x求导的. 参考这个: https://zhuanlan.zhihu.com/p/644091516 讲的很好.
标签:loss,vqvae,torch,encoder,codebook,计算,求导,detach From: https://www.cnblogs.com/zhangbo2008/p/17854296.html