機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Convolutional Neural Networks on Graphs with Fast Localized Spectral Filteringを読んだのでメモ

はじめに

Convolutional Neural Networks on Graphs with Fast Localized Spectral Filteringを読んだのでメモ.

Convolution on Graph

この論文ではグラフフーリエ変換に基づく畳み込みを考える.

Graph Fourier transform

無向グラフ\mathcal{g}=(\mathcal{V},\mathcal{E},W)で定義された信号を考える.\mathcal{V}|\mathcal{V}|=n個の頂点集合,\mathcal{E}はエッジの集合,W\in\mathbb{R}^{n\times n}は重み付きの隣接行列を表す.D\in\mathbb{R}^{n\times n}を対角行列D_{ii}=\sum_jW_{ij}とすれば,非正規化ラプラシアンL=D-W\in\mathbb{R}^{n\times n},正規化ラプラシアンL=I_n-D^{-1/2}WD^{-1/2}として表現される.グラフラプラシアンは実対称半正定値行列であるため直交固有ベクトル\{u_l\}_{l=0}^{n-1}\in\mathbb{R}^nを持ち,非負の固有値\{\lambda_l\}_{l=0}^{n-1}を持つ.グラフラプラシアン固有ベクトルフーリエ基底U=[u_0,\dots,u_{n-1}]\in\mathbb{R}^{n\times n}として知られ,固有値分解からグラフラプラシアンフーリエ基底をもちいてL=U\Lambda A^Tとして表現できる.グラフ上に定義された信号x\in\mathbb{R}^nフーリエ変換\hat{x}=U^Tx\in\mathbb{R}^n,逆フーリエ変換x=U\hat{x}として記述される.

Spectral filtering of graph signals

以前読んだ論文と変わりないのでさっくりと式だけ.Fourier domainで定義された信号xに対するパラメータ\theta\in\mathbb{R}^nを持つフィルタg_\thetaによる畳み込みは次のように表現できる.

\displaystyle
y=g_\theta(L)x=g_\theta(U\Lambda U^T)x=Ug_\theta(\Lambda)U^Tx

ただしフィルタはg_\theta(\Lambda)=\mathrm{diag}(\theta)

Polynomial parametrization for localized filters

グラフ上でのフィルタの学習はvertex domainでの局所性を考慮できないこととパラメータが\mathcal{O}(n)とデータの次元によるという問題がある.そこでこの論文ではこの問題を解決するため次のような多項式展開したフィルタを使う.

\displaystyle
g_\theta(\Lambda)=\sum_{k=0}^{K-1}\theta_k\Lambda^k

パラメータ\theta\in\mathbb{R}^K多項式の係数.頂点iを中心とするフィルタg_\thetaの頂点jの値は(g_\theta(L)\delta_i)_j=(g_\theta(L))_{i,j}=\sum_k\theta_k(L^k)_{i,j}で与えられ,カーネルクロネッカーのデルタ\delta_i\in\mathbb{R}^nによって表現される.ここで,グラフ上の2つの頂点を最短距離で結ぶパスのエッジの数がd_\mathcal{g}(i,j)\gt Kである時(L^K)_{i,j}=0であることが知られている.結果としてラプラシアンK次の多項式で表現されるspectral fileterはK近傍までの頂点を含んだ計算になっていることがわかる(つまりKを畳み込みのカーネルサイズとした時に(L^K)_{i,j}=0からKより遠い頂点は畳み込みの計算に含まれないということ).さらにパラメータ数も\mathcal{O}(K)になっていて従来のCNNと同様のオーダーになっている.

Recursive formulation for fast filtering

ここがこの論文の主な貢献部分.基本的にフィルタリングの処理y=Ug_\theta(\Lambda)U^Tx\mathcal{O}(n^2)の計算コストがかかる.解決方法としては前節のようなLから再帰的に計算可能な多項式関数としてg_\theta(L)を定義することが考えられる.疎な行列Lに対してK回行列積を計算する時その演算回数のオーダーは\mathcal{O}(K|\mathcal{E}|)\ll\mathcal{O}(n^2)になる.グラフ信号処理においてはグラフ上での畳み込みにおけるフィルタの多項式表現は昔から研究されていて,一般的な方法としてチェビシェフ多項式を使ってフィルタの近似をする方法がある.そこでここではチェビシェフ多項式を使ってフィルタを近似する.

k次のチェビシェフ多項式T_k(x)T_k(x)=2xT_{k-1}(x)-T_{k-2}(x),\:T_0=1,\:T_1=xとして計算することが可能.この多項式dy/\sqrt{1-y^2}を測度とするヒルベルト空間L^2([-1,1],dy/\sqrt{1-y^2})を作ることが知られている(ちなみにこのあたりの初等的な教科書として金谷 健一先生の「これならわかる応用数学教室」などが自分のような数学弱者にも優しい).K-1次のチェビシェフ多項式を使えばフィルタは次のように定義できる.

\displaystyle
g_\theta(\Lambda)=\sum_{k=0}^{K-1}\theta_kT_k(\tilde{\Lambda})

\theta\in\mathbb{R}^Kはチェビシェフ係数でT_k(\tilde{\Lambda})\in\mathbb{R}^{n\times n}\tilde{\Lambda}=2\Lambda/\lambda_{max}-I_nによって計算されるk次のチェビシェフ多項式\bar{x}_k=T_k(\tilde{L})x\in\mathbb{R}^nと定義すれば,チェビシェフ多項式再帰関係から\bar{x}_k=2\tilde{L}\bar{x}_{k-1}-\bar{x}_{k-2},\:\bar{x}_0=x,\:\bar{x}_1=\tilde{L}xとして計算することができ,フィルタリング計算y=g_\theta(L)x=[\bar{x}_0,\dots,\bar{x}_{K-1}]\thetaの演算回数は\mathcal{O}(K|\mathcal{E}|)になる.

Learning filters

サンプルsj番目の出力の特徴マップは次のように計算できる.

\displaystyle
y_{s,j}=\sum_{I=1}^{F_{in}}g_{\theta_{i,j}}(L)x_{s,j}\in\mathbb{R}^n

x_{s,i}は入力の特徴マップでチェビシェフ係数\theta\in\mathbb{R}^{F_{in}\times F_{out}\times K}が学習係数となっている.よって最終的な計算コストは\mathcal{O}(K|\mathcal{E}|F_{in}F_{out}S)になる.フィルタリング処理の入力と学習パラメータに関するbackprop中の微分計算は次のようになる.

\displaystyle
\frac{\partial E}{\partial \theta_{i,j}}=\sum_{s=1}^S[\bar{x}_{s,i,0},\dots,\bar{x}_{s,i,K-1}]^T\frac{\partial E}{\partial y_{s,j}}\\ \displaystyle
\frac{\partial E}{\partial x_{s,i}}=\sum_{j=1}^{F_{out}}g_{\theta_{i,j}}(L)\frac{\partial E}{\partial y_{s,j}}

ただし,Eは目的関数でSはミニバッチ内のサンプル数.

Graph Coarsening

ここではグラフ上でのプーリングについて考える.グラフ上のプーリングは基本的に近傍の似た頂点同士を一つのクラスタにまとめる処理を指す.ただし,グラフのクラスタリングはNP-hardであることが知られていて一般的には何らかの近似を用いなければならない.Grapn Convの文脈ではマルチレベルなクラスタリングが可能でダウンサンプリングのスケールがコントロールできる必要がある.この論文ではGraclus multilevel clustering algorithmを転用したとのこと.このアルゴリズムはgreedy algorithmの一種で,クラスタが割り振られていないノードを適当に取り出して局所的なnormalized cut W_{ij}(1/d_i+1/d_j)を最大化するような近傍のノードjを選んでクラスタリングする.

Fast Pooling of Graph Signals

グラフ上でプーリングの処理は単純に行うと対応関係を保持する必要があって計算コストが非常にかさむ.そこでこの論文ではちょっとした工夫をして1Dプーリングと同程度の計算効率でgraph poolingを行う.具体的にはbalanced binary treeの構築と頂点の振り直しの二つの処理によって効率化する.

グラフ上でのプーリング処理が行われた場合,プーリング後のグラフの各ノードはプーリング前の2つ,もしくは1つのノードと結びつくことになる.そこで,全てのノードが二つの子ノードを持つように擬似ノード(fake node)を作る(これがbalanced binary tree).すると,全てのノードは(1)二つの真のノードを持つ(2)1つの真のノードと1つの擬似ノードを持つ(3)2つの擬似ノードを持つ3種類のノードに分けられる.ただし,全ての擬似ノードは(3)にあたる.擬似ノードはプーリング後に影響を及ぼさないような値,例えばReLUを活性化関数に使ってmaxpoolingをする場合は0などの値で初期化される.当然擬似ノードを導入するとその分次元が増えて計算コストは増えるが,実験的にGraclus algorithmを使った時には孤立ノードは少量しか出ないため特に問題はないとのこと.

で,このbalanced binary treeを作ることの何がいいかというと,こうして作られたグラフは二つのノードを一つにまとめることでプーリングされる,すなわち孤立ノードがないため綺麗に番号を並び替えることで1Dプーリングと同様の演算でプーリングを行うことが可能なため高速に演算することができる.Figure 2にこのプーリング処理の図解があるのでそこを見ればわかりやすい.ちなみに,今までの説明はスケールを2分の1にする際の処理だが,ダウンサンプリングのスケールを大きくしたい場合にはスケールに対応した数の子ノードを持つように擬似ノードを導入すればいい.

まとめ

実験で面白かったのは普通にフィルタの重みを用意するより多項式近似の形にした方が精度がよかったというところ(MNISTでの実験だからなんとも言えないが).普通のspectral filteringの方法で学習するとパラメータ多すぎて過学習気味になるってことなのかなんなのか.