機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Semi-Supervised Semantic Segmentation with Cross-Consistency Trainingを読んだのでメモ

はじめに

Semi-Supervised Semantic Segmentation with Cross-Consistency Trainingを読んだのでメモ. semantic segmentationのための半教師あり学習手法.

Cross-Consistency Training

semantic segmentationでは入力の空間では半教師あり学習のcluster assumption(同じクラスに属するデータが入力の空間で近いという仮説)が成り立たない(領域分割ではデータ点はピクセルのため同じクラスに属すると言って色が近いとは限らない)が,CNNの中間特徴では成り立つという観測の下,中間層に摂動を与えてconsistencyを課すというもの.

まず,ラベル付きデータを集合\mathcal{D}_l=\{(\mathbf{x}_1^l,y_1),\dots,(\mathbf{x}_n^l,y_n)\}で表し,ラベルなしデータを集合\mathcal{D}_u=\{\mathbf{x}_1^u,\dots,\mathbf{x}_m^u\}で表す. 仮定としてm\gg nとする.

encoder hとdecoder gから成り立つモデルf=g\circ hを考える. 提案するモデルでは上記モデルに対しK個のauxiliary decoder \{g_a^k\}_k^Kを導入する. ベースのdecoder (main decoder)gは教師付きラベルによって学習が進み,auxiliary decoderはラベルなしデータによって学習が進んでいく.

より具体的には教師付きデータに関しては次の損失をencoderとmain decoderに関して最小化する.

\displaystyle
\mathcal{L}_s=\frac{1}{|\mathcal{D}_l|}\sum_{\mathbf{x}^l_i,y_i\in\mathcal{D}_l}\mathbf{H}(y_i,f(\mathbf{x}_i^l))

ただし,\mathbf{H}(\cdot, \cdot)はクロスエントロピーを表す. ラベルなしデータに関してはまず共通のencoderで中間表現\mathcal{z}_i=h(\mathbf{x}_i^u)とした後,R個の確率的な摂動を与える関数p_r,r\in[1,R]のどれかをランダムに選び,K個の異なる摂動が加えられた中間表現\tilde{\mathbf{z}}_i^kを作り,それぞれをK個のauxiliary decoderへ入力する. これに対し下記の損失を最小化するようにauxiliary decoderとencoderを学習する(main decoderに関しては勾配を計算しない).

\displaystyle
\mathcal{L}_u=\frac{1}{|\mathcal{D}_u|}\frac{1}{K}\sum_{\mathbf{x}_i^u\in\mathcal{D}_u}\sum^K_{k=1}\mathbf{d}(g(\mathbf{z}_i),g_a^k(\tilde{\mathbf{z}}_i))

\mathbf{d}(\cdot,\cdot)は距離関数で自乗誤差やJSダイバージェンス など. よって最終的な損失は次のようになる.

\displaystyle
\mathcal{L}=\mathcal{L}_s+\omega_u\mathcal{L}_u

\omega_uはハイパーパラメータで,他のconsistency regularizationと同様に学習中にrampupする.

auxiliary decoderに関して,イントロダクションにはラベルなしデータを利用するため導入すると書いてあるが,アイディアのベースとしているcluster assumptionやアルゴリズム的に導入する必然性がないため,いまいちどういう役割を担うか分からない.

Perturbaation functions

摂動としては下記の5つを利用する.

  • F-Noise:N\sim\mathcal{U}(-0.3,0.3)としてサンプリングしたノイズを\tilde{\mathbf{z}}=(\mathbf{z}\odot N)+\mathbf{z}と加える.
  • F-Drop:\gamma\sim\mathcal{U}(0.6,0.9)としてサンプリングした閾値を使って入力を\tilde{\mathbf{z}}\odot\mathbf{M}_\text{drop},\mathbf{M}_\text{drop}=\{\mathbf{z}'\lt\gamma\}_1としてマスクする.ただし,\{\cdot\}_1は指示関数.大体10%から40%くらいの領域がマスクされるとのこと.
  • Guided Masking:main decoderの予測結果\hat{y}=g\circ h(\mathbf{x})を使って生成したマスク\mathbf{M}_\text{obj}からcontext mask \mathbf{M}_\text{con}=1-\mathbf{M}_\text{obj}を作り,このcontext maskにより\mathbf{z}をマスクする.\mathcal{M}_\text{obj}の生成は2007年に提案された手法を利用とのこと.
  • Guided Cutout:\mathbf{M}_\text{obj}から物体らしい領域のbounding boxを取得し,それに基づき\mathbf{z}に対しcutoutを行う.
  • Intermediate VAT:\mathbf{z}に対しVATを計算.

Practical considerations

UDAで提案されたTSAと同様に,段階的に教師あり損失を計算するサンプルを増やしていく方法,an annealed version of the bootstrapped-CE (ab-CE)を使う. 具体的には予測結果に対し,予測確率がある閾値\eta以下のピクセルのみ損失を計算するというもの.式的には下記.

\displaystyle
\mathcal{L}_s=\frac{1}{|\mathcal{D}_l|}\sum_{\mathbf{x}_i^l,y_i\in\mathcal{D}_l}\{f(\mathbf{x}_i^l)\lt\eta\}_1\mathbf{H}(y_i,f(\mathbf{x}_i^l))

式にも文章に書いてはないが,閾値判定はおそらく予測確率の最大値に対して評価する. \etaは学習中に徐々に大きくしていく.意味としては,大きく間違えているサンプルから優先的に学習していくというもの.

もう一つのpracticalな要素として,弱教師付きセグメンテーションの利用をする.image-level labelはセグメンテーションラベルより入手コストが低いため存在するものとし,そのようなデータを集合\mathcal{D}_w=\{(\mathbf{x}_1^w,y_1^w),\dots,(\mathbf{x}_m^w,y_m^w)\}とする. 元々のモデルの構成に加えて,image-level labelを分類するglobal average poolingを持つブランチg_cを追加し,学習する. このブランチからclass activation mappingを使ってマスクMを作り,適当な閾値を使って擬似ラベルy_pを作る.この擬似ラベルをCRFによりrefinementし教師データとして他の2つの損失に加え,次の損失を最小化するようにauxiliary decoderとencoderを学習する.

\displaystyle
\mathcal{L}_w=\frac{1}{|\mathcal{D}_w|}\frac{1}{K}\sum_{\mathbf{x}_i^w\in\mathcal{D}_w}\sum_{k=1}^K\mathbf{H}(y_p,g_a^k(\mathbf{z}_i))

その他,デコーダーを追加することで複数ドメインでの学習にも利用可能とのこと. ざっくり言えば,ドメイン1とドメイン2があれば,encoderは共通でそれぞれのドメインに対応するmain decoder g^{(1)},g^{(2)}とauxiliary decoder g^{k(1)}_a,g^{k(2)}_aを用意して学習するというもの,

まとめ

ablationを見るとab-CEの利用が精度向上にかなり寄与するよう. auxiliary decoderの数は増やしても精度が上がったり下がったりで設定にかなり依存がありそう. 個人的にはauxiliary decoderの必要性に関してもう少し説明もしくは実験が欲しかった.