機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Invertible Residual Networksを読んだのでメモ

はじめに

Invertible Residual Networksを読んだのでメモ.

気持ち

昨今のdeepによる生成モデルは識別モデルに比べ汎用的なモデルの構造がないので,このgapを埋める新たなアーキテクチャを提案.具体的にはResNetそのままのアーキテクチャでflow-based generative modelを作るというもの.

Flow-based generative modelはNICEやRealNVP,GLOWなどでお馴染みのモデルで,今回は詳細は割愛.

Invertible ResNet

基本的にはResNetは逆変換(出力から入力を計算)することはできない.自己回帰構造にしたりすれば逆変換可能だが,そういった場合には元の構造からのちょっとした組み換えが必要になる.

ここではバナッハの不動点定理を利用して逆変換を求める.バナッハの不動点定理は縮小写像T:X\rightarrow Xに対して唯一の不動点T(X^\ast)=X^\astを持つというもの.

ResNetの逆変換が求まることを保証するためにはresidual block (すなわちx_{t+1}\leftarrow x_t+g_{\theta_t}(x_t)の変換) に対して次の定理を満たす必要がある.

\displaystyle
\text{Lip}(g_{\theta_t})\lt 1,\ \textit{for}\ \textit{all}\: t=1,\dots,T

この\text{Lip}(g_{\theta_t})g_{\theta_t}のリプシッツ定数を意味していて,residual blockのリプシッツ定数が1未満である必要がある.これにより出力値をx^0=yとして表現した時,バナッハの不動点定理から次の関係が成り立つ.

\displaystyle
\| x-x^n\|_2\leq\frac{\text{Lip}(g)^n}{1-\text{Lip}(g)}\| x^1-x^0\|_2

以上の関係から逆変換はx^{i+1}:=y-g(x^i)\ \text{for}\ i=0,\dots,nとして推定が可能.

問題はどのようにしてリプシッツ定数が1未満という制約を満たすか.今residual blockを陽に書き下すとするとg=W_3\phi(W_2\phi(W_1))として表現できる.ただし,W_iは畳み込み層,\phi非線形関数を表す.もしW_iのspectral norm \|W_i\|_2が1未満なら\text{Lip}(g)\lt 1が成り立つため,\|W_i\|_2に制約を与える.ここで,power-iterationによりW_iのspectral normが\tilde{\sigma_i}\leq\|W_i\|_2として推定可能なため,W_iに次のような正規化を行うことでspectral normを抑えることができる.

\displaystyle
\tilde{W}_i=cW_i/\tilde{\sigma}_i,\ \text{if}\: c/\tilde{\sigma}_i\lt 1

ここでcc\lt 1を満たすハイパラ.基本的には推定される\tilde{\sigma}_iは真のspectral norm以下の値であるため,\| W_i\|_2\leq cは保障されないが,学習後に特異値分解を使って正確に\| W_i\|_2を評価可能なため確実に\text{Lip}(g)\lt 1を満たすことができる.

以上から,residual blockのリプシッツ定数を1未満に正規化することでResNetをinvertibleにできることがわかったが,flow-based generative modelは目的関数に写像ヤコビアン行列式を計算する必要がある.今回の場合では,x=F^{-1}(z)とした時,\ln |\det J_F(x)|を計算しなければならない.ここでは|\det J_F(x)|=\det J_F(x)が成り立つことと,非特異な行列A\in\mathbb{R}^{d\times d}において\ln\det(A)=\mathrm{tr}(\ln(A))が成り立つことが知られているらしいので,次のように計算ができる.

\displaystyle
\ln|\det J_F(x)|=\mathrm{tr}(\ln J_F)

さらにF(x)=(I+g)(x)であることから最終的な尤度は\ln p_x(x)=\ln p_z(z)+\mathrm{tr}(\ln(I+J_g(x)))として評価することができる.また,行列の対数のトレースは|J_g(x)|\lt 1であるならば,次のように計算できることが知られているらしい.

\displaystyle
\mathrm{tr}(\ln(I+J_g(x)))=\sum_{k=1}^\infty(-1)^{k+1}\frac{\mathrm{tr}(J_g^k)}{k}

ここではリプシッツ定数が\text{Lip}(g_t)\lt 1を満たすため,次のように上限と下限を与えることができるらしい.

\displaystyle
d\sum_{t=1}^T\ln(1-\text{Lip}(g_t))\leq\ln |\det J_F(x)|\\
d\sum_{t=1}^T\ln(1+\text{LiP}(g_t))\geq\ln |\det J_F(x)|

ここでのdはおそらく次元数.基本的には前述のpower seriesによる計算は無限級数を扱うことやトーレスとその中身の行列がヤコビアンイテレーションとなっていて計算コストの面から扱いにくい.ただ,\mathbb{E}[v]=0,\ \mathrm{Cov}(v)=Iを満たすvを導入すれば,ヤコビアンベクターヤコビアンv^T J_gとして,トレースの計算を\mathrm{tr}(A)=\mathbb{E}_{p(v)}=[v^T Av]として行えるらしく,無限級数nで打ち切ることで問題を軽くすることが可能.

上記をまとめると\ln |\det(I+J_g)|は次のように計算される.

\displaystyle
PS(J_g,n):=\sum_{k=1}^n(-1)^{k+1}\frac{\mathrm{tr}(J_g^k)}{k}

ここでこの近似による誤差の上限は次のようになるらしい.

\displaystyle
|PS(J_g,n)-\ln\det(I+J_g)|\leq -d\left(\ln(1-\text{Lip}(g))+\sum_{k=1}^n\frac{\text{Lip}(g)^k}{k}\right)

また収束率はc:=\text{Lip}(g)として次のように与えられるとのこと.

\displaystyle
|\nabla_\theta(\ln\det(I+J_g))-PS(J_g,n))|_\infty=\mathcal{O}(c^n)

基本的にnは5から10くらいの値で得られる値のバイアスが0.001 bit per dimension以下になるとのこと.

まとめ

終始リプシッツ定数が役立っている手法.後半のトレースの計算方法とかよく知っているなという感想以外浮かばない.証明までは追えてないので推定値の誤差や収束率などはあまり理解できてないが,細々したハイパラ(c,nなど)の値が実際に応用を考える上では結構扱いにくいんじゃないかという印象.