機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Prescribed Generative Adversarial Networksを読んだのでメモ

はじめに

Prescribed Generative Adversarial Networksを読んだのでメモ.

気持ち

GANの問題の一つであるmode collapseとモデル分布(生成分布)が陽に定義されておらず尤度等の計算ができないという2つの問題にアプローチした論文.mode collapseはモデルの学習に尤度やエントロピーなどデータ分布に対する損失が入っていないことが原因で,これはGANのデータ分布の扱いにくさに起因する.そのため本質的にはひとつの問題に帰着する.この論文ではGANの生成過程を修正することでモデル分布を扱いやすくしこれらの課題の解決を試みている.

Prescribed Generative Adversarial Networks (PresGAN)

PresGANはGANの生成過程をVAE-likeに修正することでモデル分布がwell-definedなGANを作る.論文のプロローグにてdeep generative model(DGM)に関する詳しい議論がされているが,簡単に言えばDGMはモデル分布(密度)p_\theta(\mathbf{x})は多様かつ質の高い(真のデータと見分けがつかない)データを生成し,また対数尤度が計算可能であるという性質を満たすべきということを述べている.

PresGANはGANのAdversarial lossにエントロピーの項を加えたものでひとつの大きな工夫として,学習にはエントロピーそのものではなくその勾配が求まればパラメータを学習できるため,直接勾配を計算する方法を提案している.

PresGANは事前分布p(\mathbf{z})と尤度p_\theta(\mathbf{x}|\mathbf{z})を以下のガウス分布として定義する.

\displaystyle
p(\mathbf{z})=\mathcal{N}(\mathbf{z}|\mathbf{0},\mathbf{I}),p_\theta(\mathbf{x}|\mathbf{z})=\mathcal{N}(\mathbf{x}|\mu_\theta(\mathbf{z}),\mathbf{\Sigma}_\theta(\mathbf{z}))

p_\theta(\mathbf{x}|\mathbf{z})の平均は\mathbf{z}を入力とするニューラルネットの出力で与えられる.ただし分散は対角成分のみに値を持つ\mathbf{\Sigma}_\theta(\mathbf{z})=\mathrm{diag}(\mathbf{\sigma}^2)で定義され,\mathbf{\sigma}は直接学習パラメータとしてもつ.

生成過程を上記のように定義することで生成分布のエントロピーがwell-definedになり,mode collapseを避けるために損失にエントロピー正則化を加えることができるようになる.さらに対数尤度の推定も可能となる.

Avoiding mode collapse via entropy regularization

まずエントロピー正則化について説明する.PresGANは次の損失関数の最適化を行う.

\displaystyle
\mathcal{L}_\text{PresGAN}(\theta,\phi)=\mathcal{L}_\text{GAN}(\theta,\phi)-\lambda\mathcal{H}(p_\theta(\mathbf{x}))

\mathcal{H}(p_\theta(\mathbf{x}))エントロピーで次のように定義される.

\displaystyle
\mathcal{H}(p_\theta(\mathbf{x}))=-\mathbb{E}_{p_\theta(\mathbf{x})}[\log p_\theta(\mathbf{x})]

\lambdaはハイパーパラメータで0のときは通常のGANになり,無限大のときはエントロピー最小化と一致する.すなわち第一項が生成データの質を,第二項が多様性を向上させる役割を持つ.

このエントロピーp_\theta(\mathbf{x})に関する期待値となっていて計算が困難だが,学習のためには勾配のみが必要となるためこれを次の直接不偏モンテカルロ推定によって求める.

Fitting Prescribed Generative Adversarial Networks

まず生成分布のパラメータを\theta,discriminatorのパラメータを\phiとする.\thetaの学習には次の勾配を計算する必要がある.

\displaystyle
\nabla_\theta\mathcal{L}_\text{PreGAN}(\theta,\phi)=\nabla_\theta\mathcal{L}_\text{GAN}(\theta,\phi)-\lambda\nabla_\theta\mathcal{H}(p_\theta(\mathbf{x}))

\nabla_\theta\mathcal{L}_\text{GAN}(\theta,\phi)の計算は正規分布からのサンプリングをVAEと同様次のreparameterizationされた形式で計算することで直接計算可能.

\displaystyle
\mathbf{x}(\mathbf{z},\epsilon;\theta)=\mu_\eta(\mathbf{z})+\sigma\odot\epsilon

ただし\etaニューラルネットのパラメータとし\epsilonは標準正規分布からサンプリングされた値.論文では勾配の細かい計算が説明されているが通常のGANと何ら変わりないので割愛.

問題は第二項のエントロピー微分.ここではTitsias and Ruizが提案した方法で勾配の推論を行う.reparameterizationされた表現を使ってエントロピーの勾配を次の形で表現する.

\displaystyle
\nabla_\theta\mathcal{H}(p_\theta(\mathbf{x}))=-\nabla_\theta\mathbb{E}_{p_\theta(\mathbf{x})}[\log p_\theta(\mathbf{x})]=-\nabla_\theta\mathbb{E}_{p(\epsilon)p(\mathbf{z})}[\log p_\theta(\mathbf{x})|_{\mathbf{x}=\mathbf{x}(\mathbf{z},\epsilon;\theta)}]\\
=-\mathbb{E}_{p(\epsilon)p(\mathbf{z})}[\nabla_\theta\log p_\theta(\mathbf{x})|_{\mathbf{x}=\mathbf{x}(\mathbf{z},\epsilon;\theta)}]\\
=-\mathbb{E}_{p(\epsilon)p(\mathbf{z})}[\nabla_\mathbf{x}\log p_\theta(\mathbf{x})|_{\mathbf{x}=\mathbf{x}(\mathbf{z},\epsilon;\theta)}\nabla_\theta\mathbf{x}(\mathbf{z},\epsilon;\theta)]

一行目の変形はデータの生成過程からreparameterizationのための標準ガウス分布からのサンプリング値\epsilonと平均を出力するニューラルネットへの入力\mathbf{z}の期待値に書き換えただけで,三行目は連鎖律から導かれる.二行目は単純に線形性という解釈をしたが,論文でscore function identity \mathbb{E}_{p_\theta(\mathbf{x})}[\nabla_\theta\log p_\theta(\mathbf{x})]=0を使ったと書いてある(勉強不足でよくわからなかった).\nabla_\theta\mathbf{x}(\mathbf{z},\epsilon;\theta)は期待値をモンテカルロ近似すれば勾配降下法により容易に微分の計算ができる.一方,\nabla_\mathbf{x}\log p_\theta(\mathbf{x})は次のように計算される.

\displaystyle
\nabla_\mathbf{x}\log p_\theta(\mathbf{x})=\frac{\nabla_\mathbf{x}p_\theta(\mathbf{x})}{p_\theta(\mathbf{x})}=\frac{\int\nabla_\mathbf{x}p_\theta(\mathbf{x},\mathbf{z})d\mathbf{z}}{p_\theta(\mathbf{x})}=\int\frac{\frac{\nabla_\mathbf{x}p_\theta(\mathbf{x}|\mathbf{z})}{p_\theta(\mathbf{x}|\mathbf{z})}p_\theta(\mathbf{x},\mathbf{z})}{p_\theta(\mathbf{x})}d\mathbf{z}\\
=\int\nabla_\mathbf{x}\log p_\theta(\mathbf{x}|\mathbf{z})p_\theta(\mathbf{z}|\mathbf{x})d\mathbf{z}=\mathbb{E}_{p_\theta(\mathbf{z}|\mathbf{x})}[\nabla_\mathbf{x}\log p_\theta(\mathbf{x}|\mathbf{z})]

これはこれでp_\theta(\mathbf{z}|\mathbf{x})の期待値が扱いにくく,Dieng and Paisleyの方法で推論することもできるが不偏推定量とはならない.ここでは事後分布から\mathbf{z}^{(1)},\dots,\mathbf{z}^{(M)}Hamiltonian Monte Carlo (HMC)によってサンプリングすることで勾配の不偏推定量を求める.

\displaystyle
\hat{\nabla}_\mathbf{x}\log p_\theta(\mathbf{x})=\frac{1}{M}\sum_{m=1}^M\nabla_\mathbf{x}\log p_\theta(\mathbf{x}|\mathbf{z}^{(m)}),\mathbf{z}^{(m)}\sim p_\theta(\mathbf{z}|\mathbf{x})

\mathbf{z}の初期値に\mathbf{x}(\mathbf{z},\epsilon;\theta)を生成するときに使われた値を用いることでサンプリングの収束を早めることができ,少ない繰り返しで良い勾配を得ることができる.実験的にはburn-inはわずか二回で良いとのこと.

最終的に\thetaに関する目的関数の勾配は次のように計算される.

\displaystyle
\hat{\nabla}_\theta\mathcal{L}_\text{PresGAN}(\theta,\phi)=\nabla_\theta\log(1-D_\phi(\mathbf{x}(\mathbf{z},\epsilon;\theta)))+\frac{\lambda}{M}\sum_{m=1}^M\nabla_\mathbf{x}\log p_\theta(\mathbf{x}|\mathbf{z}^{(m)})|_{\mathbf{x}=\mathbf{x}(\mathbf{z}^{(m)},\epsilon;\theta)}\times\nabla_\theta\mathbf{x}(\mathbf{z}^{(m)},\epsilon;\theta)

特に,ニューラルネットのパラメータ\etaのみに着目すると次のようになる.

\displaystyle
\hat{\nabla}_\eta\mathcal{L}_\text{PresGAN}(\theta,\phi)=\nabla_\eta\log(1-D_\phi(\mathbf{x}(\mathbf{z},\epsilon;\theta)))-\frac{\lambda}{M}\sum_{m=1}^M\frac{\mathbf{x}(\mathbf{z}^{(m)},\epsilon;\theta)-\mu_\eta(\mathbf{z}^{(m)})}{\mathbf{\sigma}^2}\nabla_\eta\mu_\eta(\mathbf{z}^{(m)})

そして標準偏差\sigmaに関する微分は次のようになる.

\displaystyle
\hat{\nabla}_\sigma\mathcal{L}_\text{PresGAN}(\theta,\phi)=\nabla_\sigma\log(1-D_\phi(\mathbf{x}(\mathbf{z},\epsilon;\theta)))-\frac{\lambda}{M}\sum_{m=1}^M\frac{\mathbf{x}(\mathbf{z}^{(m)},\epsilon;\theta)-\mu_\eta(\mathbf{z}^{(m)})}{\mathbf{\sigma}^2}\cdot\epsilon

また,エントロピーの項はdiscriminatorに依存しないためdiscriminatorのパラメータはAdversarial lossの勾配のみで計算される.PresGANでは生成データはガウス分布のreparameterizationの形から計算されるため常にノイズが加算された形になっており,普通に学習しようとするとdiscriminatorが簡単に局所解(完璧にrealとfakeを識別する)にはまってしまう.なのでこれを防ぐために生成されたデータと同様にノイズをrealデータにも加える.

\displaystyle
\hat{\mathbf{x}}=\mathbf{x}+\mathbf{\sigma}\odot\epsilon

\mathbf{\sigma}は生成分布の\mathbf{\sigma}に等しい.このノイズの加算は局所解にはまるのを防ぐと同時に学習も安定させるらしく,詳しくはHuszarを参照とのこと.ただPresGANの学習を安定させるにはノイズ加算だけでは不十分で,これは\mathbf{\sigma}が学習パラメータになっていることに起因するとのこと.具体的に学習失敗に2つの場合があり,標準偏差が大きくなりすぎてrealなデータもほぼノイズ画像になってしまう場合と,逆に小さくなりすぎてエントロピー項に対する勾配が支配的になってしまう場合がある.そのためここでは標準偏差\mathbf{\sigma}_\text{low}\leq\mathbf{\sigma}\leq\mathbf{\sigma}_\text{high}を満たすように値をtruncateする.ただし最大値,最小値はハイパーパラメータとのこと.

Enabling tractable predictive log-likelihood approximation

最後に対数尤度の推定について考える.ここでは未観測のデータ\mathbf{x}^\ast)の対数尤度\log p_\theta(\mathbf{x}^\ast)を重点サンプリングを使って推定する.

\displaystyle
\log p_\theta(\mathbf{x}^\ast)\approx\log\left(\frac{1}{S}\sum_{s=1}^S\frac{p_\theta(\mathbf{x}^\ast|\mathbf{z}^{(s)})\cdot p(\mathbf{z}^{(s)})}{r(\mathbf{z}^{(s)}|\mathbf{x}^\ast)}\right)

サンプル\mathbf{z}^{(1)},\dots,\mathbf{z}^{(S)}は提案分布r(\mathbf{z}|\mathbf{x}^\ast)からサンプリングされる.

良いr(\mathbf{z}|\mathbf{x}^\ast)を作るための方法は複数あるがここでは次のガウス分布とする.

\displaystyle
r(\mathbf{z}|\mathbf{x}^\ast)=\mathcal{N}(\mu_r,\mathbf{\Sigma}_r)

平均パラメータは事後確率最大化の解\mathrm{arg}\max_z(\log p_\theta(\mathbf{x}^\ast|\mathbf{z})+\log p(\mathbf{z}))とし,この解を事前に学習されたエンコーダーq_\gamma(\mathbf{z}|\mathbf{x}^\ast)で近似する.エンコーダーq_\gamma(\mathbf{z}|\mathbf{x})p_\theta(\mathbf{z}|\mathbf{x})間のreverse KLの最小化によって学習される.

\displaystyle
\mathrm{KL}(q_\gamma(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z}|\mathbf{x}))=\log p_\theta(\mathbf{x})-\mathbb{E}_{q_\gamma(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})p(\mathbf{z})-\log q_\gamma(\mathbf{z}|\mathbf{x})]

この式の二項目はVAEのELBOに等しく,このKLの最小化はELBOの最大化を含む.おそらくこれが通常のKLではなくreverse KLを使った理由.

提案分布の共分散は対角行列とし,エンコーダーの共分散行列のoverdispersed versionとする(勉強不足でoverdispersed version(過分散?)がわからないので詳細は割愛.ただ重点サンプリングの推定結果が良くなるという効果があるそう).操作としてはエンコーダーの共分散の各要素にハイパーパラメータ\gammaをかけるとのこと.\gammaの値は実験では1.2に設定.

まとめ

モデルの全体像を誤解を恐れずまとめるとデータ分布に関する損失を考慮したVAEGANという感じ.実験結果は期待通りでmode collapseを防ぎつつ生成画像の質も高い.