機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Supervised Contrastive Learningを読んだのでメモ

はじめに

Supervised Contrastive Learningを読んだのでメモ. 通常,教師なし表現学習で使われるcontrastive learningを教師あり学習に適用した論文. 通常のsoftmax+cross entropyに比べハイパーパラメータの設定に対し鈍感(ある程度調整が雑でも動く)かつ,精度が良い.

Method

ベースとするcontrastive learningはSimCLRとほぼ同じ. ただしaugmentationとして,AutoAugment,RandAugment,SimAugment(SimCLRで使われたaugmentation)の3つのどれかを利用. Contrastive learningのためのモデルは画像を特徴ベクトルへと変換するencoder network \mathrm{E}(\cdot) (ResNet-50もしくはResNet-200のGAPまで)とcontrastive lossの計算に用いる表現に写像するprojection network \mathrm{P}(\cdot) (中間層が1層のMLP)の二つから成る. encoderとprojectionの出力ベクトルは共にノルムが1に正規化される. projectionの方はcontrastive learningでコサイン類似度を利用するので一般的ではあるが,中間表現を正規化するのはほとんどの場合で精度を改善するためとのこと.

ここでは教師あり学習のためのcontrastive learningを次の様に改良する.

\displaystyle
\mathcal{L}^{sup}=\sum_{i=1}^{2N}\mathcal{L}^{sup}_i\\
\mathcal{L}_i^{sup}=\frac{-1}{2N_{\tilde{y}_i}-1}\sum^{2N}_{j=1}\mathbb{1}_{i\neq j}\cdot\mathbb{1}_{\tilde{y}_i=\tilde{y}_j}\cdot\log\frac{\exp(z_i\cdot z_j/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{i\neq k}\cdot\exp(z_i\cdot z_k/\tau)}

SimCLRのフレームワークに則っているので,入力のバッチサイズNに対し2Nのデータが生成される. 2N個のデータに対しそれぞれprojection networkの出力z_iが計算され,anchorとなるz_iとその他のデータ間の温度パラメータ付きコサイン類似度z_i\cdot z_j/\tauが計算される. コサイン類似度に対しsoftmax関数の対数をとった形で定義されるcontrastive lossを計算する. SimCLRと異なる点として,SimCLRではいわゆるポジティブペアは元となったデータが同じデータ同士のみで定義されたが,supervised contrastive learningではラベルが等しいデータ全てをポジティブペアとして扱う. その気持ちが先頭の\sum^{2N}_{j=1}\mathbb{1}_{\tilde{y}_i\cdot=\tilde{y}_j}に現れている. N_{\tilde{y}_i}はミニバッチ内におけるラベル\tilde{y}_iが付いているデータの個数.

このsupervised contrastive lossの勾配はhard positiveとhard negativeを重視した学習を引き起こす構造を持つことについて示す. まずwをprojection networkの正規化前の出力とし(つまりz=w/\|w\|),その勾配は

\displaystyle
\frac{\partial \mathcal{L}^{sup}_i}{\partial w_i}=\frac{\partial \mathcal{L}^{sup}_i}{\partial w_i}\left|_{pos}+\frac{\partial\mathcal{L}^{sup}_i}{\partial w_i}\right|_{neg}

となり,右辺はそれぞれ

\displaystyle
\left.\frac{\partial\mathcal{L}^{sup}_i}{\partial w_i}\right|_{pos}\propto\sum^{2N}_{j=1}\mathbb{1}_{i\neq j}\cdot\mathbb{1}_{\tilde{y}_i=\tilde{y}_j}\cdot( (z_i\cdot z_j)\cdot z_i-z_j)\cdot(1-P_ij)\\
\left.\frac{\partial\mathcal{L}^{sup}_i}{\partial w_i}\right|_{neg}\propto\sum^{2N}_{j=1}\mathbb{1}_{i\neq j}\cdot\mathbb{1}_{\tilde{y}_i=\tilde{y}_j}\cdot\sum_{k=1}^{2N}\mathbb{1}_{k\notin k}\cdot(z_i-(z_i\cdot z_k)\cdot z_i)\cdot P_{ik}

となる.ただし,P_{il}は以下の様に定義される.

\displaystyle
P_{il}=\frac{\exp(z_i\cdot z_l/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{i\neq k}\cdot\exp(z_k\cdot z_l/\tau)},\ i,l\in\{1,\dots,2N\},\ i\neq l

ここでは簡単なポジティブペアに関してはz_i\cdot z_j\approx 1が成り立つとし,このとき

\displaystyle
\|( (z_i\cdot z_j)\cdot z_i-z_j)\|\cdot(1-P_{ij})=\sqrt{1-(z_i\cdot z_j)^2}\cdot(1-P_{ij})\approx 0

となる.一方でhard positiveに関してはz_i\cdot z_j\approx 0が成り立つと考えられ,このとき

\displaystyle
\|( (z_i\cdot z_j)\cdot z_i-z_j)\|\cdot(1-P_{ij})=\sqrt{1-(z_i\cdot z_j)^2}\cdot(1-P_{ij})\gt 0

となる. そのため簡単なpositiveに関する勾配は小さくなり難しいpositiveに関する勾配は大きくなる. これはnegativeに関する勾配でも同様のことが言える. と論文で言っているがeasy,hardの議論は当たり前のことでは…(ちょっと理解が足りていない気がするが).

また,手法とは関係なく一般的な話としてcontrastive learningはtriplet lossともつながりがある. テイラー展開を2度利用することで次のように導出可能.

displaystyle
\mathcal{L}_{con}=-\log\frac{\exp(z_a\cdot z_p/\tau)}{\exp(z_a\cdot z_p/\tau)+\exp(z_a\cdot z_n/\tau)}\\
=\log(1+\exp( (z_a\cdot z_n-z_a\cdot z_p)/\tau))\\
\approx \exp( (z_a\cdot z_n-z_a\cdot z_p)/\tau)\\
\approx 1+\frac{1}{tau}\cdot(z_a\cdot z_n-z_a\cdot z_p)\\
=1-\frac{1}{2\tau}\cdot(\|z_a-z_n\|^2-\|z_a-z_p\|^2)\\
\propto\|z_a-z_p\|^2-\|z_a-z_n\|^2+2\tau

2\tau=\alphaとすれば最後の式はtriplet lossそのものとなる. 一方で,contrastive lossはtriplet lossより一般に良い結果をもたらす. また,triplet lossは計算コストのかかるhard negative miningを利用するが,先の議論の通りsupervised contrastive lossは自然にhard negative miningをするという利点がある.

まとめ

実験で,Fig. 4に示されている様にsupervised contrastive learningはcross entropyに比べ,先に挙げた3つのdata augmentationどれに対しても安定した学習が可能で,optimizerの選択(LARS,SGD,RMSProp)でも精度がぶれない. 一方で学習率に対してはcross entropyよりセンシティブな様子. また,学習にはsupervised contrastive learning後にcross entropyによるsupervised learningが必要となり学習コストは上がるという課題もある. これらを踏まえるとcross entropyに変わる学習方法になるのは難しそうだが,面白い結果だった.