最近做项目用到了GP-WGAN,所以感觉有必要了解一下,这里简要参考别人的博客自己做一个总结吧。
GAN通过训练判别器和生成器来使得生成器生成的数据分布上尽可能和真实样本的分布完全一致。但是在GAN训练的过程中常常会存在训练不稳定的现象。因此蒙特利尔大学的研究者对WGAN进行改进,提出了一种替代WGAN判别器中权重剪枝的方法,即具有梯度惩罚的WGAN,从而避免训练不稳定的情况。
在WGAN中,作者将权重剪切到一定范围内,如[-0.01,0.01],发生了这样的情况,如下图所示:
最后发现大部分的权重都在-0.01和0.01上,这说明大部分权重都只有这两个参数,无法发挥深度神经网络的泛化能力。而且这种剪切也很容易导致梯度爆炸或者梯度弥散,原因在于,剪切范围太小会导致梯度消失,损失无法有效传递,而剪切范围太大,梯度变大一点点,多层以后梯度就会爆炸。为了解决这个问题,这里引入了lipschitz连续性条件,通过梯度惩罚的方式满足连续性条件,如上图右部所示。
lipschitz连续性条件的一个表达如下所示,即我们希望y’->y时,希望$||D(y,\theta)-D(y’,\theta)||<=C||y-y’||$,满足这个条件,就可以满足稳定性要求,这是一个充分非必要条件。上式可写成:
梯度惩罚就是让上式不超过K,那么可以先求出判别器的梯度d(D(x)),然后建立与K之间的二范数实现简单的损失函数设计。如下所示:
由于样本数量较多,没必要对整个样本空间进行采样,重点抓住生成样本集中区域,真实样本集中区域,以及夹在中间的区域就行了。即产生一对真假样本,Xr,Xg,和一个随机数$\epsilon$在0到1之间。那么采样的样本为:
当然,因为对每个batch都做了梯度惩罚,因此判别器中不能使用batch norm,但是可以使用别的norm,如Layer Normalization,weight normalization。
最后的效果还是很稳定的,如下所示:
该方法优点是收敛速度快,稳定性好,生成样本质量高,且基本不需要调参。
参考:
1.http://www.pytorchtutorial.com/pytorch-improved-training-of-wasserstein-gans-wgan-gp/
2.https://www.cnblogs.com/bonelee/p/9166122.html
3.https://blog.csdn.net/qq_38826019/article/details/80786061