VQ-VAE介绍
转载自https://spaces.ac.cn/archives/6760。
VQ-VAE(Vector Quantised Variational AutoEncoder)首先出现在论文《Neural Discrete Representation Learning》中,出自Google团队。
PixelCNN
要追溯VQ-VAE的思想,就不得不谈到自回归模型。可以说,VQ-VAE做生成模型的思路,源于PixelRNN、PixelCNN之类的自回归模型。这类模型在生成图像时,实际上是离散的而不是连续的。以cifar10为例,它是\(32 \times 32\)大小的3通道图像,换言之它是一个\(32 \times 32 \times 3\)的矩阵,矩阵的每个元素是0~255之间的任意整数。这样一来,我们可以将它看成是一个长度为\(32 \times 32 \times 3=3072\)的句子,而词表大小是256,从而用语言模型的方法,来逐像素地、递归地生成一张图片(传入前面的所有像素,来预测下一个像素)。这就是所谓的自回归方法:
\[p(x)=p(x_1)p(x_2|x_1)...p(x_{3072}|x_1,x_2,...,x_{3071})\]
其中\(p(x_1),p(x_2|x_1),...,p(x_{3072}|x_1,x_2,...,x_{3071})\)每一个都是256分类问题,只不过依赖的条件有所不同。
自回归的方法很稳妥,也能有效地做概率估计,但生成模型只能逐像素的生成,导致生成速度很慢。上面举例的cifar10已经算是很小的图像了,但展开也相当于生成3072长度的句子。目前做图像生成好歹也要做到\(128 \times 128 \times 3\) 的才有说服力了吧,这总像素接近5万个(想想看要生成一个长度为5万的句子),真要逐像素生成会非常耗时。而且这么长的序列,不管是RNN还是CNN模型都无法很好地捕捉这么长的依赖。
原始的自回归模型还有一个问题,就是割裂了类别之间的联系。虽然说因为每个像素是离散的,所以看成256分类问题也无妨,但事实上连续像素之间的差别是很小的,纯粹的分类问题无法捕捉到这种联系。假如目标像素值是100,如果预测成99,因为类别不同,就会带来一个很大的损失。但从视觉上来看,像素值是100还是99其实差别不大,不应该有这么大的损失。
VQ-VAE
针对自回归模型的固有毛病,VQ-VAE提出的解决方案是:先降维,然后再对编码向量用PixelCNN建模。如一张\(n \times n \times 3\)的图像,可以降维到\(m \times m\)的编码向量,其中\(m << n\)。这样用PixelCNN对编码向量建模时,生成长度就不会特别大。
降维离散化
看上去这个方案很自然,似乎没什么特别的,但事实上一点都不自然。
因为PixelCNN生成的是离散序列,你想用PixelCNN建模编码向量,那就意味着编码向量也是离散的才行。而我们常见的降维手段,比如自编码器,生成的编码向量都是连续性变量,无法直接生成离散变量。同时,生成离散型变量往往还意味着存在梯度消失的问题(梯度无法反向传播)。还有,降维、重构这个过程,如何保证重构之后出现的图像不失真?如果失真得太严重,甚至还比不上普通的VAE的话,那么VQ-VAE也没什么存在价值了。
幸运的是,VQ-VAE确实提供了有效的训练策略解决了这两个问题。
最近邻重构
在VQ-VAE中,一张\(n \times n \times 3\)的图片\(x\)先被传入一个encoder
中,得到连续的编码向量\(z \in R^{m \times m \times d}\):
\[z = encoder(x)\]
这里的\(z\)是一个大小为\(m \times m \times d\)的矩阵。另外,VQ-VAE还维护一个Embedding层,我们也可以称其为编码表,记为:
\[E=[e_1,e_2,...,e_K]\]
这里每个\(e_i\)都是一个大小为\(d\)的向量。接着,VQ-VAE通过最近邻搜索,将\(z\)中每个位置的向量映射为这\(K\)个向量之一:
\[z_t \rightarrow e_k,k=argmin_j\|z_t-e_j\|_2\]
我们可以将\(z\)对应的编码表矩阵记为\(z_q\),我们认为\(z_q\)才是最后的编码结果。最后将\(z_q\)传入一个decoder
,希望重构原图\(\hat{x}=decoder(z_q)\)。
整个流程是:
\[x \stackrel{encoder}{\longrightarrow} z \stackrel{最近邻}{\longrightarrow} z_q \stackrel{decoder}{\longrightarrow} \hat{x}\]
这样一来,因为\(z_q \in R^{m \times m \times d}\)中每个位置的向量是编码表\(E\)中的向量之一,所以它实际上就等价于\(1,2,...,K\)这\(K\)个整数组成的一个大小为\(m \times m\)的整数矩阵,这样就完成了编码向量的离散化。
自行设计梯度
我们知道,如果是普通的自编码器,直接用下述loss进行训练即可:
\[\|x-decoder(z)\|_2^2\]
但是,在VQ-VAE中,我们用来重构的是\(z_q\)而不是\(z\),那么似乎应该用这个loss才对:
\[\|x-decoder(z_q)\|_2^2\]
但问题是\(z_q\)的构建过程包含了\(argmin\),这个操作是没有梯度的,所以如果用第二个loss的话,我们没法更新encoder
。
换言之,我们的目标其实是让\(\|x-decoder(z_q)\|_2^2\)最小,但是却不好优化;而\(\|x-decoder(z)\|_2^2\)容易优化,但却不是我们的优化目标。那怎么办呢?
VQ-VAE使用了一个很精巧也很直接的方法,称为Straight-Through Estimator
,它最早源于Benjio的论文《Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation》,在VQ-VAE原论文中也是直接抛出这篇论文而没有做什么讲解。但事实上直接读这篇原始论文是一个很不友好的选择,还不如直接读源代码。
事实上Straight-Through的思想很简单,就是前向传播的时候可以用想要的变量(哪怕不可导),而反向传播的时候,用你自己为它所设计的梯度。根据这个思想,我们设计的目标函数是:
\[\|x-decoder(z+sg(z_q-z))\|_2^2\]
其中sg
表示stop gradient,即不需要它的梯度。这样一来,前向传播计算的时候,就直接等价于\(decoder(z+z_q-z)=decoder(z_q)\),然后反向传播的时候,由于\(z_q-z\)不提供梯度,所以等价于\(decoder(z)\),这样就允许我们对encoder
进行优化了。
顺便说一下,基于这个思想,我们可以为很多函数自定义梯度。比如\(x+sg[relu(x) - x]\)就是将\(relu(x)\)的梯度定义为恒为1,但是在前向传播是又和\(relu(x)\)完全等价。当然,用同样的方法我们可以随便指定一个函数的梯度,至于有没有实用价值,则要具体任务具体分析了。
维护编码表
要注意,根据VQ-VAE的最近邻搜索的设计,我们应该期望\(z_q\)和\(z\)是很接近的,但事实上未必如此,即使\(\|x-decoder(z)\|_2^2\)和\(\|x-decoder(z_q)\|_2^2\)都很小,也不意味着\(z_q\)和\(z\)差别很小(即\(f(z_1)=f(z_2)\)不意味着\(z_1=z_2\))。
所以,为了让\(z_q\)和\(z\)更接近,我们可以直接将\(\|z-z_q\|_2^2\)加入到loss中:
\[\|x-decoder(z+sg[z_q-z])\|_2^2 + \beta \|z-z_q\|_2^2\]
除此之外,还可以做的更仔细一些。由于编码表\(z_q\)相对是比较自由的,而\(z\)要尽量保证重构效果,所以我们应当尽量让\(z_q\)去靠近\(z\)而不是让\(z\)去靠近\(z_q\)。而因为\(\|z-z_q\|_2^2\)的梯度等于对\(z_q\)的梯度加上对\(z\)的梯度,所以我们将它等价地分解为:
\[\|sg[z]-z_q\|_2^2 + \|z-sg[z_q]\|_2^2\]
第一项等于固定\(z\),让\(z_q\)靠近\(z\);第二项则反过来固定\(z_q\),让\(z\)靠近\(z_q\)。注意这个“等价”是对于反向传播(求梯度)来说的,对于前向传播(求loss)它是原来的两倍。根据我们刚才的讨论,我们希望让\(z_q\)去靠近\(z\)多于让\(z\)去靠近\(z_q\),所以可以调一下最终的loss比例:
\[\|x-decoder(z+sg[z_q-z])\|_2^2 + \beta \|sg[z]-z_q\|_2^2 + \gamma \|z-sg[z_q]\|_2^2\]
其中\(\gamma < \beta\),在原论文中使用的是\(\gamma = 0.25 \beta\)
注:还可以用滑动平均的方式更新编码表,详情请看原论文。
\[z_q = \alpha z_q + (1 - \alpha)z\]
这等价于指定使用SGD优化\(\|sg[z]-z_q\|_2^2\)这一项loss,该方案被VQ-VAE-2所使用。
拟合编码分布
经过上述设计之后,我们终于将图片编码为\(m \times m\)的整数离散矩阵了,即编码矩阵。我们可以用自回归模型比如PixelCNN,来对编码矩阵进行拟合。
拟合过程为通过模型预测编码矩阵的分布,即直接通过模型预测输出一个\(m \times m\)大小的整数矩阵。
通过PixelCNN得到编码分布后,就可以随机生成一个新的编码矩阵,然后通过编码表\(E\)映射为浮点数矩阵\(z_q\),最后经过decoder
得到一张图片。
一般来说,现在的\(m \times m\)比原来的\(n \times n \times 3\)要小得多,所以用自回归模型对编码矩阵进行建模,要比直接对原始图片进行建模要容易得多。