[1704.00028] Improved Training of Wasserstein GANs (arxiv.org)
Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of wasserstein gans[J]. Advances in neural information processing systems, 2017, 30.
引用:9962
摘要
生成对抗网络(GAN)是强大的生成模型,但遭受训练不稳定性。最近提出的WassersteinGAN(WGAN)在GAN的稳定训练方面取得了进展,但有时仍然可以只生成较差的样本或未能收敛。我们发现这些问题往往是由于在WGAN中使用权重裁剪来对批评者实施Lipschitz约束,这可能导致不希望的行为。我们提出了一种裁剪权重的替代方案:惩罚批评者相对于其输入的梯度范数。我们提出的方法比标准WGAN性能更好,并且能够在几乎没有超参数调整的情况下对各种GAN架构进行稳定训练,包括101层ResNet和具有连续生成器的语言模型。我们还在CIFAR-10和LSUN卧室上实现了高质量的生成。
1 简介
生成对抗网络(GAN)[9]是一类强大的生成模型,它将生成建模视为两个网络之间的博弈:生成网络在给定一些噪声源的情况下产生合成数据,鉴别器网络在生成器的输出和真实数据之间进行鉴别。GAN可以产生非常视觉上吸引人的样本,但通常很难训练,最近关于这个主题的大部分工作[23,19,2,21]都致力于寻找稳定训练的方法。尽管如此,GAN的持续稳定训练仍然是一个悬而未决的问题。
特别是,[1]提供了对GAN优化的价值函数的收敛性的分析。他们提出的替代方案,名为瓦瑟斯坦生成对抗网络(WGAN)[2],利用瓦瑟斯坦距离产生一个比原始函数具有更好理论性质的价值函数。WGAN要求判别器(在该作品中称为批评家)必须位于Lipschitz函数的空间内,作者通过权重剪裁来强制执行。
我们的贡献如下:
1.在玩具数据集上,我们演示了批评家权重裁剪如何导致不希望的行为。 2.我们提出了梯度惩罚(WGAN-GP),它不存在同样的问题。 3.我们展示了各种GAN架构的稳定训练、权重裁剪的性能改进、高质量的图像生成以及没有任何离散抽样的字符级GAN语言模型。
2 背景
2.1 生成对抗网络
GAN训练策略是定义两个竞争网络之间的博弈,生成器网络将噪声源映射到输入空间,判别网络要么接收生成的样本,要么接收真实的数据样本,必须区分两者,生成器被训练来愚弄判别器。
形式上,生成器G和判别器D之间的博弈是极大极小的目标:
其中$P_r$是数据分布,$P_g$是由x~=G(z), z~p(z)隐式定义的模型分布(生成器的输入z是从一些简单的噪声分布p中采样的,例如均匀分布或球形高斯分布)。
如果判别器在每次生成器更新之前被训练到最优,那么最小化价值函数相当于最小化Pr和Pg[9]之间的Jensen-Shannon散度,但这样做通常会导致随着判别器饱和梯度消失。在实践中,[9]主张改为训练生成器以最大化$\mathbb{E}_{\tilde{\boldsymbol{x}} \sim \mathbb{P}_g}[\log (D(\tilde{\boldsymbol{x}}))]$,这在某种程度上规避了这一困难。然而,即使是这种修改后的损失函数也可能在有好判别器[1]的情况下行为不端。
2.2 Wasserstein GANs
[2]认为GANs通常最小化的分歧在生成器参数方面可能不是连续的,这导致了训练困难。他们建议使用Earth-Mover(也称为Wasserstein-1)距离W(q,p),它被非正式地定义为运输质量的最小成本,以便将分布q转换为分布p(其中成本是质量乘以运输距离)。在温和的假设下,W(q,p)在任何地方都是连续的,几乎在任何地方都是可微的。
WGAN价值函数使用Kantorovich-Rubinstein对偶构造[25]得到
其中D是1-Lipschitz函数的集合,Pg是属于x = G(z), z ~ p(z)隐式定义的模型分布。在这种情况下,在一个最优鉴别器(在本文中称为批评家,因为它没有受过分类训练)下,最小化关于生成器参数的值函数使W(Pr, Pg)最小化。
WGAN价值函数产生了一个批评家函数,其相对于输入的梯度比GAN的对应函数表现得更好,使得生成器的最优化更容易。经验上,还观察到WGAN价值函数似乎与样本质量相关,而GAN的情况并非如此[2]。
为了对批评家实施利普希茨约束,[2]提出将批评家的权重剪辑在一个紧凑的空间[−c, c]中。满足这个约束的函数集是一些k的利普希茨函数的子集,它依赖于c和批评家体系结构。在下面的部分中,我们展示了这种方法的一些问题,并提出了一种替代方案。
2.3 最优WGAN批评家的性质
为了理解为什么权重裁剪在WGAN评论中是有问题的,以及激励我们的方法,我们强调了WGAN框架中最优评论的一些属性。我们在<u>附录</u>中证明了这些。
命题1. Let $\mathbb{P}_r$ and $\mathbb{P}g$ be two distributions in $\mathcal{X}$, a compact metric space. Then, there is a 1-Lipschitz function $f^$ which is the optimal solution of $\max _{|f|L \leq 1} \mathbb{E}{y \sim \mathbb{P}r}[f(y)]-\mathbb{E}{x \sim \mathbb{P}_g}[f(x)]$. Let $\pi$ be the optimal coupling between $\mathbb{P}_r$ and $\mathbb{P}_g$, defined as the minimizer of: $W\left(\mathbb{P}_r, \mathbb{P}_g\right)=$ $\inf _{\pi \in \Pi\left(\mathbb{P}_r, \mathbb{P}g\right)} \mathbb{E}{(x, y) \sim \pi}[|x-y|]$ where $\Pi\left(\mathbb{P}_r, \mathbb{P}_g\right)$ is the set of joint distributions $\pi(x, y)$ whose marginals are $\mathbb{P}_r$ and $\mathbb{P}_g$, respectively. Then, if $f^$ is differentiable ${ }^{\ddagger}, \pi(x=y)=0^{\S}$, and $x_t=$ $t x+(1-t) y$ with $0 \leq t \leq 1$, it holds that $\mathbb{P}{(x, y) \sim \pi}\left[\nabla f^*\left(x_t\right)=\frac{y-x_t}{\left|y-x_t\right|}\right]=1$.
推论1. $f^*$ has gradient norm 1 almost everywhere under $\mathbb{P}_r$ and $\mathbb{P}_g$.
$\ddagger$ 我们可以少做一些假设,只讨论直线方向上的方向导数;这在证明中是存在的。这意味着在每一个f∗是可微的点(因此我们可以在神经网络设置中采用梯度),该语句都成立。
$\S$ 这个假设是为了排除样本x的匹配点为x本身的情况。当Pr和Pg具有在一组测度0中相交的支撑时,它是满足的,例如当它们被两个低维流形支撑时,这些流形不完全匹配[1]。
3 权重约束条件的难点
我们发现WGAN中的权重裁剪会导致最优化困难,即使最优化成功,结果批评家也会有病态的价值表面。我们在下面解释这些问题并展示它们的影响;然而,我们并不声称每个问题都在实践中发生,也不认为它们是唯一的这样的机制。
我们的实验使用来自[2]的特定形式的权重约束(每个权重的大小的硬裁剪),但是我们也尝试了其他权重约束条件(L2范数裁剪、权重归一化),以及软约束条件(L1和L2权重衰减),发现它们表现出类似的问题。
在某种程度上,这些问题可以通过批处理归一化来缓解,[2]在他们所有的实验中都使用了<u>批处理归一化</u>。然而,即使使用批处理归一化,我们也观察到非常深的WGAN批评家经常无法收敛。
3.1 能力使用不足
通过权重裁剪实现k-Lipshitz约束会使批评家偏向更简单的函数。如前所述,在推论1中,最佳WGAN批评家在Pr和Pg下几乎到处都有单位梯度范数;在权重裁剪约束下,我们观察到试图达到最大梯度范数k的神经网络架构最终学习了极其简单的函数。
为了证明这一点,我们在几个玩具分布上用权重裁剪训练WGAN批评家到最优,保持生成器分布Pg固定在实分布加上单位方差高斯噪声。我们在图1a中绘制批评家的值表面。我们在批评家中省略了批量规范化。在每种情况下,用权重裁剪训练的批评家都会忽略数据分布的更高矩,而是为最优函数建模非常简单的近似值。相比之下,我们的方法没有受到这种行为的影响。
3.2 爆炸和消失梯度
我们观察到WGAN最优化过程是困难的,因为权重约束和代价函数之间的相互作用,不仔细调整裁剪阈值容易导致梯度消失或爆炸。
为了证明这一点,我们在Swiss Roll玩具数据集上训练WGAN,改变$[10^{−1},10^{−3},10^{−3}]$中的剪切阈值c,并绘制相对于连续激活层的临界损失梯度范数。发生器和批评家都是没有批量归一化的12层ReLU MLP。图1b显示,对于这些值中的每一个,随着我们在网络中进一步向后移动,梯度要么呈指数增长,要么呈指数衰减。我们发现我们的方法结果在更稳定的梯度,既不消失也不爆炸,允许训练更复杂的网络。
4 梯度惩罚
我们现在提出了另一种实施李普希茨约束的方法。当且仅当一个可微函数为1- lipschtiz时,它处处具有范数最多为1的梯度,因此我们考虑直接约束批评家的输出相对于其输入的梯度范数。为了避免易处理性问题,我们在随机样本的梯度范数上施加惩罚$\hat{\boldsymbol{x}} \sim \mathbb{P}_{\hat{\boldsymbol{x}}}$,从而实现约束的软版本。我们的新目标是
采样分布:我们隐式地定义了P^x沿从数据分布Pr和生成器分布Pg采样的点对之间的直线均匀采样。这是由这样一个事实驱动的,即最优批评家包含梯度范数1连接Pr和Pg耦合点的直线(见命题1)。鉴于在任何地方强制执行单位梯度范数约束是棘手的,仅沿着这些直线强制执行似乎就足够了,并且在实验上可以获得良好的性能。
惩罚系数:本文中的所有实验都使用λ = 10,我们发现它在各种架构和数据集(从玩具任务到大型ImageNet cnn)上都能很好地工作。
没有Critic批处理规范化:大多数之前的GAN实现[22,23,2]在生成器和鉴别器中都使用批量规范化来帮助稳定训练,但批量规范化改变了鉴别器问题的形式,从单个输入映射到单个输出,到从整个批量输入映射到批量输出[23]。我们的惩罚训练目标在这种情况下不再有效,因为我们针对每个输入单独惩罚了批评家梯度的规范,而不是整个批次。为了解决这个问题,我们简单地在我们的模型的批评家中忽略了批处理规范化,发现它们在没有它的情况下执行得很好。我们的方法适用于不引入例子之间相关性的标准化方案。特别是,我们建议将层规范化[3](LN)作为批量规范化(BN)的替代。
双面惩罚:我们鼓励标准的梯度趋向于1(双面惩罚),而不是保持在1以下(单方面惩罚)。从经验上看,这似乎并没有太多地限制批评家,这可能是因为最佳WGAN批评家在Pr和Pg下以及在这两者之间的大部分区域中几乎都具有标准1的梯度(见章节2.3)。在我们早期的观察中,我们发现这种方法的效果稍好一些,但我们并没有对此进行充分的研究。我们在附录中描述了关于单边惩罚的实验。
5 实验
5.1 在集合内训练随机架构
我们通过实验证明了我们的模型能够训练大量我们认为有用的架构。从DCGAN体系结构开始,我们通过将模型设置更改为表1中随机对应的值来定义一组体系结构变体。我们认为对这一集合中的许多架构进行可靠的训练是一个有用的目标,但我们并不声称我们的集合是整个有用架构空间的一个无偏或有代表性的样本:它的目的是为了演示我们方法的成功模式,读者应该评估它是否包含与他们预期的应用程序类似的架构。
从这个集合中,我们采样了200个架构,并使用WGAN-GP和标准GAN目标在32×32 ImageNet上对每个架构进行训练。表2列出了以下实例的数量:只有标准GAN成功,只有WGAN-GP成功,两者都成功,或两者都失败,其中成功定义为初始分数>分钟分数。对于大多数分数阈值的选择,WGAN-GP成功地从该集合中训练出许多我们无法用标准GAN目标训练的体系结构。我们在附录中给出了更多的实验细节。
图2:不同的GAN架构用不同的方法训练。我们只是使用WGAN-GP成功地用一组共享的超参数来训练每个体系结构。
5.2 在LSUN卧室数据集上用不同结构训练
为了证明我们的模型能够使用默认设置训练许多架构,我们在LSUN卧室数据集上训练了六个不同的GAN架构[31]。除了[22]中的基线DCGAN架构之外,我们选择了六个架构,我们证明了它们的成功训练:(1)生成器中没有BN和恒定数量的过滤器,如[2],(2)4层512维ReLU 函数MLP生成器,如[2],(3)别判器或生成器中没有normalization (4)门控乘法非线性,如[24],(5)tanh非线性,以及(6)101层ResNet生成器和判别器
尽管我们不认为没有我们的方法就不可能实现,但据我们所知,这是第一次在GAN设置中成功地训练非常深的残差网络。对于每个体系结构,我们使用四种不同的GAN方法来训练模型:WGAN- gp、带权重裁剪的WGAN、DCGAN[22]和最小二乘GAN[18]。对于每个目标,我们都使用了该工作中推荐的默认优化器超参数集(LSGAN除外,在那里我们搜索学习速率)。
对于WGAN-GP,我们将甄别器中的任何批量归一化替换为层归一化(参见第4节)。我们对每个模型进行200K次迭代训练,并在图2中给出样本。我们只是使用WGAN-GP成功地用一组共享的超参数来训练每个体系结构。对于其他的训练方法,有些架构是不稳定的或遭受模式崩溃。
5.3 改进权重裁剪的性能
我们的方法优于权重剪裁的一个优点是提高了训练速度和样本质量。为了证明这一点,我们用权重剪裁和我们在CIFAR10[13]上的梯度惩罚来训练WGAN,并在图3中的训练过程中绘制IS指标[23]。对于WGAN-GP,我们训练一个具有相同优化器(RMSProp)和学习率的模型,作为权重剪裁的WGAN,另一个具有Adam和更高学习率的模型。即使使用相同的优化器,我们的方法也比权重剪裁收敛得更快,得分也更好。使用Adam进一步提高了性能。我们还绘制了DCGAN[22]的性能,发现我们的方法比DCGAN收敛得更慢(在挂钟时间内),但其得分在收敛时更稳定。
图3:四个模型的cifar10 IS指标在生成器迭代(左)或墙上时钟时间(右)上:带有梯度剪切的WGAN,带有RMSProp和Adam(控制优化器)的WGAN- GP,以及DCGAN。WGAN-GP显著优于重量裁剪,性能与DCGAN相当。
5.4 CIFAR-10和LSUN卧室样品质量
对于等效架构,我们的方法实现了与标准GAN目标相当的样本质量。然而,增加的稳定性使我们能够通过探索更广泛的架构来提高样本质量。为了证明这一点,我们找到了一种架构,它在无监督CIFAR-10(表3)上建立了一个新的最先进的Inception 分数。当我们添加标签信息时(使用[20]中的方法),相同的架构优于除SGAN之外的所有其他已发布模型。
表3:在CIFAR-10的Inception 分数。我们的无监督模型达到了最先进的性能,我们的条件模型优于除了SGAN之外的所有其他模型。
我们还在128 × 128 LSUN卧室上训练了深度ResNet,并在图4中显示了样本。我们相信,这些样本至少在任何分辨率上都可以与迄今为止报道的最好的数据集相媲美。
5.5 使用连续发生器对离散数据进行建模
为了证明我们的方法对退化分布建模的能力,我们考虑了用一个GAN对复杂离散分布建模的问题,该GAN的生成器是在连续空间上定义的。作为这个问题的一个例子,我们在谷歌十亿字数据集上训练了一个字符级GAN语言模型[6]。我们的生成器是一个简单的一维CNN,它通过一维卷积将潜在向量确定性地转换为32个独热字符向量的序列。我们在输出处应用非线性Softmax,但不使用抽样步骤:在训练期间,Softmax输出直接传递给Critic(同样,Critic是一个简单的一维CNN)。解码样本时,我们只取每个每维输出向量的最大值。
图4:128×128 LSUN卧室的样本。我们相信这些样本至少可以与迄今为止发表的最好的结果相媲美。
我们在表4中展示了模型的样本。我们的模型经常出现拼写错误(可能是因为它必须独立输出每个字符),但尽管如此,我们还是设法学习了很多关于语言统计的知识。我们无法产生与标准GAN目标可比的结果,尽管我们并不声称这样做是不可能的。
表4:来自WGAN-GP字符级语言模型的样本,对来自十亿字数据集的句子进行训练,截断为32个字符。该模型学习直接从潜在向量输出独热字符嵌入,而无需任何离散抽样步骤。我们无法获得与标准GAN目标和连续生成器可比的结果。
WGAN与其他gan在性能上的差异可以解释如下。考虑单纯形 $\Delta_n=\left{p \in \mathbb{R}^n: p_i \geq 0, \sum_i p_i=1\right}$,以及单纯形(或独热向量)上的顶点集 $V_n=\left{p \in \mathbb{R}^n: p_i \in{0,1}, \sum_i p_i=1\right} \subseteq \Delta_n$。如果我们有一个大小为 $n$ 的词汇表,并且我们有一个大小为 $T$ 的序列上的分布 $\mathbb{P}_r$,我们有$\mathbb{P}_r$是 $V_n^T=V_n \times \cdots \times V_n$上的分布。由于 $V_n^T$是$\Delta_n^T$的子集,我们也可以将 $\mathbb{P}_r$视为$\Delta_n^T$上的分布(通过将零概率质量分配给所有不在 $V_n^T$ 中的点)。
$\mathbb{P}_r$ 在 $\Delta_n^T$上是离散的(或支持在有限数量的元素上,即$V_n^T$ ),但 $\mathbb{P}_g$ 很容易成为$\Delta_n^T$上的连续分布。两个这样的分布之间的KL散度是无限的,因此JS散度是饱和的。尽管GAN并没有从字面上最小化这些散度[16],但在实践中,这意味着判别器可能会很快学会拒绝所有不位于$\Delta_n^T$(独热向量序列)上的样本,并向生成器提供无意义的梯度。然而,很容易看出,即使在$X=\Delta_n^T$的非标准学习场景中,[2]的定理1和推论1的条件也是满足的。这意味着W(Pr,Pg)仍然是定义良好的,在任何地方都是连续的,几乎在任何地方都是可微分的,我们可以像在任何其他连续变量设置中一样优化它。这表现的方式是,在WGAN中,Lipschitz约束迫使Critic提供一个线性的梯度,从所有$\Delta_n^T$到$V_n^T$中的实数点。
其他使用GANs进行语言建模的尝试[32,14,30,5,15,10]通常使用离散模型和梯度估计器[28,12,17]。我们的方法实现起来更简单,但它是否能扩展到玩具语言模型之外还不清楚。
5.6 有意义的损耗曲线和过拟合检测
图5:(a)我们的模型在LSUN卧室上的负批评损失随着网络训练收敛到极小点。(b)当使用我们的方法(左)或梯度剪裁(右)时,MNIST随机1000位子集上的WGAN训练和验证损失显示过拟合。特别是,使用我们的方法,批评者比生成器更快地过拟合,导致训练损失随着时间的推移逐渐增加,即使验证损失下降。
<u>梯度裁剪WGAN的一个重要好处是它们的损失与样本质量相关</u>,并收敛到极小点。为了证明我们的方法保留了这个属性,我们在LSUN卧室数据集[31]上训练了一个WGAN-GP,并在图5a中绘制了批评家损失的负数。我们看到随着生成器最小化$W(P_r, P_g)$),损失收敛。
给定足够的容量和太少的训练数据,GAN将过拟合。为了探索网络过拟合时损失曲线的行为,我们在MNIST的随机1000个图像子集上训练大型非正则化WGAN,并在图5b中的训练集和验证集上绘制负批评家损失。在WGAN和WGAN-GP中,两种损失不同,这表明批评家过拟合并提供了$W(P_r, P_g)$的不准确估计,此时所有关于样本质量相关系数的赌注都关闭了。然而,在WGAN-GP中,即使验证损失下降,训练损失也逐渐增加。
[29]还通过估计生成器的对数似然来测量GAN中的过拟合。与这项工作相比,我们的方法在批评家(而不是生成器)中检测过拟合,并针对网络最小化的相同损失来测量过拟合。
6 结语
在这项工作中,我们展示了WGAN中权重裁剪的问题,并在批评损失中引入了一种不表现出相同问题的罚项形式的替代方案。使用我们的方法,我们展示了跨各种架构的强大建模性能和稳定性。现在我们有了更稳定的训练GAN的算法,我们希望我们的工作为在大规模图像数据集和语言上获得更强的建模性能开辟了道路。另一个有趣的方向是使我们的罚项适应标准GAN目标函数,在那里它可能会通过鼓励判别器学习更平滑的决策边界来稳定训练。
标签:mathbb,GANs,Training,Improved,训练,生成器,GAN,我们,WGAN From: https://blog.51cto.com/coderge/8215103