Neural Relational Inference for Interacting Systemsを読んだのでメモ
はじめに
Neural Relational Inference for Interacting Systemsを読んだのでメモ.ニューラルネットにより要素間のインタラクションを推定したいというもの.例としてバスケットボールをあげていて,ある選手のダイナミクスは相手選手や味方の選手の動きや位置によって影響を受ける.そのような関係性を教師なしで獲得するというのが論文のゴール.実験ではトラジェクトリ推定などに適用.
Neural Relational Inference Model
提案モデルであるNeural Relational Inference model(NRI model)は,入力に現時刻までの(バスケを例にとれば)各選手のトラジェクトリが与えられ,出力として次の時刻の各選手の位置と速度を返すようなモデル.これをVAEを使って実現するというもの.肝としては,VAEの潜在表現を各選手間のインタラクションを表すグラフ構造にしたというもの.
時刻における個の物体の観測値(位置や速度)をで表し,物体のトラジェクトリを,全トラジェクトリはで表す.物体と間の関係性を離散エッジで表す.各物体のダイナミクスは未知のグラフを入力とするgraph neural network(GNN)でモデル化することを目標とする.
目的関数は次のVAEの目的関数がベースとなる.
ただし,今回の問題設定では時系列を扱うかつ潜在変数が離散エッジのため各分布は少し特殊となる.エンコーダーは因子分解されたの分布を返す.は間のエッジのタイプを表すカテゴリカル分布に従う確率変数で,具体的には個のインタラクションのタイプを表すonehot表現になる.
デコーダーは物体間の関係性と各物体のトラジェクトリを入力として次の時刻の各物体の状態を返す.すなわち次のように表現される.
priorはとして表現される因子分解された形の一様分布.
Encoderについて
エンコーダーは次の4つの演算から構成される.
最初の演算は各オブジェクトの観測値を特徴量に変換する計算.次の式は物体間の関係性を記述する計算(物体からエッジ).3番目は得られた物体間の関係から物体の特徴量を得るする演算(エッジから物体).最後に物体間の関係性(インタラクション)を得る(物体からエッジ).得られた物体間の関係性を表す特徴量をsoftmax関数にかけることでインタラクションに関する分布を得る.ここでエッジ->物体->エッジの順で特徴量を計算しているのは,2番目の式から得られるエッジの特徴量は2つの物体間の情報しか入っていないため他のノードとのインタラクションが無視されているためとのこと.繰り返すことで物体全体のインタラクションが考慮可能ということ.各関数は全てニューラルネットにより表現(全結合か1次元のCNN).
今回はVAEの構造を考えているためエンコーダーから得られた分布からサンプリングする必要がある.この論文ではconcrete distributionによってサンプリングを行ったとのこと.concrete distributionについては過去に論文を読んでメモしたのでそちらを参照.
Decoderについて
最初に書いたように今回のデコーダーはと表され,潜在変数と現時刻までのトラジェクトリの情報を入力とする.今回はマルコフ性を仮定してという表現に置き換える.
今回のデコーダーは次の3つの演算から構成される.
最初の計算は関係性を考慮した2物体のインタラクションを表す.ここでのはインタラクションの種類を表し,理想的にははonehotベクトルだが今回は連続値に緩和しているため各インタラクションの種類に関する重み付き和として計算.2番目の計算は関係する全ての物体とのインタラクションを考慮して現時刻から次の時刻での物体の状態を表現.最終的に固定の分散を持つ正規分布として次の時刻の状態を表現.
ここでの問題は,デコーダーが潜在変数を無視するlatent variable collapse(その他KL collapseなどとも呼ばれる)を起こしてしまうこと.特に,short-termでの等速直線運動を仮定したモデルでも次の予測がそれなりの精度で行えてしまうように,インタラクションを無視してもある程度所望の出力を得ることが可能になってしまう.なのでこの問題を回避するため,デコーダーは複数ステップ(ステップ)の予測をするように学習.そうすることでshort-termでは成り立つようなモデルがうまく機能しなくなり,インタラクションの情報に頼らざるを得なくなるというもの.具体的には,次のようにある時刻の出力を使って再帰的に予測していく.
今回はで実験したとのこと.
Recurrent decoder
悲しいかな,実問題を考えた場合マルコフ性が成り立たない場合が多く存在する.なのでここではデコーダーにRNNを導入して全ての系列を考慮可能なデコーダーを考える.具体的にはGRUを使って次のような演算でデコーダーを構成する.
最初の演算はインタラクションに関する隠れ層(hidden variable)の計算.2から4番目の計算がrecurrent部分で,インタラクションに関する隠れ層の値の集約値と観測値をGRUに入力,その出力から次の時刻の状態を得るという流れ.
まとめ
トイデータでの実験結果がめちゃくちゃ美しい.実問題での性能はわからないが実験ではLSTMと比べて飛躍的に精度が向上している.