機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

HOW POWERFUL ARE GRAPH NEURAL NETWORKS?を読んだのでメモ

はじめに

HOW POWERFUL ARE GRAPH NEURAL NETWORKS?を読んだのでメモ.各命題,定理の証明は今回は割愛.

2018.11.24

Openreviewにて論文で提案しているGINが定理3を満たさない(= WL testと同等の識別能力を有すると言えない)ということが示されて内容が変更されていたのでこちらのメモも少し修正.(修正抜けがあると思うので正確な内容を知りたい場合は論文を参照).

気持ち

Graph convolutional networks(GNN)がグラフデータに関する様々な課題においてstate-of-the-artを記録しているが,どの研究もヒューリスティックスに頼ったtrial-and-errorで手法が作られていることを問題視.もう少し理論的な理解とともに新しいGNNを提案するというもの.

Graph Neural networks

まずは一般的なGNNについて.G=(V,E)を各ノードが特徴ベクトルX_v,\:v\in Vを持つグラフとする.今回はnode classification(各ノードのラベルy_vの予測)とgraph classification(グラフそのもののラベルy_Gの予測)の二つに焦点を当てる.

GNNは主にノードの表現ベクトルh_vまたはグラフそのものの表現ベクトルh_Gを学習するために用いられ,基本的な戦略としては隣接ノードの情報を集約を繰り返すことでそのノードの持つ表現ベクトルを更新していく.一般的にk回集約を繰り返すことでk-hopの情報を得ることが可能.この集約の手続きは次のように表現できる.

\displaystyle
a_v^{(k)}=\mathrm{AGGREGATE}^{(k)}\left(\left\{h_u^{(k-1)}:u\in\mathcal{N}(v)\right\}\right),\:h_v^{(k)}=\mathrm{COMBINE}^{(k)}\left(h_v^{(k-1)},a_v^{(k)}\right)

h_v^{(k)}k層目のノードvが持つ特徴ベクトルで,h_v^{(0)}=X_v\mathcal{N}(v)vに隣接するノード集合.\mathrm{AGGREGATE}^{(k)}(\cdot),\mathrm{COMBINE}^{(k)}(\cdot)の選び方によってGNNは区別される.例えば,KipfとWellingが提案したGCNは次のような\mathrm{AGGREGATE}関数として定義される.

\displaystyle
a_v^{(k)}=\mathrm{MEAN}\left(\left\{\mathrm{ReLU}\left(W\cdot h_u^{(k-1)}\right),\forall u\in\mathcal{N}(v)\right\}\right)

Wは学習パラメータとなる行列.GraphSAGE\mathrm{MEAN}関数を要素ごとのmax-poolingに置き換えたものとして表現できる.またGraphSAGEにおいては\mathrm{COMBINE}のステップとしてW\cdot\left[h_v^{(k-1)}|a_v^{(k)}\right]という演算が存在する(GCNには\mathrm{COMBINE}の処理はない).

Node classificationにおいては,ノードの表現ベクトルh_v^{(K)}を使って予測を行い,graph classificationにおいては次のような\mathrm{READOUT}関数によりノードの表現ベクトルの集約が行われる.

\displaystyle
h_G=\mathrm{READOUT}\left(\left\{h_v^{(K)}|v\in G\right\}\right)

\mathrm{READOUT}は基本的にpermutation invariantな関数である必要がある.

Weisfeiler-Lehman test

Weisfeiler-Lehman test(WL test)は各ノードにカテゴリカルラベルが付与されていると仮定した元で,(1)ノードとその隣接ノードのラベルをaggregate,(2)集約されたラベルをユニークな新しいラベルに変換,という処理を繰り返すもので,この処理を繰り返した結果二つのグラフの各ノードが異なるかどうかでグラフの同型性を判断するというもの.この処理はGNNの処理と 似ているというのがここでの主張.

WL testを使ったGNN的な手法としてShervashidzeらが提案したWL subtree kernelがある.これはWL testの異なる繰り返し回数でのノードラベルをノードの特徴ベクトルとして使う手法で,k-th iterationにおいてあるノードをルートとしたのきの高さk木構造を表現可能(Figure 1がその例).

Theoretical framework

GNNの表現力の高さを分析する.仮定としてグラフの各ノードの特徴ベクトルとしてunique label\in\{a,b,c,\dots\}を割り当てる.この時隣接ノードの集合の特徴ベクトルは以下で定義されるmultisetを形成する.これは異なるノードが同一の特徴ベクトルを持つため同一の要素が複数回現れることを意味する.

Definition 1 (Multiset)

Multisetは集合の一般化された概念で,その要素として複数のインスタンスを持つことを許す.すなわち,multisetは2-tuple X=(S,m)であり,この時SXの基底となる集合で個別の要素によって形成され,m:S\rightarrow\mathbb{N}_{\geq 1}は要素の多様性を与える.

ここではGNNの表現力を測るため,GNNに二つのノードを同一の場所に埋め込む時を考える.直感的には,これは二つのノードが同一の特徴量,同一のsubtree構造を持つ時にのみ可能だと思える.面白いのはmultisetsな状況においてGNNはこれは不可能だと言っていて,aggregation schemeはinjectiveであるという.よってここではCNNをmultisets上の関数の集合として考え,それがinjective multiset functionsを表現できるかどうかを分析する.

Generalizing the WL test with graph neural networks

理想的にGNNの表現力の高さは異なるグラフを区別できるかどうかであるが,これはグラフの同型性の識別問題を解決したことになってしまう.ここでの分析においては少しゆるい指標を使ってGNNの表現力を分析する.

Lemma 2.

G_1,G_2を非同型なグラフとする.もしneighborhood aggregation schemeに従うGNN \mathcal{A}:\mathcal{g}\rightarrow\mathbb{R}^bG_1G_2に異なる埋め込みを与えるならば,WL graph isomorphism testを使ってG_1G_2は非同型であると判別できる.

Lemma 2からaggregation-based GNNはグラフ識別においてWL testと同等の能力がある.では現状提案されてるGNNはどれもWL testと同等の能力があるのかというのが疑問.論文では次のTheoremから同等な能力があるとしている.

Theorem 3

\mathcal{A}:\mathcal{g}\rightarrow\mathbb{R}^dをneighborhood aggregation schemeに従うGNNとする.十分な回数集約の操作が行われたとき,もし次の条件が満たされれば\mathcal{A}はWL testにおいて非同型と判断されたG_1G_2を異なる埋め込みへと写像する.

a) \mathcal{A}はノードの特徴ベクトルを次のように繰り返し集約,更新する.

\displaystyle
h_v^{(k)}=\phi\left(h_v^{(k-1)},f\left(\left\{h_u^{(k-1)}:u\in\mathcal{N}(v)\right\}\right)\right)\:or\:h_v^{(k)}=f\left(\left\{h_v^{(k-1)},h_u^{(k-1)}:u\in\mathcal{N}(v)\right\}\right)

関数fはmultisets上の演算子で,\phi単射

b)\mathcal{A}のグラフレベルのreadout(multiset上の演算)は単射

ここまでGNNとWL testは同等とは言いつつも,GNNはWL test以上に重要な利益をもたらす.というのもWL testでのノードの特徴ベクトルはone-hotベクトルでsubtree間の類似性などを得ることはできない.つまりGNNはWL testを連続空間に拡張かつ学習ベースにしたものと言える.これはGNNが構造の識別だけでなく埋め込みの学習やグラフ構造間のdependenciesも取得可能ということを意味する.

Graph isomorphism network (GIN)

ここではTheorem 3を満たすことで一般化されたWL testという保証のあるGraph Isomorphism Network (GIN)を提案.

Lemma 4によりsum aggregatorsが単射,特にmultisets上のuniversal functionsを表現できるということを示す.

Lemma 4

\mathcal{X}が可算であるとする.h(c,X)=(1+\epsilon)\cdot f(c)+\sum_{x\in X}f(x)が有限のmultiset X\subset\mathcal{X}においてuniqueとなるような関数f:\mathcal{X}\rightarrow\mathbb{R}^nが存在する.さらに,任意のmultiset function gはある関数\phiを用いてg(X)=\phi( (1+\epsilon)\cdot f(c)+\sum_{x\in X}f(x) )のように分解される.

Multi-layer perceptrons(MLPs)のuniversal approximation theoremから,f,\phiをMLPsでモデリングすることができる.実践的にはf^{(k+1)}\circ\phi^{(k)}は一つのMLPでモデル化したとのこと.するとGINは次のような形でノードの特徴量を更新する.

\displaystyle
h_v^{(k)}=\mathrm{MLP}^{(k)}\left((1+\epsilon)^{(k)}h_v^{(k-1)}+\sum_{u\in\mathcal{N}(v)}h_u^{(k-1)}\right)

\epsilonは実験では0または学習パラメータとしていた(ちなみに0の場合はTheorem 3を満たさない).この演算は言ってしまえば自分自身に適当な係数をかけて全ての隣接ノードとの和をとってMLPに入力するというもの.またGINは\mathrm{AGGREGATE}のみで\mathrm{COMBINE}は行わない.めちゃくちゃシンプルで直感的にはうまくいかなさそうだがTheorem 3からより複雑なものと比べても同程度の性能を発揮できる保証がある.

Readout subtree structures of different depths

グラフレベルのreadoutの重要な側面としては繰り返し回数が増えるほどノードの表現はrefinementされ,かつglobalな情報を得ることができるという点.この繰り返しの回数は重要だが,稀に回数が少ない時に得られた表現が良いものである場合もある.そのためGINでは全てのdepths/iterationsから得られた情報を用いることにする.これをJumping Knowledge Networks (JK-Nets)に似た構造で実現する.すなわち次のようなreadoutの関数を用いる.

\displaystyle
h_G=\mathrm{CONCAT}\left(\mathrm{READOUT}\left(\left\{h_v^{(k)}|v\in G\right\}\right)|k=0,1,\dots,K\right)

つまり全てのiterationにまたがって特徴ベクトルをconcatするというもの.Theorem 3とLemma 4に従えばこのreadoutの処理はWL testとWL subtree kernelの一般化になっている.

Less powerful but still interesting GNNs

最後にGCNやGraphSAGEを含むTheorem 3を満たさないGNNについて考える.ここではGINの集約の処理を(1)MLPsではなく単層のパーセプトロンにした場合,(2)sumの代わりにmeanまたはmax-poolingを使った場合のモデルについて掘り下げる.これらのGNNのバリエーションがWL testより劣り,単純なグラフにおいて機能しないことを示す.それでもなおsumの代わりにmeanを使うGCNはnode classificationにおいてよく機能するらしい.この理解を深めるためにGNNがグラフの何を捉えることができて何を捉えることができないかを論じる.

1-layer perceptron is insufficient for capturing structures

Lemma 4における関数fは異なるmultisetsをuniqueなembeddingsに写像する役割を果たし,universal approximation theoremからMLPによってパラメタライズされる.にも関わらず多くのGNNは単層のパーセプトロンを使っている.するとグラフの学習に単層のパーセプトロンで十分かという疑問がわく.実際には単層のパーセプトロンモデルはできないことが次のLemma 5から言える.

Lemma 5

任意の線形変換W,\:\sum_{x\in X_1}\mathrm{ReLU}(Wx)=\sum_{x\in X_2}\mathrm{ReLU}(Wx)に対して有限のmultisets X_1\neq X_2が存在する.

すなわち違うmultisetsに対して同じ表現を割り当ててしまう(区別できなる)場合があるということ.これを後ほど実験的にも証明した.

Structures that confuse mean and max-pooling

今度はsumをmeanまたはmax-poolingに置き換えたらどうなるかを考える.基本的にmeanまたはmax-poolingはpermutation invariantなため良いmultisite functionであるが,単射でないという問題がある.Figure 2と3にmeanやmaxにした場合に非常に単純なグラフの識別に失敗する例が示されている.

Mean learns distributions

ここではmean aggregatorが区別可能なmultisetsを調べるため,X_1=(S,m),X_2=(S,k\cdot m)という例を考える.この二つのmultisetsはmean aggregatorによって同じ埋め込みが与えられる.そのため,mean aggregatorはmultisetの要素の分布を捉えていて,正確なmultisetを捉えることはしていないと言える.

Corollary 6

\mathcal{X}が可算であるとする.もし,有限のmultisetsX_1,X_2が同じ分布を持つならば,h(X)=\frac{1}{|X|}\sum_{x\in X}f(x),\:h(X_1)=h(X_2)に関して関数f:\mathcal{X}\rightarrow\mathbb{R}^nが存在する.

もし解きたい問題がグラフの構造よりも統計的な情報を重視する場合にはmean aggregatorはうまく機能する.さらに言えば,ノードの特徴が多様な場合にはsum aggregatorと同程度の性能を発揮する.これがnode classificationにおいてmean aggregator (GCN)が機能する理由だと考えられる.

Max-pooling leans sets with distinct elements

Max-poolingは残念なことに正確な構造も分布も捉えることができない.ただし,代表的な要素を捉えるのが重要なタスクにおいては適していると言える.実際PointNet (point cloudのためのNN)ではmax poolingを使っていて,ノイズや外れ値に頑健な認識を可能にしている.

Corollary 7

\mathcal{X}が可算であるとする.もしX_1,X_2が同じunderlying setを持つなら,h(X)=\max_{x\in X}f(x),\:h(X_1)=h(X_2)に関してf:\mathcal{X}\rightarrow\mathbb{R}^\inftyが存在する.

まとめ

個人的にめちゃくちゃ面白かった.GNNの論文は色々読んだけど確かに表現力に関して突っ込んだ論文はなかったなと.読み応えあってメモも一部和訳レベルで残してしまった.ICLR2019のレビュー中なようだけど3人目のレビュワーがしきりに"I quite liked the paper"て言っててテンション上がってたのが印象的.