機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

PointAugment: an Auto-Augmentation Framework for Point Cloud Classificationを呼んだのでメモ

はじめに

PointAugment: an Auto-Augmentation Framework for Point Cloud Classificationを呼んだのでメモ. Point Cloudに対するAutoAugment. ただAutoAugmentというが,originalのAutoAugmentとは大きく異なり,フレームワークとしては

  • データ毎にdata augmentationを生成
  • data augmentationはニューラルネットで実行
  • モデルの学習とaugmentationをadversarial trainingの要領で同時最適化

となる. 前提として点群の分類問題を扱うものとする.

Overview

PointAugmentはaugmentor \mathcal{A}とclassifier \mathcal{C}の二つのネットワークから構成される. M個の学習データ\{\mathcal{P}_i\}_{i=1}^Mが与えられた際,\mathcal{C}\mathcal{P}_iで通常通り学習されるとともに,\mathcal{P}_i\mathcal{A}に入力して得られた出力\mathcal{P}^\prime_iでも学習される. 同時に,\mathcal{P}^\prime_iに対する\mathcal{C}の出力を受けて\mathcal{A}のパラメータも更新される.

Method

Augmentor \mathcal{A}は3D点群のshapeに関するaugmentation(回転など)と点ごとの変換に関するaugmentation(ノイズの加算など)それぞれを扱う. そのためaugmentor(PointNet構造)は入力を(x,y,z)座標を値として持つN点の点群とすると,3\times3の行列\mathcal{M}N\times 3の行列\mathcal{D}を出力する. その出力を使って入力の点群\mathcal{P}_i\mathcal{P}\cdot\mathcal{M}+\mathcal{D}として変換し,\mathcal{P}^\primeを生成する. ただし,augmentationに確率的な振る舞いを導入するため,\mathcal{M},\mathcal{D}を出力する全結合層に点群の特徴量だけでなくnoiseも入力する(詳細はFig. 4).

ここでの問題はaugmentorの学習で,この論文ではaugmentorは(i)元の点群\mathcal{P}より難しいサンプル,すなわちL(\mathcal{P}^\prime)\gt L(\mathcal{P})を満たす\mathcal{P}^\prime生成すべき(ただしL(\cdot)はクロスエントロピー),(ii)\mathcal{P}^\primeは元の点群の形から逸脱しない物であるべきという2点を満たすべきとしている.

まず(i)を満たすため,augmentorの目的関数は次のようになる.

\displaystyle
\mathcal{L}_\mathcal{A}=\exp[-(L(\mathcal{P}^\prime)-L(\mathcal{P}))]

これは1以下の値を取る場合にL(\mathcal{P}^\prime)\gt L(\mathcal{P})を満たす. なので\mathcal{L}_\mathcal{A}を最小化するようにaugmentorを学習すれば(i)を満たすサンプルを生成することが期待される. 一方で,\mathcal{L}_\mathcal{A}\rightarrow0のときにはL(\mathcal{P}^\prime)-L(\mathcal{P})\rightarrow\inftyとなり,このような場合\mathcal{P}^\primeはもはや(ii)を満たさない点群になっていることが予想される. そのため\xi= L(\mathcal{P}^\prime)-L(\mathcal{P})を小さくするように次のように\mathcal{L}_\mathcal{A}を非負のdynamic parameter\rhoを使って定義し直す.

\displaystyle
\mathcal{L}_\mathcal{A}=|1.0-\exp[(L(\mathcal{P}^\prime)-\rho L(\mathcal{P}))]|

このとき,\xiの上界は\xi= L(\mathcal{P}^\prime)-L(\mathcal{P})\leq(\rho-1)L(\mathcal{P})=\xi_\mathcal{O}となる.

Augmentorの学習時にはclassifierは固定となるため,この上界は\rhoにのみ依存する. 通常,学習初期はclassifierの挙動がセンシティブであるため,\mathcal{P}^\primeが難しすぎないように注意を払う必要がある. そのため,\rhoを次のように設計する.

\displaystyle
\rho=\max\left(1,\exp\left(\sum_{c=1}^K\hat{y}_c\cdot y_c\right)\right)

\hat{y}_cは真のonehotラベルで,y_c\mathcal{P}に対するclassifierの出力. これは,分類が容易な簡単なサンプルほど\rhoが大きくなり,難しい\mathcal{P}^\primeを生成することを目的としている.

最終的なaugmentorの損失は次のようになる.   \displaystyle
\mathcal{L}_\mathcal{A}=L(\mathcal{P}^\prime)+\lambda|1.0-\exp(L(\mathcal{P}^\prime)-\rho L(\mathcal{P}))|

\lambdaはハイパーパラメータ. L(\mathcal{P}^\prime)はaugmentorによって作り出されたサンプルが(ii)を満たすことをencourageする項で,\lambdaの大きさによってどれだけ重要視するかが変わる. 実験では\lambda=1とした.

Classifierの方は次の損失関数で学習される.

\displaystyle
\mathcal{L}_\mathcal{C}=L(\mathcal{P}^\prime)+L(\mathcal{P})+\gamma\|\mathbf{F}_g-\mathbf{F}_{g^\prime}\|_2

第1項目と第2項目は通常通りの損失で,最後の項は元の点群とaugmentorによって生成された点群の特徴ベクトル間の一致を強いる. \gammaは実験で10とした.

まとめ

PointNetやDGCNNなど代表的なモデルで1~2%の精度の向上が見られている. データによっては4%前後. 画像と違って点群はdata augmentationのバリエーションがそんなにないことが手法のシンプルさに繋がったかなと.