Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networksを読んだのでメモ
はじめに
Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networksを読んだのでメモ. BatchNormをoutperformする非バッチ依存の正規化層を提案する論文.
手法
前提として入力データはバッチ次元を持つ画像,すなわちとする.番目の画像の番目のチャネルをとし,この平均二乗ノルムをとする.
提案手法はFilter Response Normalization (FRN)とThresholded Linear Unit(TLU)の組み合わせからなる.FRNは以下に定義される.
はゼロ割を防ぐための値.の値によって学習時の勾配の値が大きく異なるため,ここではを学習パラメータとする場合を実験したとのこと.実験的にを学習パラメータ化した方が良かったらしく,さらにとしてを学習すると良いとのこと.学習パラメータとする際,チャネルごとに持つのか共通のスカラー値かは不明.
FRNの狙いとしてはスケールの影響をなくすことで,従来との大きな違いは平均を引く演算がないこと.これはバッチに非依存な正規化処理では平均を引くことが正当化されないためとか.この正規化の処理はチャネルごとに行われ,これは全てのチャネルが同一のノルムを持つことからモデルの持つCNNのフィルタが同列に重要になることを保証する.
FRNの後は従来の正規化層と同様次の様なアフィン変換をおこなう.
FRNは平均を引く操作がない分任意の大きさのバイアスを与えることができる.その様な任意のバイアスはReLU関数を後段で使う場合にdead unitsを引き起こす原因となるため,ここでは学習パラメータを持つ次のReLUの様な関数であるTLUを提案する.
式的には
のようにReLUの前後に値を共有したバイアスを加えることと等しい. なぜか実験ではの方がよかったとかなんとか. 論文にはTLUの定式化の方が最適化に適していると書いてあるがよくわからない.数値誤差のレベルで違いが出るとか?
平均中心化
BatchNormは(実際の効果は置いておいて)共変量シフトを防ぐために1次と2次のモーメントによる正規化を行なっている.またInstanceNormはFRNの1次モーメントを使うバージョンと見れる.InstanceNormを考えると,画像の解像度が下がるとInstanceNormはzero activationを生成してしまう可能性がある.LayerNormやGroupNormはチャンネル方向にまたぐことでこれを防いでいるが,個々のフィルタは特定のチャンネルに対して反応を得るのでチャネルを跨いだ正規化は学習を複雑にする問題がある.結果として,フィルタ間の干渉とzero activationを防ぐという観点から平均を引く処理をなくすという方法に行き着いた.ただし,平均を引く処理をなくすと今度は任意のバイアスを許すという別な問題を生じさせるため合わせてTLUを用いるとのこと.
実装
PyTorchだと多分こんな感じ.
import torch import torch.nn as nn import torch.nn.functional as F class FRN(nn.Module): def __init__(self, channels, affine=True, learnable_eps=True, eps=1e-6): super().__init__() self.register_parameter("tau", nn.Parameter(torch.zeros(1))) if affine: self.register_parameter("weight", nn.Parameter(torch.ones(1, channels, 1, 1))) self.register_parameter("bias", nn.Parameter(torch.zeros(1, channels, 1, 1))) else: self.register_buffer("weight", None) self.register_buffer("bias", None) if learnable_eps: self.register_parameter("eps_l", nn.Parameter(torch.zeros(1).fill_(1e-4))) else: self.register_buffer("eps_l", None) self.register_buffer("eps", torch.zeros(1).fill_(eps)) def forward(self, x): if self.eps_l is not None: eps = self.eps_l.abs() + self.eps else: eps = self.eps x = x / (x.pow(2).mean((1,2), keepdim=True) + eps).sqrt() if self.weight is not None: x = self.weight * x + self.bias return torch.max(x, self.tau)
まとめ
Weight StandardizationもBNいらずとして非常に驚きだったが,こちらはBNを超える精度を達成しているのでさらに驚き.