最近听几场讲座都听到了一些有关GAN的应用,今年GAN更是刷了很多会议论文,了解一些常用的GAN网络,对于后面做科研还是做工程项目都很有帮助。一般的GAN网络存在着以下一些问题:训练不稳定,需要平衡生成器和判别器之间的训练;缺乏相应指示GAN训练好坏的指标等。下面依次简要介绍目前常用的几种GAN:WGAN,LSGAN,DCGAN。
WGAN
WGAN是Wasserstein GAN的jiancheng简称,其目标是使GAN的训练达到稳定。WGAN主要从损失函数对GAN进行了改进,主要包括判别器最后一层去掉sigmoid,生成器和判别器的loss不取log,将更新后的权重限制在一定范围内,比如【-0.01,0.01】,以满足lipschitz条件。
WGAN的依据是交叉熵(JS散度)不适合衡量具有不相交部分分布之间的距离,而是使用Wasserstein距离去衡量生成数据和真实数据之间的距离,理论上解决了训练不稳定的问题。同时给定一个说明GAN训练好坏的指标,该指标越小,表明GAN训练越差。距离表达公式摘自下图。
注:Lipschitz限制是在样本空间中,要求判别器函数D(x)梯度值不大于一个有限的常数K,通过权重值限制的方式保证了权重参数的有界性,间接限制了其梯度信息。
DCGAN
DCGAN还是解决GAN训练不稳定的问题,将生成器和判别器换成了两个卷积神经网络。其结构如下图所示:
DCGAN对GAN的主要变化如下:
1.将pooling层convolutions替代,其中,在discriminator上用strided convolutions替代,在generator上用fractional-strided convolutions替代。
2.在generator和discriminator上都使用batchnorm。好处是帮助梯度传播到每一层,防止generator把所有的样本都收敛到同一个点。直接将BN应用到所有层会导致样本震荡和模型不稳定,通过在generator输出层和discriminator输入层不采用BN可以防止这种现象。
3.移除全连接层。
4.此外,还在预处理中将图像scale到tanh的[-1, 1]。LeakyReLU的斜率是0.2。将momentum参数beta从0.9降为0.5来防止震荡和不稳定,等。
LSGAN
LSGANs的英文全称是Least Squares GANs。其核心思想是将GAN的目标函数由交叉熵损失换成最小二乘损失,缓解了GAN训练不稳定和生成图像质量差多样性不足的问题。
作者认为GAN以交叉熵作为损失,会使得生成器不会再优化那些被判别器识别为真实图片的生成图片,即使这些生成图片距离判别器的决策边界仍然很远,也就是距真实数据比较远。这意味着生成器的生成图片质量并不高。为什么生成器不再优化优化生成图片呢?是因为生成器已经完成我们为它设定的目标——尽可能地混淆判别器,所以交叉熵损失已经很小了。而最小二乘就不一样了,要想最小二乘损失比较小,在混淆判别器的前提下还得让生成器把距离决策边界比较远的生成图片拉向决策边界。
为什么最小二乘损失可以使得GAN的训练更稳定呢?作者对这一点介绍的并不是很详细,只是说sigmoid交叉熵损失很容易就达到饱和状态(饱和是指梯度为0),而最小二乘损失只在一点达到饱和,如图2所示:
损失函数如下所示:
参考:
1.https://blog.csdn.net/victoriaw/article/details/56486471
2.https://blog.csdn.net/qq_25737169/article/details/78857788
3.https://blog.csdn.net/stdcoutzyx/article/details/53872121
4.https://blog.csdn.net/victoriaw/article/details/60755698