一、引言
生成对抗网络(GANs)自诞生以来就在机器学习领域掀起了巨大波澜,它为图像生成、数据增强等诸多任务提供了强大的解决方案。在MJ(Midjourney,一款知名的AI绘画工具,以生成高质量图像著称,其背后大概率也运用到了类似GAN的技术理念)相关的应用场景下,GAN技术更是助力创造出令人惊叹的视觉内容。
二、GAN的基本架构
GAN含生成器(G)与判别器(D)。生成器力求生成逼真数据骗过判别器,判别器负责区分真假数据,二者训练时相互博弈,达动态平衡后生成器便能产出高质量近似真实数据样本
生成器(Generator)部分
- 输入:通常是一个随机噪声向量,一般用
z
来表示,这个向量可以是从某个预设的随机分布(比如高斯分布等)中采样得到的,维度根据具体设计而定,例如可以是 100 维等。 - 网络结构:
- 由多个神经网络层构成,常见的是多层全连接层(在简单示例中)或者是卷积层(在处理图像等数据时更常用)依次堆叠。例如,先是一个线性层(全连接层)将输入的噪声向量进行维度变换(比如从 100 维变换到 256 维等),接着后面跟随激活函数,像 ReLU 或者 LeakyReLU 激活函数来引入非线性因素,使其能够学习到更复杂的映射关系。
- 然后可能再经过几个类似的线性层或者卷积层等的组合,不断对数据进行变换和特征提取,最终输出一个生成的数据样本,这个样本的维度和形式要与真实数据相匹配(比如生成图像时,输出的就是图像的像素矩阵,其大小和真实图像一样)。
- 输出:生成的数据样本,标记为
G(z)
,例如在图像生成任务中,输出就是一张伪造的图像,这个图像会被送入判别器去判断真假。
判别器(Discriminator)部分
- 输入:有两个来源的输入,一是来自真实数据集中的真实样本(标记为
x
,比如真实的图像数据),二是来自生成器生成的虚假样本(即G(z)
)。 - 网络结构:
- 同样由多个神经网络层搭建而成,和生成器类似,可以是多层全连接层或者卷积层等的组合。例如,对于输入的数据(不管是真实的还是生成的)先经过一个线性层将其维度进行合适的压缩和特征提取,然后使用激活函数(如 LeakyReLU 等)处理后再传入后续的层,后续层不断地对数据进行分析判断,提取特征来分辨输入的数据是真实的还是生成器生成的。
- 最后一层一般是一个输出层,输出一个概率值,这个概率值代表输入的数据是真实数据的概率,取值范围在 0 到 1 之间,越接近 1 表示判别器认为该数据是真实数据的可能性越大,越接近 0 则认为是生成器生成的虚假数据的可能性越大。
- 输出:针对输入的数据输出一个 0 到 1 之间的概率值,分别对真实样本
x
和生成样本G(z)
进行判别并输出相应的概率值,用于后续计算损失以及指导生成器和判别器的优化训练。
说人话就是:
生成器是一个卖假货的人,坑蒙拐骗多年,积累很多经验,这些经验就是它以后坑蒙拐骗的素材库,以后遇到新的场景,就会从素材库中提取经验,营造新骗局。从而制出新的假货骗人。
判别器就是专门的鉴别专家,他会对假货各种分析,从而鉴别是不是家伙,并且给假货打分,是不是高仿。
黑心商家与鉴别专家在不断对抗也在不断促进,从而使假货越来越逼真,甄别手段也越来越先进。
终生成器能够生成非常逼真的图像,而判别器也能达到很高的鉴别准确率。
二、GAN数学原理
设真实数据分布为 ,生成数据分布为 ,判别器输出数据为真实的概率 (取值 0 到 1)。
判别器目标是最大化:
[V(D) = \mathbb{E}{x \sim P{data}}[\log D(x)] + \mathbb{E}{z \sim P{z}}[\log(1 - D(G(z)))]]
生成器目标是最小化判别其生成数据为真实的概率,等价于最大化:
[\mathbb{E}{z \sim P{z}}[\log D(G(z))]]
训练时交替优化二者,先定生成器优化判别器,再反之。
三、GAN在mj中的应用实现
网络架构搭建
- 生成器:采用多层神经网络结构,如卷积神经网络(CNN)或 Transformer 架构。可对输入的随机噪声向量进行处理,通过一系列的卷积层、反卷积层、激活函数等操作,将低维噪声向量逐步变换为高维的图像数据,使其符合真实图像的尺寸和特征分布。如通过反卷积层增加图像的分辨率,使用激活函数引入非线性因素,学习复杂的图像特征映射关系。
- 判别器:也基于神经网络构建,通常使用卷积神经网络来提取图像的特征。将生成器生成的图像和真实图像作为输入,通过卷积层、池化层等操作提取图像的特征,最后通过全连接层输出一个概率值,代表输入图像是真实图像的概率。
数据准备与预处理(骗子的经验)
- 收集图像数据:收集海量的各类高质量图像数据,涵盖各种风格、主题和场景,为模型提供丰富的学习样本,让模型能够学习到不同图像的特征和规律。
- 数据清洗:对收集到的数据进行清洗,去除模糊、损坏、低质量的图像,以及与目标任务不相关的图像,提高数据的质量和纯度。
- 数据标注:根据图像的内容、风格等信息进行标注,如标注图像的主题类别、艺术风格类型等,以便在训练过程中为模型提供监督信息,帮助模型更好地学习和理解图像的特征与用户需求之间的关系。
对抗训练过程(骗子与鉴定专家的博弈)
- 初始化参数:随机初始化生成器和判别器的网络参数,确定训练的超参数,如学习率、批次大小、训练轮数等。
- 生成器训练:固定判别器的参数,给生成器输入随机噪声向量,生成器根据噪声生成图像。将生成的图像送入判别器,根据判别器的输出计算生成器的损失,使用反向传播算法更新生成器的参数,使生成器生成的图像更接近真实图像,以骗过判别器。
- 判别器训练:固定生成器的参数,将真实图像和生成器生成的图像同时送入判别器,判别器对输入图像进行真假判断,根据判断结果计算判别器的损失,通过反向传播算法更新判别器的参数,提高判别器区分真假图像的能力。
- 交替训练:重复上述生成器和判别器的训练步骤,使二者不断对抗和优化,直到模型收敛或达到预设的训练轮数等停止条件。
用户交互与图像生成
- 文本输入理解:用户输入文本描述,MJ 对文本进行解析和特征提取,将文本信息转换为模型能够理解的向量表示,作为生成器的输入条件之一,引导生成器生成符合用户描述的图像。
- 生成图像:训练好的模型根据用户输入的文本向量和随机噪声向量,通过生成器生成图像。生成过程中,模型会根据学习到的图像特征和文本条件,组合和生成具有相应风格、内容和细节的图像。
- 结果反馈与优化:将生成的图像展示给用户,用户可根据自己的需求和期望,对生成的图像提出反馈意见,如要求调整图像的某个细节、改变风格等。MJ 根据用户反馈,对输入文本进行调整或重新生成图像,直到用户满意。