Meta Pseudo Labelsを読んだのでメモ
はじめに
Meta Pseudo Labelsを読んだのでメモ. policy gradientを使ったteacher networkの学習により半教師あり学習の精度向上.
気持ち
半教師あり学習では何らかの出力とラベルなしデータの出力のconsistencyをコストとする手法が多く,何らかの出力(target distribution/pseudo label)はヒューリスティクスな方法で作られることが多い(e.g. mixmatchのshapenedやmean teacherなど). そこでこの手法ではmeta-learn的にtarget distributionを作る. 具体的にはtarget distributionを作るteacher modelを学習したいモデル(student model)のvalidationでの精度が高くなるようにpolicy gradientで学習する. ある種,擬似ラベルをmeta learning的に決定しているためこの手法をMeta Pseudo Label(MPL)と呼ぶ.
Motivation
まずここではパラメータを持つクラス分類器を考える. 通常,分類機はtaget distribution とモデル分布 間のクロスエントロピーの最小化で学習される.
target distributionは完全な教師あり学習ではone-hotベクトル,knowledge distillationでは蒸留したいモデルの出力,半教師あり学習では何らかのモデルの出力に次のように手を加えたベクトルで与えられる.
また,ヒューリスティクスを加えた方法として,label smoothingやtemperature tuningなどが存在する. 一方で,taget distributionはどうのように振る舞うのがsutudent modelを学習する上で最も良いかは議論されてこなかった. なのでこの論文ではtarget distributionを手動設計する代わりにmeta learningで得る方が合理的としている. さらに,は悪い局所解へのトラップを避けるための学習度合いによって適応的に變化すべきとしている.
Meta Pseudo Labels
まず,target distribution をとパラメータで表現し,を学習することを考える. ここではをteacher model,をstudent modelと呼ぶ. の学習は擬似ラベルの獲得に等しい.
学習則
ここではteacherとstudentの学習則について述べる.
Phase 1: the student learns from the teacher
タイトルの通り,student modelをteacher modelの出力に従って学習する. なので勾配法により次のようにを更新.
Phase 2: the teacher learns from the student's validation loss
teacher modelはstudent modelのvalidationにおける損失に基づく以下の報酬を使ったpolicy gradientによりパラメータの更新を行う.
はこのを最小化するように勾配を使って更新される.
これはstudentの学習状況に応じてteacherを適応的に変化させることが可能であることを意味するが,一方でteacherを学習するには観測が少ない(studentのvalidation lossだけである)ため,teacherの学習が難しいという側面もある. なのでここでは従来のやり方に従い,教師あり学習も同時に行う. すなわちteacherは先ほどのstudentのvalidation lossを使った学習に加え,教師ありのクロスエントロピーの勾配を使っての更新を行う.
一方,課題として二つのモデルを学習するため計算コストが大きく,大規模なモデルを学習するのが難しいということ. なのでReducedMPLも提案. ReducedMPLは巨大なネットワークをまず(完全教師あり?)学習して,そのモデルで全学習データのtarget distributionを計算. その後巨大なネットワークは捨てて,小さいteacher を作る. このteacherはMLPなど.は巨大なネットワークで計算されたtarget distributionを入力として新しいtarget distributionを出力する. 直感的には巨大なネットワークで計算されたtarget distributionを学習中にstudentのvalidation lossに合わせて動的に修正していくという感じ.
MPLは他の半教師あり学習手法と合わせて使えるため実験ではRandAugmentやUDAと組み合わせて使っている. UDA+MPLが最も良い精度. さらに完全な教師あり学習にも組み込むことができ,NoisyStudentをoutperformしている. また,実験の中で面白かったのが,MPLの勾配はvalidation setにおけるstudentの購買に近しいものになっているという仮説の検証. 実際,MPLを使った場合と使わなかった場合の勾配とvalidationに対する勾配のコサイン類似度を測った結果,MPLを使った時の類似度が高かった.
まとめ
毎度思うが半教師あり学習のベンチマークでvalidationの精度(損失)を使った強化学習的アプローチは,ある種validationのデータも学習データとして使ってモデルのパラメータを学習してるのでフェアではない気がする.まして学習時に用いるラベル付きデータよりvalidationデータのが多いので尚更.ちゃんと学習時に使うラベル付きデータをsplitとして評価,もしくはベースラインはvalidationも学習データに含めて学習して欲しいところ.