機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Gumbel Max Trickについて勉強した

はじめに

Gumbel Max Trickについて勉強したからメモ. Gumbel Max Trickのお気持ちとしてはカテゴリカル分布から効率よくサンプリングしたいというもの(多分。。。). 言ってしまえばカテゴリカル分布におけるreparameterization trick

accept-rejectアルゴリズム

カテゴリカル分布からの基本的なサンプリング方法の一つにaccept-rejectアルゴリズムがある. カテゴリカル分布のパラメータ(各カテゴリがサンプリングされる確率)を \mathbf{\alpha}=\alpha_1,\alpha_2,...,\alpha_mとする. u\sim\mathrm{Uniform}(0,\max\alpha_i) j\sim{1,2,...,m} を一様にサンプリングし,\alpha_j\geq uの時に jを返し,\alpha_j\lt uの時にはまたサンプリングを繰り返すというもの.\alpha_j\lt uである限り二つの変数をサンプリングし続けなければならないから効率がわるい.

Gumbel Max Trick

Gumbel Max Trickでは標準Gumbel分布からカテゴリの数だけサンプリングを行い,得られた値を摂動としてカテゴリカル分布のパラメータに加えた後,argmaxを取ることでサンプリングをしようというもの.


\displaystyle g_i \sim \mathrm{Gumbel}


\displaystyle j = \mathrm{arg}\max(\log\alpha_i+g_i)

これならサンプリングの回数がカテゴリ数 mになりaccept-rejectアルゴリズムに比べ効率が良い.

Gumbel分布とは

突然出てきたGumbel分布だが,Gumbel分布は極値分布と呼ばれる分布の一つで,確率変数の最大値や最小値が従う分布らしい.極値分布の詳しい性質等は教科書とかで勉強しないと理解できなさそうなので割愛. Gumbel分布の密度関数と分布関数は以下で定義される.


\displaystyle Gu_{\mu,\eta}(x)=\frac{1}{\eta}\exp\left(-\frac{x-\mu}{\eta}-\exp\left(-\frac{x-\mu}{\eta}\right)\right)


\displaystyle F_{\mu,\eta}(x)=\exp\left(-\exp\left(-\frac{x-\mu}{\eta}\right)\right)

 \muはlocation parameter, \etaはscale parameterと呼ばれる分布のパラメータで,標準Gumbel分布の場合\mu=0,\eta=1となる.

なぜGumbel分布

ここで疑問になるのはなぜGumbel分布を使ってサンプリングができるのかということ.証明の方法は色々あるようだけど一番分かり易かったのがこれ. 手順としては,location parameterが\log(\alpha_i),scale parameterが1のGumbel分布からサンプリングされた値を z_{1,...,m}とし,z_iがもっとも大きくなる確率(すなわちiがサンプリングされる確率)が元のカテゴリカル分布のパラメータ \alpha_iと一致すればいいでしょうというもの.すなわち


\displaystyle P(z_i \mathrm{が最大}|z_i,\alpha_{1,..,m})=\alpha_i

を証明する.location parameterが\log(\alpha_i)のgumbel分布に関して考えているのは,locaton parameterが正規分布の平均と同様の役割をするため,VAEの論文で使用されている正規分布のreparameterization trickと同様に Gu_{\log\alpha_i, 1}(x)=\log\alpha_i+Gu_{0,1}(x)と書けるから.z_iはi.i.dにサンプリングされるため,  P(z_i \mathrm{が最大}|z_i,\alpha_{1,..,m})=P(z_i\gt z_1,...,z_i\gt z_m|\alpha_{1,...,m})=P(z_i\gt z_1|\alpha_1)...P(z_i\gt z_m|\alpha_m)と分解され,P(z_i\gt z_j)はGumble分布の分布関数であることから


\displaystyle P(z_i \mathrm{が最大}|z_i,\alpha_{1,..,m}) = \prod_{j\neq i}\exp(-\exp(-(z_i-\log\alpha_j)))

と書直せる.さらにこれを z_iに対して周辺化,すなわち P(z_i \mathrm{が最大} | \alpha_{1,..,m}) = \int P(z_{i}|\alpha_{i})P(z_{i} \mathrm{が最大}|z_{i},\alpha_{1,..,m})dz_{i}を計算する.


\displaystyle
\int P(z_i|\alpha_{i})P(z_i \mathrm{が最大}|z_i,\alpha_{1,..,m})dz_i = \int\exp(-(z_i-\log\alpha_i)-\exp(-(z_i-\log\alpha_i)))\prod_{j\neq i}\exp(-\exp(-(z_i-\log\alpha_j)))dz_i

ここで \prod_{j\neq i} P(z_i|\alpha_i)\exp(-\exp(-(z_i-\log\alpha_i)))の項をくくると\prod_{j}でまとめて書ける.


\displaystyle
\int\exp(-(z_i-\log\alpha_i)-\exp(-(z_i-\log\alpha_i)))\prod_{j\neq i}\exp(-\exp(-(z_i-\log\alpha_j)))dz_i \\
\displaystyle= \int\exp(-(z_i-\log\alpha_i))\prod_j\exp(-\exp(-(z_i-\log\alpha_j)))dz_i \\
\displaystyle= \int\exp(-(z_i-\log\alpha_i))\exp(-\sum_j\exp(-(z_i-\log\alpha_j)))dz_i \\
\displaystyle=\int\exp(-(z_i-\log\alpha_i)-\exp(-z_i)\sum_j\exp(\log\alpha_j))dz_i

4行目はz_i\prod_jに関係ないことと\expの掛け算を足し算にして書き直した.後は出てきた積分を計算するだけ.ここはちょっとしたテクニックが必要でs=\exp(-(z_i+\log\alpha_i))という変数変換をするとガンマ分布の確率密度関数の形になるから積分はclosed formに求まって最終的には


\displaystyle P(z_i \mathrm{が最大} | \alpha_{1,..,m}) = \frac{\exp(\log\alpha_i)}{\sum_j\exp(\log\alpha_j)}=\frac{\alpha_i}{\sum_j\alpha_j}

となり,元のカテゴリカル分布の値になる.素晴らしい.

Gumbel分布からサンプリング

Gumbel分布を使えばいいのは納得したとして,標準Gumbel分布からサンプリングするのは簡単なのかという疑問が残る.これは逆関数法を使って簡単にできる. 分布関数は広義単調増加関数で値が区間(0,1)に一様に分布するため,一般にある確率変数Xが従う分布関数 F(X) U=F(X)という関係を満たす.ただし, UFによりXから一意にきまり 0\lt U\lt1の範囲で値をとる.この時Fに関してXUは一対一に対応するので分布関数には逆関数が存在する.するとX=F^{-1}(U)からXF^{-1}によりUから一意に決まることがわかる.よって,U\sim\mathrm{Uniform}(0,1)のように一様分布からサンプリングすることで任意の確率変数を一様分布からサンプリングできる.以上が逆関数法. 標準Gumbel分布の分布関数の逆関数F^{-1}(x)=-\log(-\log(x))となるため


\displaystyle u\sim\mathrm{Uniform}(0,1) \\
\displaystyle g = -\log(-\log(u))

として簡単にサンプリングが行える.

Gumbel Max Trickを応用した論文は最近Google中心にいっぱい出ているみたいなので読んで理解を深めていきたいという思い.