機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Weight Standardizationを読んだのでメモ

はじめに

Weight Standardizationを読んだのでメモ.

気持ち

バッチサイズに左右されない新しい正規化層の提案.BatchNormはミニバッチサイズが小さい時にうまく機能しないことはよく知られていて,代替となるようなバッチサイズに依存しな様々な正規化層は提案されて来たもののどれもいまいちというのが現状.

BatchNormがうまくいく理由として,元々は内部共変量シフトを抑えるためとして提案されたが,近年ではloss landscapeを滑らかにするためという説が提唱されている.そのためこの論文ではloss landscapeを滑らかにするような方法を提案するというもの.

Weight Standardization

主な考え方としては,畳み込み層の重みを標準化することで勾配にリプシッツ定数を抑える役割を与え,loss landscapeを滑らかにするというもの.リプシッツ定数は勾配の上界と言えるため,その定数が小さくなるということはlossのlandscapeに崖のような急勾配が現れなくなる,すなわち滑らかになると考えられるというもの.

畳み込み層による変換は,バイアス項を除いて次のように表現できる. \displaystyle
\mathbf{y}+\hat{\mathbf{W}}\ast\mathbf{x}

ここで\astは畳み込みの演算子で,\hat{\mathbf{W}}\mathbb{R}^{O\times I}は重みを表す.ただし,O,Iはそれぞれ出力のチャネル数と,入力のチャネル数とカーネルサイズをかけたもの,すなわちI=C_{in}\times\mathrm{Kernel\ Size}を表す.ここで重み\hat{\mathbf{W}}は次のような\mathbf{W}の関数として表されるものとする.

\displaystyle
\hat{\mathbf{W}}=\left[\hat{\mathbf{W}}_{i,j}|\hat{\mathbf{W}}_{i,j}=\frac{\mathbf{W}_{i,j}-\mu_{\mathbf{W}_{i.\cdot}}}{\sigma_{\mathbf{W}_{i,\cdot}}+\epsilon}\right]\\

ただし,\mu_{\mathbf{W}_{i,\cdot}},\ \sigma_{\mathbf{W}_{i,\cdot}}はそれぞれ重みのI軸に対する平均と標準偏差を表す.この重みの変換をweight standardization \hat{\mathbf{W}}=\mathrm{WS}(\mathbf{W})として定義する.

以下,WSの勾配を計算して具体的にリプシッツ定数が抑えられる効果があることなど解析が続くが,夜も遅く疲れたのと,ただ計算するだけでそんな小難しい内容ではないのでここでは割愛.注意としては効果を得るためには後段にGroupNormが必要で,基本的にGN+WSという運用になる.

まとめ

バッチサイズの大小に関わらずBatchNormを超える性能を論文の実験では記録している.pytorchに夜著者実装も公開済み.個人的にはBatchNorm不要になってくれれば嬉しい.