機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Meta Pseudo Labelsを読んだのでメモ

はじめに

Meta Pseudo Labelsを読んだのでメモ. policy gradientを使ったteacher networkの学習により半教師あり学習の精度向上.

気持ち

教師あり学習では何らかの出力とラベルなしデータの出力のconsistencyをコストとする手法が多く,何らかの出力(target distribution/pseudo label)はヒューリスティクスな方法で作られることが多い(e.g. mixmatchのshapenedやmean teacherなど). そこでこの手法ではmeta-learn的にtarget distributionを作る. 具体的にはtarget distributionを作るteacher modelを学習したいモデル(student model)のvalidationでの精度が高くなるようにpolicy gradientで学習する. ある種,擬似ラベルをmeta learning的に決定しているためこの手法をMeta Pseudo Label(MPL)と呼ぶ.

Motivation

まずここではパラメータ\Thetaを持つCクラス分類器を考える. 通常,分類機はtaget distribution q_\ast(\mathbf{Y}|\mathbf{X})とモデル分布 p_\Theta(\mathbf{Y}|\mathbf{X})間のクロスエントロピーの最小化で学習される.

\displaystyle
\min_\Theta\mathcal{L}_\text{CE}(\Theta)=-\mathbb{E}_{\mathbf{x}\sim\mathcal{D}}\left[\sum_{c=1}^Cq_\ast(c|\mathbf{x})\log p_\Theta(c|\mathbf{x})\right]

target distributionは完全な教師あり学習ではone-hotベクトル,knowledge distillationでは蒸留したいモデルの出力,半教師あり学習では何らかのモデルq_\xiの出力に次のように手を加えたベクトルで与えられる.

\displaystyle
\text{Hard label}:\ q_\ast(\mathbf{Y}|\mathbf{x})\triangleq\text{one-hot}(\underset{\mathbf{y}}{\mathrm{arg}\max}q_\xi(\mathbf{y}|\mathbf{x}))\\
\text{Soft label}:\ q_\ast(\mathbf{Y}|\mathbf{x})\triangleq q_\xi(\mathbf{Y}|\mathbf{x})

また,ヒューリスティクスを加えた方法として,label smoothingやtemperature tuningなどが存在する. 一方で,taget distributionはどうのように振る舞うのがsutudent modelを学習する上で最も良いかは議論されてこなかった. なのでこの論文ではtarget distributionを手動設計する代わりにmeta learningで得る方が合理的としている. さらに,q_\astは悪い局所解へのトラップを避けるためp_\Thetaの学習度合いによって適応的に變化すべきとしている.

Meta Pseudo Labels

まず,target distribution q_\ast(\mathbf{x})q_\Psi(\mathbf{x})とパラメータ\Psiで表現し,\Psiを学習することを考える. ここではp_\Psiをteacher model,p_\Thetaをstudent modelと呼ぶ. p_\Psiの学習は擬似ラベルの獲得に等しい.

学習則

ここではteacherとstudentの学習則について述べる.

Phase 1: the student learns from the teacher

タイトルの通り,student modelをteacher modelの出力に従って学習する. なので勾配法により次のように\Thetaを更新.

\displaystyle
\Theta^{(t+1)}\triangleq\Theta^{(t)}-\eta\nabla_\Theta\mathcal{L}_\text{CE}(q_\Psi(\mathbf{x}),p_\Theta(\mathbf{x}))|_{\Theta^{(t)}}

Phase 2: the teacher learns from the student's validation loss

teacher modelはstudent modelのvalidationにおける損失に基づく以下の報酬を使ったpolicy gradientによりパラメータの更新を行う.

\displaystyle
\mathcal{L}_\text{CE}(\mathbf{y}_\text{val},p_{\Theta^{(t+1)}}(\mathbf{x}_\text{val}))\triangleq\mathcal{R}(\Theta^{(t+1)})\\
=\mathcal{R}(\Theta^{(t)}-\eta\nabla_\Theta\mathcal{L}_\text{CE}(q_\Psi(\mathbf{x}),p_\theta(\mathbf{x}))|_{\Theta^{(t)}})

\Psiはこの\mathcal{R}(\Theta^{(t+1)})を最小化するように勾配\nabla_\Psi\mathcal{R}を使って更新される.

これはstudentの学習状況に応じてteacherを適応的に変化させることが可能であることを意味するが,一方でteacherを学習するには観測が少ない(studentのvalidation lossだけである)ため,teacherの学習が難しいという側面もある. なのでここでは従来のやり方に従い,教師あり学習も同時に行う. すなわちteacherは先ほどのstudentのvalidation lossを使った学習に加え,教師ありのクロスエントロピーの勾配\nabla_\Psi\mathcal{L}_\text{CE}(\mathbf{y},q_\Psi(\mathbf{x}))を使って\Psiの更新を行う.

一方,課題として二つのモデルを学習するため計算コストが大きく,大規模なモデルを学習するのが難しいということ. なのでReducedMPLも提案. ReducedMPLは巨大なネットワークをまず(完全教師あり?)学習して,そのモデルで全学習データのtarget distributionを計算. その後巨大なネットワークは捨てて,小さいteacher q_\Psiを作る. このteacherはMLPなど.q_\Psiは巨大なネットワークで計算されたtarget distributionを入力として新しいtarget distributionを出力する. 直感的には巨大なネットワークで計算されたtarget distributionを学習中にstudentのvalidation lossに合わせて動的に修正していくという感じ.

MPLは他の半教師あり学習手法と合わせて使えるため実験ではRandAugmentやUDAと組み合わせて使っている. UDA+MPLが最も良い精度. さらに完全な教師あり学習にも組み込むことができ,NoisyStudentをoutperformしている. また,実験の中で面白かったのが,MPLの勾配はvalidation setにおけるstudentの購買に近しいものになっているという仮説の検証. 実際,MPLを使った場合と使わなかった場合の勾配とvalidationに対する勾配のコサイン類似度を測った結果,MPLを使った時の類似度が高かった.

まとめ

毎度思うが半教師あり学習ベンチマークでvalidationの精度(損失)を使った強化学習的アプローチは,ある種validationのデータも学習データとして使ってモデルのパラメータを学習してるのでフェアではない気がする.まして学習時に用いるラベル付きデータよりvalidationデータのが多いので尚更.ちゃんと学習時に使うラベル付きデータをsplitとして評価,もしくはベースラインはvalidationも学習データに含めて学習して欲しいところ.