機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

THE UNREASONABLE EFFECTIVENESS OF (ZERO) INITIALIZATION IN DEEP RESIDUAL LEARNINGを読んだのでメモ

はじめに

THE UNREASONABLE EFFECTIVENESS OF (ZERO) INITIALIZATION IN DEEP RESIDUAL LEARNINGを読んだのでメモ.ICLR2019の査読中論文.

2019/2/4 ICLR2019に採択されていたがタイトルがFIXUP INITIALIZATION:RESIDUAL LEARNING WITHOUT NORMALIZATIONに変わっていた.内容の再確認はしていないのでもしかしたら細かな議論が変わっているかも.

気持ち

NNの重みをゼロで初期化してしまおうという論文.ここではbatch normalizationのような正規化を行うことなく深いネットワークが学習可能か,さらに学習可能だとして正規化を行う場合とそうでない場合で同じ学習率を使って学習可能か,学習の速度や汎化性は等しいかという問題を投げかけている.論文曰くどちらもyesでゼロ初期化によってそれは可能ということ.

ResNetの勾配爆発問題

勾配爆発や勾配消失問題を解決するために様々なdeep learning toolでも実装されているxavierの初期化やheの初期化が提案されている.ただこれらの初期化方法はbatch normalizationを含まないネットワークで議論されていて,実はresidual connectionのせいでこれらの初期化方法は勾配爆発を引き起こすことがあるとのこと.この問題はReLUを使ったネットワークについて言及されていたのでここではpositively homogeneous functionの場合に拡張する(positively homogeneous functionは論文の定義1参照).まずResNetのresidual blocksを\{F_1,\dots,F_L\}とし,入力を\mathbf{x}_0とするとresnetは次のように定式化される.

\displaystyle
\mathbf{x}_l=\mathbf{x}_0+\sum_{i=0}^{l-1}F_i(\mathbf{x}_i)

ここでは初期化について考えるので入力\mathbf{x}_0は固定として,重みの曖昧性のみ考える.\mathbf{x}_lの各座標の分散の和を\mathrm{Var}[\mathbf{x}_l]とする.簡単のため各ブロックが平均ゼロ\mathbb{E}[F_l(\mathbf{x}_l)|\mathbf{x}_l]=0で初期化されていると仮定すれば,\mathbf{x}_{l+1}=\mathbf{x}_l+F_l(\mathbf{x}_l)の関係から\mathrm{Var}[\mathbf{x}_{l+1}]=\mathbb{E}[\mathrm{Var}[F(\mathbf{x}_l)|\mathbf{x}_l]]+\mathrm{Var}(\mathbf{x}_l)となる.すなわちResNetの構造は分散が層の増加に伴って増えることで\mathbf{x}_lが0になることを防いでいると言える.言い換えれば\mathbb{E}[\mathrm{Var}[F(\mathbf{x}_l)|\mathbf{x}_l]]\gt 0ならば,\mathrm{Var}[\mathbf{x}_l]\lt \mathrm{Var}[\mathbf{x}_l]が成り立つということ.Heの初期化を使えば\mathrm{Var}[F_l(\mathbf{x}_l)|\mathbf{x}_l]が入力の分散\mathrm{Var}[\mathbf{x}_l]とほぼ等しくなるため\mathrm{Val}[\mathbf{x}_{l+1}]\approx 2\mathrm{Var}[\mathbf{x}_l]となる.positively homogeneous functionを使った正規化がないblockの場合には次のように出力の分散が深さに対して指数的に増加する.

\displaystyle
\mathrm{Var}[\mathbf{x}_l]=\mathrm{Var}[\mathbf{x}_0]+\sum_{i=0}^{l-1}\mathrm{Var}[\mathbf{x}_i]\mathbb{E}\left[\mathrm{Var}\left[F_i\left(\frac{\mathbf{x}_i}{\sqrt{\mathrm{Var}[\mathbf{x}_i]}}\right)\mid\mathbf{x}_i\right]\right]=\Omega(2^l)

この分散の増加は勾配爆発の原因となり学習においてはあまり嬉しくない性質.そこで,初期状態においてある時点のactivationと重みの勾配のノルムがcross-entropy lossにおいてある下限を持つことを示す.

Definition 1 (positively homogeneous function of first degree)

ある関数f:\mathbb{R}^m\rightarrow\mathbb{R}^nがpositively homogeneous of first degreeであるとは任意の入力\mathbf{x}\in\mathbb{R}^mにおいて\alpha\gt 0とした時,f(\alpha\mathbf{x})=\alpha f(\mathbf{x})という性質を持つ.

Definition 2 (positively homogeneous set of first degree)

\Thetaf(\mathbf{x})のパラメータの集合とし,\Theta_{ph}=\{\theta_i\}_{i\in S}\subset\Thetaとする.\Theta_{pH}がpositively homogeneous setであるとは,任意の[tex:\alpha\gt 0においてf(\mathbf{x};\Theta\setminus\Theta_{ph},\alpha\Theta_{ph})=\alpha f(\mathbf{x};\Theta\setminus\Theta_{ph},\Theta_{ph})という性質を持つ.ただし\alpha\Theta_{ph}\{\alpha\theta_i\}_{i\in S}

P.h.(positively homogeneous)関数の例としてはNNで使われる演算(fully-connected,convolution,pooling,addition,concatenation,dropout,ReLUなど)があげられる.さらにp.h.関数に関して次の命題が成り立つ.

Propostion 1

P.h.関数で構成される関数もまたp.h.関数.

ここで一般的な分類問題を考える.すなわちcクラスの分類をcross-entropy lossを用いて解く.fを最終出力がsoftmax層のNNを表す関数として,cross-entropy lossをl(\mathbf{z},\mathbf{y})=-\mathbf{y}^T(\mathbf{z}-\log\sum\exp(\mathbf{z}))と定義する.ただし,\mathbf{y}はone-hotラベルで,\mathbf{z}=f(\mathbf{x})\in\mathbb{R}^cとする.ミニバッチ\mathcal{D}_M=\{(\mathbf{x}^{(m)},\mathbf{y}^{(m)}\}_{m=1}^Mを使った学習を考えれば,cross-entropy lossはl_{avg}(\mathcal{D}_M)=\frac{1}{M}\sum_{m=1}^Ml(f(\mathbf{x}^{(m)}),\mathbf{y}^{(m)})となる.ここでネットワークfについて次の過程を置く.

  1. fはsequentialな構成,すなわちネットワークのブロックをp.h.関数\{f_i\}_{i=1}^Lとすれば入力から出力までをf(\mathbf{x}_0)=f_L(f_{L-1}(\dots f_1(\mathbf{x}_0)))と計算可能である.

  2. 全結合層の重みの要素はi.i.dに平均0の対象な分布からサンプリングされている.

Theorem 1

i番目のブロックの入力を\mathbf{x}_{i-1}とすれば,仮定1から次の関係が得られる.

\displaystyle
\left|\frac{\partial l}{\partial\mathbf{x}_{i-1}}\right|\geq\frac{l(\mathbf{z},\mathbf{y})-H(\mathbf{p})}{|\mathbf{x}_{i-1}|}

\mathbf{p}はsoftmaxによって得られる確率で,Hはシャノンエントロピー.証明は記事が長くなりそうなのでAppendix A参照.エントロピーは上界を持ち,|\mathbf{x}_{i-1}|は下位のブロックでは小さい.ロスにおける爆発は下位のブロックの入力に依存して勾配のノルムが大きくなることで起こる.次の定理でp.h. setの勾配のノルムが下限を持つことを示す.

Theorem 2

仮定1から次の関係が成り立つ.

\displaystyle
\left|\frac{\partial l_{avg}}{\partial\Theta_{ph}}\right|\geq\frac{1}{M|\Theta_{ph}|}\sum_{m=1}^Ml(\mathbf{z}^{(m)},\mathbf{y}^{(m)})-H(\mathbf{p}^{(m)})=G(\Theta_{ph})

さらに,仮定1と2から次の関係が得られる.

\displaystyle
\mathbb{E}G(\Theta_{ph})\geq\frac{\mathbb{E}[\max_{i\in[c]}z_i]-\log(c)}{|\Theta_{ph}|}

正規化のないResNetにおけるp.h. setsの例としては以下の3つがある.

  1. 一番最初のmax pooling前のconvolution層

  2. ダウンサンプリング等の際に生じるskip connection中のconvolution層とその時のresidual branchのconvolution層

  3. softmax前の全結合層

Theorem 2はもし初期状態において\mathbf{z}が爆発していたならば,この3つの層が勾配爆発の原因となることを示していて,仮に正規化のないResNetを従来の方法で初期化すればこの現象が起こるということを言っている.

ZeroInit

Scale of the output

これまでの流れから勾配爆発を起こさないためには初期状態において出力が爆発しないことを保証する必要がある.この考えから次の初期化のための良い方針が得られる.

(a.) 出力のスケールは\mathbb{E}[\max_{i\in[c]}z_i]=\mathcal{O}(1)のように深さとは独立であるべき

ナイーブな方法としては出力層を0に初期化すること.ただし,この場合には最終層にスケールを抑えることを押し付けているため,ノルム\mathcal{O}(1)の勾配で数回更新をしてしまえば爆発することになる.

Scale of the residual branches

Residual branchesのスケールが学習初期に深さに応じて爆発することを防ぐ必要がある.すると次のような初期化方針が得られる.

(b.) Residual branchesのスケールは\mathrm{Var}[F_l(\mathbf{x}_l)]=\mathcal{O}(\frac{1}{L})のようにバランスが取れているべき

この方針はL個のresidual blockを持つネットワークfが与えられた時,m層を持つresidual branchに対して次の初期化方法を導く.(b.)を満たすことに関しての細かい話は省略.

ZeroInit (How to train a deep residual network without normalization)

  1. Classification layerと各residual branchの最後の畳み込み層を0に初期化

  2. その他の層はHeの初期化など従来の方法で初期化し,residual branches内のconvolution層においては\sqrt[2-2m]{L}でスケーリング

  3. 各ブランチにおいてconvolution, linear, elemnt-wise activation layerの前に1で初期化されたscalar multiplierと0で初期化されたscalar biasを挿入

3に関してはbatch normalization層のaffine変換と同様の変換で論文のFigure 1の一番右の図を見ればどこに挿入されているかわかる.この初期化は勾配の爆発を防ぐことができて,DenseNetのような場合にも適用可能とのこと(ただしその辺りはfuture workとしてる).

まとめ

かなり良好な結果を出しているよう.ただやはりBatchNormの破壊力は高いなという印象.

ZeroInit自体は生成モデルのGlowやfacebookのAccurate, Large Minibatch SGD: Training ImageNet in 1 Hourの論文でも使われているので学習を安定化させたりするのにはかなり効果があるっぽいのでそこに理論的に踏み込んだと見るといい論文.