機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

GRAPH ATTENTION NETWORKSを読んだのでメモ

はじめに

GRAPH ATTENTION NETWORKSを読んだのでメモ.

気持ち

Kipf & Wellingの提案したGraph Convolutional Networks (GCN)は学習されたフィルタがグラフラプラシアン固有ベクトルに依存するため異なるグラフ構造に対応することができない.そこでフィルタがグラフ構造に依存しないようなアテンションに基づくGCN,Graph Attention Networks (GAT)を提案するというもの.

GAT

Graph attentional layer

まず,一層のgraph attentional layerを考える.入力は各ノードの特徴ベクトルで\mathbf{h}=\{h_1,\dots,h_N\},h_i\in\mathbb{R}^Fとして定義する.ただし,Nはノード数で,Fは特徴ベクトルの次元数.出力としては\mathbf{h}'=\{h_1',\dots,h_N'\},h_i'\in\mathbb{R}^{F'}となる.

ここでは学習パラメータ\mathbf{W}\in\mathbb{R}^{F'\times F}を持つ線形変換を考える.この線形変換は全てのノードに適用される.その際に,次のshared attentional mechanism a:\mathbb{R}^{F'}\times\mathbb{R}^{F'}\rightarrow\mathbb{R}によってアテンション係数を計算する.

\displaystyle
e_{ij}=a(\mathbf{W}h_i,\mathbf{W}h_j)

これはノードiにおけるノードjの特徴の重要度を表す.上記の形だと全てのノードの組み合わせを計算する必要が出てくるため,ここでは隣接しているノードj\in\mathcal{N}_iのみで計算を行う.計算された各係数のオーダーを合わせるため次のようにsoftmax関数を使って正規化を行う.

\displaystyle
\alpha_{ij}=\mathrm{softmax}_j(e_{ij})=\frac{\exp(e_{ij})}{\sum_{k\in\mathcal{N}_i}\exp(e_{ik})}

実験においては,アテンション係数を計算するための関数aは1層のニューラルネットワーク\mathbf{a}\in\mathbb{R}^{2F'}とLeakyReLUを使ったとのこと.そのため愚直に式を書けば次のようになる.

\displaystyle
\alpha_{ij}=\frac{\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^T[\mathbf{W}h_i\|\mathbf{w}h_j]\right)\right)}{\sum_{k\in\mathcal{N}_i}\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^T[\mathbf{W}h_i\|\mathbf{W}h_k]\right)\right)}

変数の右肩のTは行列の転置,\|は変数の結合(concatenation)を表す.

この係数を使って各層での変換は次のように記述される.

\displaystyle
h_i'=\sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}\mathbf{W}h_j\right)

さらに,学習過程におけるアテンション係数の安定化のために次のようなmulti-head attentionの構造を取り入れたとのこと.

\displaystyle
h_i'=\overset{K}{\underset{k=1}{\|}}\sigma\left(\sum_{j\in\mathcal{N}_i}\alpha_{ij}^k\mathbf{W}^kh_j\right)

要はK個の独立したアテンション機構を導入するということ.基本的には各アテンションごとの出力を結合するが,最後の識別層においては各出力の平均を計算したとのこと.

まとめ

なんとなくアテンションを計算するための線形変換と,各層における線形変換の重み係数が同じなのに違和感(単純に同じ文字を使っているだけで別なパラメータなのかもしれないが).multi-head attentionもただパラメータ増えて表現力が上がったからうまくいっただけな気もするがどうなのか.

感覚的にはグラフ構造を作る際の距離関数を学習ベースで決めるということだと思うが,PPIデータセットを使った実験ではベースラインと比較して飛躍的に性能が向上していたので実装して試したいところ.