機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Neural Relational Inference for Interacting Systemsを読んだのでメモ

はじめに

Neural Relational Inference for Interacting Systemsを読んだのでメモ.ニューラルネットにより要素間のインタラクションを推定したいというもの.例としてバスケットボールをあげていて,ある選手のダイナミクスは相手選手や味方の選手の動きや位置によって影響を受ける.そのような関係性を教師なしで獲得するというのが論文のゴール.実験ではトラジェクトリ推定などに適用.

Neural Relational Inference Model

提案モデルであるNeural Relational Inference model(NRI model)は,入力に現時刻までの(バスケを例にとれば)各選手のトラジェクトリが与えられ,出力として次の時刻の各選手の位置と速度を返すようなモデル.これをVAEを使って実現するというもの.肝としては,VAEの潜在表現\mathbf{z}を各選手間のインタラクションを表すグラフ構造にしたというもの.

時刻tにおけるN個の物体の観測値(位置や速度)を\mathbf{x}^t=\{\mathbf{x}_1,\dots,\mathbf{x}_N^t\}で表し,物体v_iのトラジェクトリを\mathbf{x}_i=(\mathbf{x}_i^1,\dots,\mathbf{x}_i^T),全トラジェクトリは\mathbf{x}=(\mathbf{x}^1,\dots,\mathbf{x}^T)で表す.物体v_iv_j間の関係性を離散エッジ\mathbf{z}_{ij}で表す.各物体のダイナミクスは未知のグラフ\mathbf{z}を入力とするgraph neural network(GNN)でモデル化することを目標とする.

目的関数は次のVAEの目的関数がベースとなる.

\displaystyle
\mathcal{L}=\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})]-\mathrm{KL}[q_\phi(\mathbf{z}|\mathbf{x})\| p_\theta(\mathbf{z})]

ただし,今回の問題設定では時系列を扱うかつ潜在変数が離散エッジのため各分布は少し特殊となる.エンコーダーq_\phi(\mathbf{z}|\mathbf{x})は因子分解された\mathbf{z}_{ij}の分布を返す.\mathbf{z}_{ij}v_i,v_j間のエッジのタイプを表すカテゴリカル分布に従う確率変数で,具体的にはK個のインタラクションのタイプを表すonehot表現になる.

デコーダーは物体間の関係性と各物体のトラジェクトリを入力として次の時刻の各物体の状態を返す.すなわち次のように表現される.

\displaystyle
p_\theta(\mathbf{x}|\mathbf{z})=\prod_{t=1}^Tp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^{t},\dots,\mathbf{x}^1,\mathbf{z})

priorはp_\theta(\mathbf{z})=\prod_{i\neq j}p_\theta(\mathbf{z}_{ij})として表現される因子分解された形の一様分布.

Encoderについて

エンコーダーは次の4つの演算から構成される.

\displaystyle
\mathbf{h}_j^1=f_\mathrm{emb}(\mathbf{x}_j)\\
v\rightarrow e:\mathbf{h}_{(i,j)}^1=f_e^1([\mathbf{h}_i^1,\mathbf{h}_j^1])\\
e\rightarrow v:\mathbf{h}_j^2=f_v^1(\sum_{i\neq j}\mathbf{h}_{(i,j)}^1)\\
v\rightarrow e:\mathbf{h}_{(i,j)}^2=f_e^2([\mathbf{h}_i^2,\mathbf{h}_j^2])

最初の演算は各オブジェクトの観測値を特徴量に変換する計算.次の式は物体v_i,v_j間の関係性を記述する計算(物体vからエッジe).3番目は得られた物体間の関係から物体の特徴量を得るする演算(エッジeから物体v).最後に物体間の関係性(インタラクション)を得る(物体vからエッジe).得られた物体間の関係性を表す特徴量\mathbf{h}_{(i,j)}^2をsoftmax関数にかけることでインタラクションに関する分布q_\phi(\mathbf{z}_{ij}|\mathbf{x})=\mathrm{softmax}(\mathbf{h}^2_{(i,j)})を得る.ここでエッジ->物体->エッジの順で特徴量を計算しているのは,2番目の式から得られるエッジの特徴量は2つの物体間の情報しか入っていないため他のノードとのインタラクションが無視されているためとのこと.繰り返すことで物体全体のインタラクションが考慮可能ということ.各関数は全てニューラルネットにより表現(全結合か1次元のCNN).

今回はVAEの構造を考えているためエンコーダーから得られた分布q_\phi(\mathbf{z}_{ij}|\mathbf{x})からサンプリングする必要がある.この論文ではconcrete distributionによってサンプリングを行ったとのこと.concrete distributionについては過去に論文を読んでメモしたのでそちらを参照.

Decoderについて

最初に書いたように今回のデコーダーはp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\dots,\mathbf{x}^1,\mathbf{z})と表され,潜在変数と現時刻までのトラジェクトリの情報を入力とする.今回はマルコフ性を仮定してp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\dots,\mathbf{x}^1,\mathbf{z})=p_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\mathbf{z})という表現に置き換える.

今回のデコーダーは次の3つの演算から構成される.

\displaystyle
v\rightarrow e:\tilde{\mathbf{h}}_{(i,j)}^t=\sum_kz_{ij,k}\tilde{f}_e^k([\mathbf{x}_i^t,\mathbf{x}_j^t])\\
e\rightarrow v:\mathbf{\mu}_j^{t+1}=\mathbf{x}_j^t+\tilde{f}_v(\sum_{i\neq j}\tilde{h}_{(i,j)}^t)\\
p(\mathbf{x}_j^{t+1}|\mathbf{x}^t,\mathbf{z})=\mathcal{N}(\mathbf{\mu}_j^{t+1},\sigma^2\mathbf{I})

最初の計算は関係性z_{ij,k}を考慮した2物体のインタラクションを表す.ここでのkはインタラクションの種類を表し,理想的にはz_{ij,k}はonehotベクトルだが今回は連続値に緩和しているため各インタラクションの種類に関する重み付き和として計算.2番目の計算は関係する全ての物体とのインタラクションを考慮して現時刻から次の時刻での物体の状態を表現.最終的に固定の分散\sigma^2を持つ正規分布として次の時刻の状態を表現.

ここでの問題は,デコーダーが潜在変数を無視するlatent variable collapse(その他KL collapseなどとも呼ばれる)を起こしてしまうこと.特に,short-termでの等速直線運動を仮定したモデルでも次の予測がそれなりの精度で行えてしまうように,インタラクション\mathbf{z}を無視してもある程度所望の出力を得ることが可能になってしまう.なのでこの問題を回避するため,デコーダーは複数ステップ(Mステップ)の予測をするように学習.そうすることでshort-termでは成り立つようなモデルがうまく機能しなくなり,インタラクション\mathbf{z}の情報に頼らざるを得なくなるというもの.具体的には,次のようにある時刻の出力を使って再帰的に予測していく.

\displaystyle
\mathbf{\mu}_j^2=f_\mathrm{dec}(\mathbf{x}_j^1)\\
\mathbf{\mu}_j^{t+1}=f_\mathrm{dec}(\mathbf{\mu}_j^t)\ \ \ t=2,\dots,M\\
\mathbf{\mu}_j^{M+2}=f_\mathrm{dec}(\mathbf{x}_j^{M+1})\\
\mathbf{\mu}_j^{t+1}=f_\mathrm{dec}(\mathbf{\mu}_j^t)\ \ \ t=M+2,\dots,2M

今回はM=10で実験したとのこと.

Recurrent decoder

悲しいかな,実問題を考えた場合マルコフ性が成り立たない場合が多く存在する.なのでここではデコーダーにRNNを導入して全ての系列を考慮可能なデコーダーを考える.具体的にはGRUを使って次のような演算でデコーダーを構成する.

\displaystyle
v\rightarrow e:\tilde{\mathbf{h}}_{(i,j)}^t=\sum_kz_{ij,k}\tilde{f}_e^k([\tilde{\mathbf{h}}_i^t,\tilde{\mathbf{h}}_j^t])\\
e\rightarrow v:\mathrm{MSG}_j^t=\sum_{i\neq j}\tilde{\mathbf{h}}_{(i,j)}^t\\
\tilde{\mathbf{h}}_j^{t+1}=\mathrm{GRU}([\mathrm{MSG}_j^t,\mathbf{x}_j^t],\tilde{\mathbf{h}}_j^t)\\
\mathbf{\mu}_j^{t+1}=\mathbf{x}_j^t+f_\mathrm{out}(\tilde{\mathbf{h}}_j^{t+1})\\
p(\mathbf{x}^{t+1}|\mathbf{x}^t,\mathbf{z})=\mathcal{N}(\mathbf{\mu}^{t+1},\sigma^2\mathbf{I})

最初の演算はインタラクションに関する隠れ層(hidden variable)の計算.2から4番目の計算がrecurrent部分で,インタラクションに関する隠れ層の値の集約値\mathrm{MSG}^t_jと観測値\mathbf{x}_j^tをGRUに入力,その出力から次の時刻の状態を得るという流れ.

まとめ

トイデータでの実験結果がめちゃくちゃ美しい.実問題での性能はわからないが実験ではLSTMと比べて飛躍的に精度が向上している.