Realistic Evaluation of Semi-Supervised Learning Algorithmsを読んだのでメモ
はじめに
Realistic Evaluation of Semi-Supervised Learning Algorithmsを読んだのでメモ.PyTorchで実装もしました.実装の話はこちら.
気持ち
データを作るコストが高いことからsemi-supervised learning (SSL)は重要で,最近はそれなりのラベルデータがあればsupervisedに匹敵する性能がでる.ただ,実世界の設定に対してSSLはちゃんと機能するのかというのがこの論文が問題として提案しているところ.そこで実世界の課題に対してSSLのアルゴリズムが機能するかを評価可能な新しい実験方法を提案するという論文.実験を行う上で実際以下の様な発見があったとのこと.
・同一構成のモデルに対してラベル有りのみを使った場合とラベルなしも使った場合の性能差が報告よりも小さい.
・強力な正則化をかけてラベル有りのみで学習された分類器の性能はSSLアルゴリズムを評価する上で重要.
・異なるデータで学習されたモデルをラベル有りのみでfine-tuneした分類器は全てのSSLアルゴリズムより性能がよかった.
・SSLの性能はラベルなしデータがラベル有りデータと異なる分布のデータを含む場合に劇的に下がる.
・アプローチごとにラベル有りとラベルなしデータの量に対する性能の感度が違う.
・Varidationデータが少ないと比較結果の信頼性がない.
いくつかは当たり前な様な気もするが,こう言ったことがあるからこの論文では実験方法とともに全てのいろんなstate-of-th-art SSL手法を再実装して評価できる様にしたものも提供するとのこと(現時点ではgithubのリポジトリはあるがコードは公開されてない).
Improved Evaluation
一般的なSSLの評価方法は,教師あり学習のためのデータセットのラベルを一部を除いて捨ててラベル有りデータとラベルなしデータに分け,このデータを使って学習した後にテストデータで評価するというもの.この時,使うデータセットとラベル有りデータの数が論文によって異なるため正確な比較ができていない.主に以下の6つの項目が実問題にかけ離れているとして,これに着目して実験/評価方法を整備する.
P.1 A Shared Implementation
SSLの比較に使われる元となるモデル構造を共通化する.というのも,従来は単純な13層のCNNを異なる実装(異なるパラメータの初期化方法やデータの前処理,正則化,最適化方法など)で使っていて,ちゃんと比較できているとは言えないため.
P.2 High-Quality Supervised Baseline
SSLの目的はだけの学習に比べてとを組み合わせた学習による精度向上.このベースラインとなるだけで学習したモデルが論文によって学習の具合が違っていて,同じ条件のはずが別々の論文のベースラインを比べると最大15%程の差があるとのこと.そのため,ベースラインとSSLのチューニングにかかる計算量を同一にしようというもの.
P.3 Comparison to Transfer Learning
限られたラベル有りデータで学習する際の有効な手法として転移学習があり,これと比較しないのはおかしいだろというもの.
P.4 Consider Class Distribution Mismatch
従来の実験設定では,データセットラベルを捨ててを作るため,はに含まれるデータと同じクラスのデータを持つ.ただ,実世界の設定においてはそうなるとは限らない.つまり,例えば10種類の識別をしたい際に,ラベルなしデータの中に識別したい10種類とは違うクラスのデータが混ざっている場合がある.なので実問題に対する評価をするためにはとが異なるクラスの分布を持つ時の影響も調べる必要がある.
P.5 Varying the Amount of Labeled and Unlabeled Data
今まではとの比率を決める体系的な方法なしで比較していたが,実問題においてはがめちゃくちゃ巨大(ネットから大量に集まる自然画像で作られている)か,比較的小さい(医療データなど)の2種類に分けられる.そのためアルゴリズムごとのデータ数による性能の変動を比較すべき.
P.6 Realistically Small Validation Sets
SSLの実験の設定ではValidationデータが非常に多い.というのも,例えばSVHNデータセットを使った実験では学習用のラベル有りサンプルは1000程度に対し,Varidationデータは元のまま7000サンプルを使う.これの問題は,実問題ではこんな大きなVaridationデータは用意できず,Varidationデータを使ったパラメータのチューニング等は不可能だということ.たとえcross-validationをしたとしても不十分な上計算コスト的に厳しいと論文では主張.
Semi-Supervised Learning Methods
この論文で使うSSL手法のざっくりとした復習.
Consistency Regularization
Consistency regularizationはRealisticな摂動をデータに加えても出力は変わらないだろうという前提に基づいた手法.元のデータを,摂動が加えられたデータをとすると,の最小化問題として記述され,はMSEやKLダイバージェンスが使われる.これはデータが分布する多様体が滑らかであるという仮定に基づいていて,この考えは様座mなSSLに応用されている.
Stochastic perturbations/-Model
Consistency Regularizationの最も単純な設定は,が確率的な出力をだす,つまり同一の入力に関して異なる出力を出すというもの.これはデータ拡張やdropoutなどのことで,この時の出力が元のデータを入力した際と変わらない様にを正則化項としてつけるというのがこのモデル.これはregularization with stochastic transformation and perturbations,-Modelとして同時に提案されていて,ここでは-Modelの呼称を使う.
Temporal ensembling/Mean teacher
-Modelの課題として,不安定なターゲット(が変化するということ)を推定するところにある.そこでより安定したターゲットを得る方法が提案された.一つはTemporal Ensemblingでの移動平均の様なものを代わりとして使おうというもの.もう一つはMean Teacherで学習中の学習パラメータの移動平均的なもので定義された関数で推定された値を使おうというもの.
Virtual adversarial training
を確率的なものにする代わりに,出力に大きな影響を及ぼす微小な摂動を近似して直接に加えることで学習しようというのがVirtual adversarial trainig (VAT).摂動は次の様に効率的に計算できる.
ただし,はハイパーパラメータ.consistency regularizationはをについて最小化するときに適用される.
Entropy-based
に適用される単純な損失項は情報量を下げる(つまりカテゴリカル分布を考えたときにどこか一つのクラスが尖る)というもの.出力がsoftmax関数を通して得られた[tex]K]次元のベクトルとすればentropy minimization項は次の様に表現できる.
ただこの方法はモデルの表現力が高いと簡単に過学習する.ただVATと組み合わせることで有用な結果を生んでいるとのこと.
Pseudo-labeling
Pseudo-labelingはヒューリスティックな手法だが,その単純さと適応範囲の広さから実際には広く使われている.Pseudo-labelingはあらかじめ設定された閾値を超える確率を持つクラスを正解のクラスとしてに擬似ラベルをつける方法.ただ,推定関数が役に立たないラベルを生成してしまった際には性能を悪化させる.ただ,pseudo-labelingはentropy minimizationと似た様な性質をもち,
Experiments
P.1
P.1の課題を解決するためにモデルと最適化方法の統一をした.モデルにはWide ResNetを,特に,一般的に使われているtensorflow/modelsのリポジトリにある"WRN-28-2"を使ってAdamによる最適化をした.あとは基本的な正則化とデータ拡張,前処理もしたらしい.ベースラインとSSLに関してgoogle cloud machine learningを使ったハイパーパラメータチューニングサービスを使って"Gaussian Process-based black box optimizationを1000回走らせて各アルゴリズムごとにハイパーパラメータを最適化したとのこと.
P.2
これはP.1の結果解決されて良いベースラインが作れたとのこと.実際他の文献で報告されているほどベースラインの精度は悪くない結果となった.ただ,現在提案されてる様々な正則化やデータ拡張をするとstate-of-the-art SSL手法並みに精度が出るとの事.
P.3
WRN-28-2を32x32にリサイズしたImageNetで学習したものをCIFAR-10で転移学習して比較したとの事.論文で実験した結果最も良いSSLであるVAT+EMのエラー率が13.13%に対し転移学習したものは12.0%だったそう.ただし,ImageNetに含まれるカテゴリはCIFAR-10とモロ被りしているのでこれはかなり恵まれた条件だったと論文には記載されている.後,転移学習とSSLを組み合わせても実験するのはfuture work.
P.4
クラスのミスマッチを扱うためにCIFAR-10の6クラス分類(bird, cat, deer, dog, frog, horse)をしたとの事,ラベルなしはその他4クラスを含む(分類の6クラスを含むかは微妙.The unlabeled data comes from four classesと書いてあったから多分含まない)ものとして,ラベル有りと無しのデータの比率を変えて実験.結果ラベルなしがある一定の割合を超えたらSSLより単純な教師ありが勝った.
P.5
ラベル付きのデータ数を変えた場合,ラベルなしのデータ数を変えた場合でSSL手法の比較をした結果,アルゴリズムごとに違う振る舞いでいい発見だったという感じ.
P.6
当たり前だけどVaridationデータの数を少なくしたら評価時の分散も大きくなったという結果.
まとめ
SSLに関する問題提起の論文.避けてきた点をつくいい論文だけどgoogleの宣伝感がすごい.githubで公開するとのことだけど使うにはtensorflow使う必要があるのと(よくわからないが)同一条件で自分のアルゴリズムを試そうとしたらgoogleのパラメータチューニングサービスを使う必要がありそう.
Appendixに探索したハイパーパラメータがまとまってるのでSSLする際には役に立ちそう.