智能驾驶峰会
您正在使用IE低版浏览器,为了您的雷锋网账号安全和更好的产品体验,强烈建议使用更快更安全的浏览器
人工智能 正文
发私信给亚萌
发送

2

令人拍案叫绝的Wasserstein GAN

本文作者:亚萌 2017-02-06 16:37
导语:一篇新鲜出炉的arXiv论文《Wassertein GAN》却在Reddit的Machine Learning频道火了!

雷锋网按:本文作者郑华滨,原载于知乎。雷锋网(公众号:雷锋网)已获转载授权。

令人拍案叫绝的Wasserstein GAN

在GAN的相关研究如火如荼甚至可以说是泛滥的今天,一篇新鲜出炉的arXiv论文《Wassertein GAN》却在Reddit的Machine Learning频道火了,连Goodfellow都在帖子里和大家热烈讨论,这篇论文究竟有什么了不得的地方呢?

要知道自从2014年Ian Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。从那时起,很多论文都在尝试解决,但是效果不尽人意,比如最有名的一个改进DCGAN依靠的是对判别器和生成器的架构进行实验枚举,最终找到一组比较好的网络架构设置,但是实际上是治标不治本,没有彻底解决问题。而今天的主角Wasserstein GAN(下面简称WGAN)成功地做到了以下爆炸性的几点:

  • 彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度

  • 基本解决了collapse mode的问题,确保了生成样本的多样性

  • 训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高(如题图所示)

  • 以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到

那以上好处来自哪里?这就是令人拍案叫绝的部分了——实际上作者整整花了两篇论文,在第一篇《Towards Principled Methods for Training Generative Adversarial Networks》里面推了一堆公式定理,从理论上分析了原始GAN的问题所在,从而针对性地给出了改进要点;在这第二篇《Wassertein GAN》里面,又再从这个改进点出发推了一堆公式定理,最终给出了改进的算法实现流程,而改进后相比原始GAN的算法实现流程却只改了四点:

  • 判别器最后一层去掉sigmoid

  • 生成器和判别器的loss不取log

  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

算法截图如下:

令人拍案叫绝的Wasserstein GAN

改动是如此简单,效果却惊人地好,以至于Reddit上不少人在感叹:就这样?没有别的了? 太简单了吧!这些反应让我想起了一个颇有年头的鸡汤段子,说是一个工程师在电机外壳上用粉笔划了一条线排除了故障,要价一万美元——画一条线,1美元;知道在哪画线,9999美元。上面这四点改进就是作者Martin Arjovsky划的简简单单四条线,对于工程实现便已足够,但是知道在哪划线,背后却是精巧的数学分析,而这也是本文想要整理的内容。

本文内容分为五个部分:

  • 原始GAN究竟出了什么问题?(此部分较长)

  • WGAN之前的一个过渡解决方案

  • Wasserstein距离的优越性质

  • 从Wasserstein距离到WGAN

  • 总结

理解原文的很多公式定理需要对测度论、 拓扑学等数学知识有所掌握,本文会从直观的角度对每一个重要公式进行解读,有时通过一些低维的例子帮助读者理解数学背后的思想,所以不免会失于严谨,如有引喻不当之处,欢迎在评论中指出。

以下简称《Wassertein GAN》为“WGAN本作”,简称《Towards Principled Methods for Training Generative Adversarial Networks》为“WGAN前作”。

WGAN源码实现:martinarjovsky/WassersteinGAN

第一部分:原始GAN究竟出了什么问题?

回顾一下,原始GAN中判别器要最小化如下损失函数,尽可能把真实样本分为正例,生成样本分为负例:

令人拍案叫绝的Wasserstein GAN(公式1 )

其中令人拍案叫绝的Wasserstein GAN是真实样本分布,令人拍案叫绝的Wasserstein GAN是由生成器产生的样本分布。对于生成器,Goodfellow一开始提出来一个损失函数,后来又提出了一个改进的损失函数,分别是

令人拍案叫绝的Wasserstein GAN(公式2)

令人拍案叫绝的Wasserstein GAN(公式3)

后者在WGAN两篇论文中称为“the - log D alternative”或“the - log D trick”。WGAN前作分别分析了这两种形式的原始GAN各自的问题所在,下面分别说明。

第一种原始GAN形式的问题

一句话概括:判别器越好,生成器梯度消失越严重。WGAN前作从两个角度进行了论证,第一个角度是从生成器的等价损失函数切入的。

首先从公式1可以得到,在生成器G固定参数时最优的判别器D应该是什么。对于一个具体的样本,它可能来自真实分布也可能来自生成分布,它对公式1损失函数的贡献是

令人拍案叫绝的Wasserstein GAN

令其关于令人拍案叫绝的Wasserstein GAN的导数为0,得

令人拍案叫绝的Wasserstein GAN

化简得最优判别器为:

令人拍案叫绝的Wasserstein GAN(公式4)

这个结果从直观上很容易理解,就是看一个样本令人拍案叫绝的Wasserstein GAN来自真实分布和生成分布的可能性的相对比例。如果令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN,最优判别器就应该非常自信地给出概率0;如果令人拍案叫绝的Wasserstein GAN,说明该样本是真是假的可能性刚好一半一半,此时最优判别器也应该给出概率0.5。

然而GAN训练有一个trick,就是别把判别器训练得太好,否则在实验中生成器会完全学不动(loss降不下去),为了探究背后的原因,我们就可以看看在极端情况——判别器最优时,生成器的损失函数变成什么。给公式2加上一个不依赖于生成器的项,使之变成

令人拍案叫绝的Wasserstein GAN

注意,最小化这个损失函数等价于最小化公式2,而且它刚好是判别器损失函数的反。代入最优判别器即公式4,再进行简单的变换可以得到

令人拍案叫绝的Wasserstein GAN(公式5)

变换成这个样子是为了引入Kullback–Leibler divergence(简称KL散度)和Jensen-Shannon divergence(简称JS散度)这两个重要的相似度衡量指标,后面的主角之一Wasserstein距离,就是要来吊打它们两个的。所以接下来介绍这两个重要的配角——KL散度和JS散度:

令人拍案叫绝的Wasserstein GAN(公式6)

令人拍案叫绝的Wasserstein GAN(公式7)

于是公式5就可以继续写成

令人拍案叫绝的Wasserstein GAN(公式8)

到这里读者可以先喘一口气,看看目前得到了什么结论:根据原始GAN定义的判别器loss,我们可以得到最优判别器的形式;而在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布令人拍案叫绝的Wasserstein GAN与生成分布令人拍案叫绝的Wasserstein GAN之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN之间的JS散度。

问题就出在这个JS散度上。我们会希望如果两个分布之间越接近它们的JS散度越小,我们通过优化JS散度就能将令人拍案叫绝的Wasserstein GAN“拉向”令人拍案叫绝的Wasserstein GAN,最终以假乱真。这个希望在两个分布有所重叠的时候是成立的,但是如果两个分布完全没有重叠的部分,或者它们重叠的部分可忽略(下面解释什么叫可忽略),它们的JS散度是多少呢?

答案是令人拍案叫绝的Wasserstein GAN,因为对于任意一个x只有四种可能:

令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN

令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN

令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN

令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN

第一种对计算JS散度无贡献,第二种情况由于重叠部分可忽略所以贡献也为0,第三种情况对公式7右边第一个项的贡献是令人拍案叫绝的Wasserstein GAN,第四种情况与之类似,所以最终令人拍案叫绝的Wasserstein GAN

换句话说,无论令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,JS散度就固定是常数令人拍案叫绝的Wasserstein GAN而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。

但是令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN不重叠或重叠部分可忽略的可能性有多大?不严谨的答案是:非常大。比较严谨的答案是:令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN的支撑集(support)是高维空间中的低维流形(manifold)时,令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN重叠部分测度(measure)为0的概率为1。

不用被奇怪的术语吓得关掉页面,虽然论文给出的是严格的数学表述,但是直观上其实很容易理解。首先简单介绍一下这几个概念:

  • 支撑集(support)其实就是函数的非零部分子集,比如ReLU函数的支撑集就是令人拍案叫绝的Wasserstein GAN,一个概率分布的支撑集就是所有概率密度非零部分的集合。

  • 流形(manifold)是高维空间中曲线、曲面概念的拓广,我们可以在低维上直观理解这个概念,比如我们说三维空间中的一个曲面是一个二维流形,因为它的本质维度(intrinsic dimension)只有2,一个点在这个二维流形上移动只有两个方向的自由度。同理,三维空间或者二维空间中的一条曲线都是一个一维流形。

  • 测度(measure)是高维空间中长度、面积、体积概念的拓广,可以理解为“超体积”。

回过头来看第一句话,“当令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN的支撑集是高维空间中的低维流形时”,基本上是成立的。原因是GAN中的生成器一般是从某个低维(比如100维)的随机分布中采样出一个编码向量,再经过一个神经网络生成出一个高维样本(比如64x64的图片就有4096维)。当生成器的参数固定时,生成样本的概率分布虽然是定义在4096维的空间上,但它本身所有可能产生的变化已经被那个100维的随机分布限定了,其本质维度就是100,再考虑到神经网络带来的映射降维,最终可能比100还小,所以生成样本分布的支撑集就在4096维空间中构成一个最多100维的低维流形,“撑不满”整个高维空间。

“撑不满”就会导致真实分布与生成分布难以“碰到面”,这很容易在二维空间中理解:一方面,二维平面中随机取两条曲线,它们之间刚好存在重叠线段的概率为0;另一方面,虽然它们很大可能会存在交叉点,但是相比于两条曲线而言,交叉点比曲线低一个维度,长度(测度)为0,可忽略。三维空间中也是类似的,随机取两个曲面,它们之间最多就是比较有可能存在交叉线,但是交叉线比曲面低一个维度,面积(测度)是0,可忽略。从低维空间拓展到高维空间,就有了如下逻辑:因为一开始生成器随机初始化,所以令人拍案叫绝的Wasserstein GAN几乎不可能与令人拍案叫绝的Wasserstein GAN有什么关联,所以它们的支撑集之间的重叠部分要么不存在,要么就比令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN的最小维度还要低至少一个维度,故而测度为0。所谓“重叠部分测度为0”,就是上文所言“不重叠或者重叠部分可忽略”的意思。

我们就得到了WGAN前作中关于生成器梯度消失的第一个论证:在(近似)最优判别器下,最小化生成器的loss等价于最小化令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN之间的JS散度,而由于令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数令人拍案叫绝的Wasserstein GAN,最终导致生成器的梯度(近似)为0,梯度消失。

接着作者写了很多公式定理从第二个角度进行论证,但是背后的思想也可以直观地解释:

首先,令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN之间几乎不可能有不可忽略的重叠,所以无论它们之间的“缝隙”多狭小,都肯定存在一个最优分割曲面把它们隔开,最多就是在那些可忽略的重叠处隔不开而已。

由于判别器作为一个神经网络可以无限拟合这个分隔曲面,所以存在一个最优判别器,对几乎所有真实样本给出概率1,对几乎所有生成样本给出概率0,而那些隔不开的部分就是难以被最优判别器分类的样本,但是它们的测度为0,可忽略。

最优判别器在真实分布和生成分布的支撑集上给出的概率都是常数(1和0),导致生成器的loss梯度为0,梯度消失。

有了这些理论分析,原始GAN不稳定的原因就彻底清楚了:判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。只有判别器训练得不好不坏才行,但是这个火候又很难把握,甚至在同一轮训练的前后不同阶段这个火候都可能不一样,所以GAN才那么难训练。

实验辅证如下:

令人拍案叫绝的Wasserstein GAN

WGAN前作Figure 2。先分别将DCGAN训练1,20,25个epoch,然后固定生成器不动,判别器重新随机初始化从头开始训练,对于第一种形式的生成器loss产生的梯度可以打印出其尺度的变化曲线,可以看到随着判别器的训练,生成器的梯度均迅速衰减。注意y轴是对数坐标轴。

第二种原始GAN形式的问题

一句话概括:最小化第二种生成器loss函数,会等价于最小化一个不合理的距离衡量,导致两个问题,一是梯度不稳定,二是collapse mode即多样性不足。WGAN前作又是从两个角度进行了论证,下面只说第一个角度,因为对于第二个角度我难以找到一个直观的解释方式,感兴趣的读者还是去看论文吧(逃)。

如前文所说,Ian Goodfellow提出的“- log D trick”是把生成器loss改成

令人拍案叫绝的Wasserstein GAN(公式3)

上文推导已经得到在最优判别器令人拍案叫绝的Wasserstein GAN

令人拍案叫绝的Wasserstein GAN(公式9)

我们可以把KL散度(注意下面是先g后r)变换成含的形式:

令人拍案叫绝的Wasserstein GAN(公式10)

由公式3,9,10可得最小化目标的等价变形

令人拍案叫绝的Wasserstein GAN

注意上式最后两项不依赖于生成器G,最终得到最小化公式3等价于最小化

令人拍案叫绝的Wasserstein GAN(公式11)

这个等价最小化目标存在两个严重的问题。第一是它同时要最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,一个要拉近,一个却要推远!这在直观上非常荒谬,在数值上则会导致梯度不稳定,这是后面那个JS散度项的毛病。

第二,即便是前面那个正常的KL散度项也有毛病。因为KL散度不是一个对称的衡量,令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN是有差别的。以前者为例

  • 令人拍案叫绝的Wasserstein GAN而时令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN对贡献趋近0

  • 令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN时,令人拍案叫绝的Wasserstein GAN令人拍案叫绝的Wasserstein GAN对贡献趋近正无穷

换言之,令人拍案叫绝的Wasserstein GAN对于上面两种错误的惩罚是不一样的,第一种错误对应的是“生成器没能生成真实的样本”,惩罚微小;第二种错误对应的是“生成器生成了不真实的样本” ,惩罚巨大。第一种错误对应的是缺乏多样性,第二种错误对应的是缺乏准确性。这一放一打之下,生成器宁可多生成一些重复但是很“安全”的样本,也不愿意去生成多样性的样本,因为那样一不小心就会产生第二种错误,得不偿失。这种现象就是大家常说的collapse mode。

第一部分小结:在原始GAN的(近似)最优判别器下,第一种生成器loss面临梯度消失问题,第二种生成器loss面临优化目标荒谬、梯度不稳定、对多样性与准确性惩罚不平衡导致mode collapse这几个问题。

实验辅证如下:

令人拍案叫绝的Wasserstein GAN

WGAN前作Figure 3。先分别将DCGAN训练1,20,25个epoch,然后固定生成器不动,判别器重新随机初始化从头开始训练,对于第二种形式的生成器loss产生的梯度可以打印出其尺度的变化曲线,可以看到随着判别器的训练,蓝色和绿色曲线中生成器的梯度迅速增长,说明梯度不稳定,红线对应的是DCGAN相对收敛的状态,梯度才比较稳定。

第二部分:WGAN之前的一个过渡解决方案

原始GAN问题的根源可以归结为两点,一是等价优化的距离衡量(KL散度、JS散度)不合理,二是生成器随机初始化后的生成分布很难与真实分布有不可忽略的重叠。

WGAN前作其实已经针对第二点提出了一个解决方案,就是对生成样本和真实样本加噪声,直观上说,使得原本的两个低维流形“弥散”到整个高维空间,强行让它们产生不可忽略的重叠。而一旦存在重叠,JS散度就能真正发挥作用,此时如果两个分布越靠近,它们“弥散”出来的部分重叠得越多,JS散度也会越小而不会一直是一个常数,于是(在第一种原始GAN形式下)梯度消失的问题就解决了