Weight Standardizationを読んだのでメモ
はじめに
Weight Standardizationを読んだのでメモ.
気持ち
バッチサイズに左右されない新しい正規化層の提案.BatchNormはミニバッチサイズが小さい時にうまく機能しないことはよく知られていて,代替となるようなバッチサイズに依存しな様々な正規化層は提案されて来たもののどれもいまいちというのが現状.
BatchNormがうまくいく理由として,元々は内部共変量シフトを抑えるためとして提案されたが,近年ではloss landscapeを滑らかにするためという説が提唱されている.そのためこの論文ではloss landscapeを滑らかにするような方法を提案するというもの.
Weight Standardization
主な考え方としては,畳み込み層の重みを標準化することで勾配にリプシッツ定数を抑える役割を与え,loss landscapeを滑らかにするというもの.リプシッツ定数は勾配の上界と言えるため,その定数が小さくなるということはlossのlandscapeに崖のような急勾配が現れなくなる,すなわち滑らかになると考えられるというもの.
畳み込み層による変換は,バイアス項を除いて次のように表現できる.
ここでは畳み込みの演算子で,は重みを表す.ただし,はそれぞれ出力のチャネル数と,入力のチャネル数とカーネルサイズをかけたもの,すなわちを表す.ここで重みは次のようなの関数として表されるものとする.
ただし,はそれぞれ重みの軸に対する平均と標準偏差を表す.この重みの変換をweight standardization として定義する.
以下,WSの勾配を計算して具体的にリプシッツ定数が抑えられる効果があることなど解析が続くが,夜も遅く疲れたのと,ただ計算するだけでそんな小難しい内容ではないのでここでは割愛.注意としては効果を得るためには後段にGroupNormが必要で,基本的にGN+WSという運用になる.
まとめ
バッチサイズの大小に関わらずBatchNormを超える性能を論文の実験では記録している.pytorchに夜著者実装も公開済み.個人的にはBatchNorm不要になってくれれば嬉しい.