機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

MixMatch: A Holistic Approach to Semi-Supervised Learningを読んだのでメモ

はじめに

MixMatch: A Holistic Approach to Semi-Supervised Learningを読んだのでメモ. Semi-supervisedのタスクで,ラベル付きデータが少量でも効果的なアルゴリズムであるMixMatchを提案.

MixMatch

MixMatchは最近の半教師付き学習の知見を全部詰め込んだみたいなアルゴリズムアルゴリズムの流れは擬似コードを見ればだいたい分かる.

まず,ラベル付きのデータ\mathcal{X}に対し任意のdata augmentationを適用して新たなデータ\hat{\mathcal{X}}を作る. その後ラベルなしデータ\mathcal{U}に対してもdata augmentationを適用して新たなデータ\hat{\mathcal{U}}を作る. ただし,ラベルなしのデータに関しては各サンプルK回ずつdata augmentationをしてラベルなしデータの数をK倍にする. K倍になったラベルなしのデータをモデルに入力し,その出力値をサンプルごとに平均することでラベルなしデータに対する予測ラベルを計算. 具体的には\bar{q}=\frac{1}{K}\sum_k\mathrm{model}(\hat{u}\in\hat{\mathcal{U}})として計算. さらに\bar{q}の値を温度パラメータTを使ってq_i=\bar{q}_i^\frac{1}{T}/\sum_{j=1}^L\bar{q}_j^{\frac{1}{T}}としてよりエントロピーの低いラベルにする. 注意として,ここでは分類問題を考えているため,添字i,jはカテゴリのインデックスを表す. 予測ラベルが作れたら,ラベル付きのデータとラベルなしのデータをMixUpして重み\lambda_\mathcal{U}付き二乗誤差をとるといった流れ. 予測ラベルは,MixUp時のラベルなしデータのラベルとして利用する.

MixMatchにはハイパーパラメータとして温度T,data augmentationの回数K,MixUpの係数をサンプリングするためのベータ分布のパラメータ\alpha,MixMatchのロス(上で説明した二乗誤差)の重み係数\lambda_\mathcal{U}が存在する. 温度と回数はそれぞれT=0.5,K=2として\alpha,\lambda_\mathcal{U}を調整するだけでだいたいうまくいくとのこと. \alpha,\lambda_\mathcal{U}もハイパーパラメータサーチする際には\lambda_\mathcal{U}=100,\alpha=0.75から始めると良いというtipsが書いてあった.

まとめ

実験結果を見る限るラベル付きデータが少ないときに圧倒的な性能を誇っている. すべてのデータにラベルがあるとして学習した場合の精度と1%前後しか変わらないのも驚き. アルゴリズム的にはInterpolation Consistency Trainingに対してmean teacher部分をdata augmentationの平均に置き換えたというような感じだが,少量ラベルでのICTの実験結果がないのでそちらも気になるところ. ICTの結果次第で,結局MixUpが強いということにもなる気はする.