Concrete Distributionについて論文を読んだ
はじめに
The Concrete Distribution: A Continuous Relaxation of Discrete Random Variablesを読んだのでメモ. ICLR2017の論文でGumbel Max Trickをベースに離散分布を連続分布に緩和することでVAEなどに応用しようというもの.Gumbel Max Trickについては以前に記事を書いた.
Concrete: CONtinuous relaxation of disCRETE
VAEでは正規分からサンプリングする際,reparameterization trickを使って微分可能な形でサンプリングをすることで正規分布をデータの分布として仮定したautoencoderを学習した.今回読んだ論文では正規分布のような連続分布ではなく離散分布(カテゴリカル分布)をデータの分布として仮定したいというもの.
カテゴリカル分布からのサンプリングにはGumbel Max Trickを使った方法がある.Gumbel Max TrickではGumbel分布からサンプリングされた摂動をカテゴリカル分布に加え,argmaxを取ることでカテゴリカル分布からサンプリングを行う.しかしargmaxの処理が入ってしまうため,backpropで勾配を計算する際にargmaxが取られたインデックスにしか勾配が伝播しないので,分布を正しく学習することができないという問題がある.そこでカテゴリカル分布を連続になるように緩和することで解決しようというのがConcrete: CONtinuous relaxation of disCRETE.別な言い方をすれば,Gumbel Max Trickでサンプリングされた値はone-hotベクトルになってしまう,すなわち単体上の頂点に位置するため,これを単体の内部に行くようにしようと言うのがconcrete.以下が例.
どのように緩和するかという問題だが,argmaxの代わりにsoftmaxを使うだけ.とても単純.
はカテゴリカル分布のパラメータでは標準Gumbel分布からサンプリングされた値.重要なのが温度パラメータでこの値によって緩和の具合が変わる(上の図がを変化させた時の例.大きくなるほど単体の中心に近付く).上記緩和を使えばbackpropで勾配が伝播するので(誤解を招く言い方かもしれないが)分布の形状を正しく学習できる.上記の緩和を計算グラフにすると以下.このように確率分布からのサンプリングのノードが入った計算グラフをstochastic computation graphsというらしい.
ここで問題なのはargmaxを勝手にsoftmaxにしてしまったため,もはや確率変数はカテゴリカル分布に従わない.そこでをconcrete random variableとし,この確率変数が従う分布Concrete Distributionを導入する.Concrete Distributionの密度関数は以下で定義される.
この密度分布は以下のpropositionを満たす.
(a) (Reparametarization) がなら
(b) (Rounding)
(c) (Zero temperature)
(d) (Convex eventually) ならはlog-convex
(a)に関してはそもそもから定義されているので成り立つ.(b)と(c)も計算から求まって,元のカテゴリカル分布と一致する.(d)の証明に関してはスペースを使うので論文のAppendix参照.ただし(d)は重要で温度パラメータを決める一つの指標になる.
また,ここでは詳細は省くが,2クラスの場合(binary case)の導出も行なっており,その際には標準Gumbel分布ではなくLogistic分布からのサンプリング結果を摂動として与える(Gumbel分布の差がLogistic分布になるからとのこと).
VAEへの応用
VAEへの応用を考える.カテゴリカル分布を使った際の損失関数は以下のようになる.
は近似事後分布でをパラメータとしたカテゴリカル分布.分子はカテゴリとデータの同時分布を表しており,はをパラメータとするカテゴリカル分布.これをConcrete distributionを使って緩和することを考える.具体的にはをとすることを考える.問題は,backpropを行うことができればいいため損失関数の緩和の方法は以下のように複数考えられる.
はを満たすone-hotベクトル.上記において一番上のみが以下のlower boundを保証する.
基本的にはいずれかの緩和された損失を適用すればいいが,愚直に用いるとunderflowを起こしてしまうので実装の際には以下のようにlog-spaceで計算を行う方がいいらしい.
またこの確率変数の密度関数(log-density)は以下で定義される.
ということで実践的には上記のExpConcreteを使って緩和を行う.すると緩和された損失関数は以下のようになる.
この損失関数を使えばカテゴリカル分布をつかったVAEの学習ができる. 論文のAppendixの最後にここで説明したConcreteとExpConcreteのほかに2クラスの場合(Binary case)のBinConcreteや実験で使われた分布がまとまったチートシートが書かれているので便利.またその他Appendixに実装の際や問題に適用する際のTipsが色々書いてあるので参考に.
終わり
今回読んだ論文とほぼ同じコンセプトのGumbel-softmaxというものが同時期にarxivに投稿されており,同様にICLR2017に採択されている.どっちもgoogle(deepmindとbrain)でなんだかなという気持ち.
あと,softmax入っているから温度係数が小さいと容易に勾配消失する気がするけど論文で言及されてなかった.時間を見て確認がてら実装したい.