首页 > 其他分享 >GAN中对生成器和判别器更新机制的理解

GAN中对生成器和判别器更新机制的理解

时间:2023-02-18 20:11:25浏览次数:49  
标签:loss 判别 训练 样本 生成器 网络 生成 GAN

首先是这2个网络是交替迭代训练。

可以是1:1迭代

即D_update_ratio==1,那么G和D之间是1:1的方式进行参数更新

若D_update_ratio==2,那么首先更新两次D再更新一次G

更新G的时候需要冻结D的梯度,避免其计算梯度耗费时间。

对于判别网络:

假设现在有了生成网络(当然可能不是最好的),那么给一堆随机数组,就会得到一堆假的样本集(因为不是最终的生成模型,现在生成网络可能处于劣势,导致生成的样本不太好,很容易就被判别网络判别为假)

现在有了这个假样本集(真样本集一直都有),我们再人为地定义真假样本集的标签,很明显,这里我们默认真样本集的类标签为1,而假样本集的类标签为0,因为我们希望真样本集的输出尽可能为1,假样本集为0。

现在有了真样本集以及它们的label(都是1)、假样本集以及它们的label(都是0)。这样一来,单就判别网络来说,问题变成了有监督的二分类问题了,直接送进神经网络中训练就好。(判别真样本集的loss+判别假样本集的loss)

def backward_D_basic(self, netD, real, fake):
    """Calculate GAN loss for the discriminator

    Parameters:
        netD (network)      -- the discriminator D
        real (tensor array) -- real images
        fake (tensor array) -- images generated by a generator

    Return the discriminator loss.
    We also call loss_D.backward() to calculate the gradients.
    """
    # Real
    pred_real = netD(real) # patchGAN: torch.Size([1, 1, 30, 30])
    loss_D_real = self.criterionGAN(pred_real, True)
    # Fake
    pred_fake = netD(fake.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)
    # Combined loss and calculate gradients
    loss_D = (loss_D_real + loss_D_fake) * 0.5
    loss_D.backward()
    return loss_D
对于生成网络:

对于生成网络,我们的目的是生成尽可能逼真的样本。

而原始的生成网络生成的样本的真实程度只能通过判别网络才知道,所以在训练生成网络时,需要联合判别网络才能达到训练的目的。

所以生成网络的训练其实是对生成-判别网络串接的训练,像上图显示的那样。因为如果只使用生成网络,那么无法得到误差,也就无法训练。
当通过原始的噪声数组Z生成了假样本后,把这些假样本的标签都设置为1,即认为这些假样本在生成网络训练的时候是真样本。因为此时是通过判别器来生成误差的,而误差回传的目的是使得生成器生成的假样本逐渐逼近为真样本(当假样本不真实,标签却为1时,判别器给出的误差会很大,这就迫使生成器进行很大的调整;反之,当假样本足够真实,标签为1时,判别器给出的误差就会减小,这就完成了假样本向真样本逐渐逼近的过程),起到迷惑判别器的目的。

现在对于生成网络的训练,有了样本集(只有假样本集,没有真样本集),有了对应的label(全为1),有了误差,就可以开始训练了。

在训练这个串接网络时,一个很重要的操作是固定判别网络的参数,不让判别网络参数更新(使用detach对其进行分离,将其变为叶子节点,避免其在G中向后回传梯度。),只是让判别网络将误差传到生成网络,更新生成网络的参数。

detach():截断node反向传播的梯度流,将某个node变成不需要梯度的Varibale,因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。pytorch训练GAN时的detach()

在生成网络训练完后,可以根据用新的生成网络对先前的噪声Z生成新的假样本了,不出意外,这次生成的假样本会更真实。有了新的真假样本集(其实是新的假样本集),就又可以重复上述过程了。

整个过程就叫单独交替训练。可以定义一个迭代次数,交替迭代到一定次数后停止即可。不出意外,这时噪声Z生成的假样本就会很真实了。

GAN设计的巧妙处之一,在于假样本在训练过程中的真假变换,这也是博弈得以进行的关键之处

标签:loss,判别,训练,样本,生成器,网络,生成,GAN
From: https://www.cnblogs.com/xyf9474/p/17133421.html

相关文章