Supervised Contrastive Learningを読んだのでメモ
はじめに
Supervised Contrastive Learningを読んだのでメモ. 通常,教師なし表現学習で使われるcontrastive learningを教師あり学習に適用した論文. 通常のsoftmax+cross entropyに比べハイパーパラメータの設定に対し鈍感(ある程度調整が雑でも動く)かつ,精度が良い.
Method
ベースとするcontrastive learningはSimCLRとほぼ同じ. ただしaugmentationとして,AutoAugment,RandAugment,SimAugment(SimCLRで使われたaugmentation)の3つのどれかを利用. Contrastive learningのためのモデルは画像を特徴ベクトルへと変換するencoder network (ResNet-50もしくはResNet-200のGAPまで)とcontrastive lossの計算に用いる表現に写像するprojection network (中間層が1層のMLP)の二つから成る. encoderとprojectionの出力ベクトルは共にノルムが1に正規化される. projectionの方はcontrastive learningでコサイン類似度を利用するので一般的ではあるが,中間表現を正規化するのはほとんどの場合で精度を改善するためとのこと.
ここでは教師あり学習のためのcontrastive learningを次の様に改良する.
SimCLRのフレームワークに則っているので,入力のバッチサイズに対しのデータが生成される. 個のデータに対しそれぞれprojection networkの出力が計算され,anchorとなるとその他のデータ間の温度パラメータ付きコサイン類似度が計算される. コサイン類似度に対しsoftmax関数の対数をとった形で定義されるcontrastive lossを計算する. SimCLRと異なる点として,SimCLRではいわゆるポジティブペアは元となったデータが同じデータ同士のみで定義されたが,supervised contrastive learningではラベルが等しいデータ全てをポジティブペアとして扱う. その気持ちが先頭のとに現れている. はミニバッチ内におけるラベルが付いているデータの個数.
このsupervised contrastive lossの勾配はhard positiveとhard negativeを重視した学習を引き起こす構造を持つことについて示す. まずをprojection networkの正規化前の出力とし(つまり),その勾配は
となり,右辺はそれぞれ
となる.ただし,は以下の様に定義される.
ここでは簡単なポジティブペアに関してはが成り立つとし,このとき
となる.一方でhard positiveに関してはが成り立つと考えられ,このとき
となる. そのため簡単なpositiveに関する勾配は小さくなり難しいpositiveに関する勾配は大きくなる. これはnegativeに関する勾配でも同様のことが言える. と論文で言っているがeasy,hardの議論は当たり前のことでは…(ちょっと理解が足りていない気がするが).
また,手法とは関係なく一般的な話としてcontrastive learningはtriplet lossともつながりがある. テイラー展開を2度利用することで次のように導出可能.
とすれば最後の式はtriplet lossそのものとなる. 一方で,contrastive lossはtriplet lossより一般に良い結果をもたらす. また,triplet lossは計算コストのかかるhard negative miningを利用するが,先の議論の通りsupervised contrastive lossは自然にhard negative miningをするという利点がある.
まとめ
実験で,Fig. 4に示されている様にsupervised contrastive learningはcross entropyに比べ,先に挙げた3つのdata augmentationどれに対しても安定した学習が可能で,optimizerの選択(LARS,SGD,RMSProp)でも精度がぶれない. 一方で学習率に対してはcross entropyよりセンシティブな様子. また,学習にはsupervised contrastive learning後にcross entropyによるsupervised learningが必要となり学習コストは上がるという課題もある. これらを踏まえるとcross entropyに変わる学習方法になるのは難しそうだが,面白い結果だった.