機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

Graph Warp Module: an Auxiliary Module for Boosting the Power of Graph Neural Networksを読んだのでメモ

はじめに

Graph Warp Module: an Auxiliary Module for Boosting the Power of Graph Neural Networksを読んだのでメモ.

気持ち

GNNの表現力の低さを問題視した論文.この問題は以前読んだICLR2019のHow Powerful are Graph Neural Networks?でも議論されている.GNNは(データとタスクによるが)一般的なNNに比べ学習データにoverfitすることすらできないことがある.そのためこの論文ではGNNの表現力を向上させるgraph warp module (GWM)を提案するというもの.GWMは既存のGNNに追加モジュールとして組み合わせることが可能.

GWMのポイントはグラフの持つグローバルな特性を捉えるためのvirtual supernodeの導入と,正確なノード間の情報のやりとりを達成するためのGRUとattention利用した構造の2点とのこと.

notation

グラフをG=(V,E)とし,V,Eはそれぞれノードとエッジの集合を表す.ノードのラベルをi=1,2,\dots,|V|として,各エッジをノードのペアとして表現する.隣接行列\mathcal{A}\in\mathbb{R}^{|V|\times |V|}は重み付きとし,各ノードには特徴ベクトルx_iを割り振る.ここでは多層構造のGNNを,平滑化関数\mathcal{F}_lを利用して再帰的に出力を計算するものとして考える.初期値をx_j=h_{0,j}としh_{l,i}=\mathcal{F}_{l-1,i}(h_{l-1,j};j\in V)\in\mathbb{R}^dをGNNのl層目によってi番目のノードに割り当てられた特徴ベクトルとする.最終的に得られた特徴ベクトルの集合\{ h_{L,i};i\in V\}を集約することで出力を得る.

Graph Warp Module

GWMはsuper node, transmitter unit, warp gate unitの3つのブロックから構成されていて,GWM付きのGNNはl-1層目からのメッセージ\mathcal{F}_{l-1,i}(h_{l-1,i};j\in V)\in\mathbb{R}^dをsupernodeの値と合わせ混んでl層目の出力h_lとして返す.

Supernode

Supernodeは全てのノードと接続を持つノードで,GNNの各層に用意される.Supernodeは全てのノードと接続を持つため,グローバルな情報を伝達する助けとなるとのこと.l層目のsupernodeの特徴ベクトルをg_lとすると,各層においてtransmitterと呼ばれるunitはsupernodeからl+1層目に渡すメッセージ\mathcal{g}_l(g_l)を要求し,メインのGNNにメッセージを送る.ただし,\mathcal{g}_lはsmooth function.

l層目のsupernodeの特徴ベクトルはl-1層目のsupernodeの特徴ベクトルを利用して作られるため,初期値g_0は何らかの値で初期化されている必要がある.ここでは初期値の例としてノードやエッジの数,入力ノードの特徴量の平均やヒストグラムなどを挙げている.

Transmitter Unit

Transmitter unitはGNNとGWM間のメッセージのやり取りを扱うユニット.ここでは複数のタイプのメッセージを扱うために,attentionを利用する工夫を入れている.GNNからsupernodeにメッセージを送る前に,transmitterはK-head attention mechanismを使ってメッセージをタイプごとに集約する.処理の流れとしては,GNNからsupernodeへと送るメッセージを作成m_{l,k}^{main\rightarrow super},伝送h_l^{main\rightarrow super}.その後supernodeからのメッセージをGNNへ伝送g_l^{super\rightarrow main}の流れ.細かい計算は次のようになる.

\displaystyle
h_l^{main\rightarrow super}=\mathrm{tanh}\left(W_lm_{l,1:k}^{main\rightarrow super}\right)\in\mathbb{R}^{D'}\\
m_{l,k}^{main\rightarrow super}=\sum_i\alpha_{l,i,k}U_{l,k}h_{l-1,i}\in\mathbb{R}^{D'}\\
g_{l}^{super\rightarrow main}=\mathrm{tanh}(F_lg_{l-1})\in\mathbb{R}^D
\alpha_{l,i,k}=\mathrm{softmax}(a(h_{l-1},g_{l-1};A_{l,k}))\in(0,1)\\
a(h_{l-1},g_{l-1};A_{l,k}):=h_{l-1}^T,A_{l,k}g_{l-1}

\alpha_{l,i,k}の部分がattention.

Warp Gate

Warp geteは送られてきたメッセージをマージしてその結果をGRUを通してsupernodeとGNNに送るというもの.構成要素は以下.

h_l^0:l層目のGNNにメッセージを伝送するGRUの入力 ・g_l^0:l層目のsupernodeにメッセージを伝送するGRUの入力 ・\hat{g}_l:l-1層目のsupernodeからのメッセージ\mathcal{g}_{l-1}(g_{l-1}\in\mathbb{R}^{D'}\hat{h}_{l,i}:l-1層目のGNNからのメッセージ\mathcal{F}_{l-1,i};k\in V)\in\mathbb{R}^D
・[tex:z_{l,i}:supernodeからGNNへの伝送のためのwarp gate coefficients ・z_{l,i}^{(S)}:GNNからsupernodeへの伝送のためのwarp gate coefficients

各変数の具体的な定義は以下.

\displaystyle
h_{l,i}^0=(1-z_{l,i})\odot\hat{h}_{l-1,i}+z_{l,i}\odot g_l^{super\rightarrow main}\in\mathbb{R}^D\\
z_{l,i}=\sigma\left(H_l\tilde{h}_{l,i}+G_lg_l^{super\rightarrow main}\right)\\
g_l^0=z_l^{(S)}\odot h_l^{main\rightarrow super}+(1-z_l^{(S)})\odot\hat{g}_l\in\mathbb{R}^{D'}\\
z_l^{(S)}=\sigma\left(H_l^{(S)}h_l^{main\rightarrow super}+G_l^{(S)}\hat{g}_l\right)

これらをGRUに通してGNNとsupernodeのメッセージを組み合わせた値を返す.

\displaystyle
h_{l,i}=\mathrm{GRU}(h_{l-1,i},h_{l,i}^0)\in\mathbb{R}^D\\
g_l=\mathrm{RGU}(g_{l-1},g_l^0)\in\mathbb{R}^{D'}

まとめ

モチベーションははHow Powerful are Graph Neural Networks?と同じだがこの論文は問題の解決方法がpracticalで個人的にはHow Powerful are Graph Neural Networks?の方が好き.ただ,既存の枠組みに導入するだけで性能改善可能という点は扱いやすく,良い手法.

Emerging Convolutions for Generative Normalizing Flowsを読んだのでメモ

はじめに

Emerging Convolutions for Generative Normalizing Flowsを読んだのでメモ.

気持ち

Deepな生成モデルでGlowが登場して,flow-basedな生成モデルの利点が広く認知されたが,結局全単射な関数からしかモデルを構成できず,表現力を上げるには多層にする必要があった.そのせいで計算コストが上がってしまい扱いにくいという欠点が存在する.この論文ではGlowで提案された1x1Convをdxd Convに拡張して表現力を上げたというもの.

Generative flows

ここは過去散々メモを書いてきたので流石に今回は割愛.

Convolution

前提知識としてCNNにおける畳み込み(数学的にはcross-correlations)は行列とベクトルの積として表現可能ということを抑える.\mathbf{x}をベクトル化された画像(t=i+w\cdot jで展開,ただしi,jはそれぞれ列と行,wは幅を表す.)とし,\mathbb{W}hwc_{c_{out}}\times hwn_{c_{in}}の重みとする.すると畳み込みの演算は\mathbf{W}\cdot\mathbf{x}として表現可能.これはspatial domainでのgraph convolutionと同じロジック.Figure 2と3を見るとよくわかるはず.

一方で,normalizing flowの分野ではautoregressive convolutionという畳み込みが利用されることがある.autoregressive flow等についても過去いくつかメモを取ったので割愛.自己回帰モデルは各確率変数の依存関係からflow-basedなものと相性が良いのはよく知られていてautoregressive flowという文脈で広く研究されている.というのも,flow-basedな生成モデルは尤度の計算にヤコビアン行列式を計算する必要があるが,自己回帰モデルではここの重い計算を単純化しやすいため.このautoregressive convolutionに従えば先ほどの畳み込みの定式化において\mathbf{W}が三角行列になり,ヤコビアン行列式\mathbf{W}の体格成分の積として計算可能になる.

なぜ三角行列になるかという点において.autoregressive flowの文脈においては各画素はv_t\prod p(v_t|v_{t-1,t-2,\dots})という依存関係を持つ.すなわちある画素は以前に生成された画素とのみ依存関係を持つというもの.これを先ほどの畳み込みの枠組みで考えれば画像はt=i+w\cdot jと展開されているため,t番目の画素はt-1番目の画素としか依存関係がない.そのため,\mathbf{W}t行目においてt+1列目以降の値は0となる.(ここは論文に書いてない説明なのでnotation等は適当.)

とりあえずまとめると,\mathbf{W}\cdot \mathbf{x}は逆変換が可能でautoregressiveな設定を考えればヤコビアン行列式も計算が楽ということ.

Emerging convolutions

この論文の提案手法であるemerging convolutionについて.先ほど見たようにautoregressive convolutionは逆変換可能だが,変換の自由度はautoregressiveのorder(ここでのorderは自己回帰の繰り返しの回数?)に制限される(convolutionは先ほどの定式化だとヤコビアン行列式の計算が非現実的).そのためここではより自由度の高いinvertibleなconvolutionであるemerging convolutionを提案する.概要としては種々のautoregressive convolutionをつなぎ合わせるというもの(Figure 5がその図).

Square emerging convolutions

Autoregressive convはフィルタの形状が正方形ではないことが多く,ここで提案しているautoregressive convの組み合わせで得られる受容野(論文ではemerging receptive fieldと呼称)はFigure 5のように様々な形状をとる.それに対し,通常のDNNの枠組みにおいてはフィルタや得られる受容野の形状は正方形を成すことが多く,ライブラリ等における実装も正方形に最適化されている.実はある特定の二つのautoregressive convolutionによるemerging receptive fieldは正方形になるとのこと(どのようなconvかはFigure 5の一番下参照).これをsquare emerging convolutionとする.

d\times dのsquare emerging convolutionは二つの\frac{d+1}{2}\times\frac{d+1}{2}convolutionとして表現することが可能.k_1,k_2をその二つの畳み込みとし,fを入力の特徴マップとすると次のような計算として表現可能.

\displaystyle
k_2\star(k_1\star f)=(k_2\ast k_1)\star f

\star,\astはそれぞれcross-correlationとconvolutionを表す.

Invertible periodic convolutions

データによって周期性等を持つ場合があるため,そのようなデータにおいて効果的なinvertible periodic convolutionを提案.周期性のあるデータに対しては周波数空間における計算として扱えばヤコビアン行列式も逆変換も扱いやすいという性質に基づくもの.別な言い方をすれば,周期性を仮定すれば重畳積分定理からデータとフィルタをDFTすれば要素積として計算可能ということ.これだけのことではあるが一応論文で丁寧に説明しているので軽く.

フィルタ\mathbf{w}と入力\mathbf{x}における畳み込み(cross-correlations)は次のように表現可能.

\displaystyle
\mathbf{z}_{out}=\sum_{c_{in}}\mathbf{w}_{c_{out},c_{in}}\star\mathbf{x}_{c_{in}}

\mathcal{F}(\cdot)フーリエ変換\mathcal{F}^{-1}(\cdot)を逆フーリエ変換とする.また,入力出力フィルタのフーリエ変換\hat{\mathbf{x}}_{c_{in}}=\mathcal{F}(\mathbf{x}_{c_{in}}),\hat{\mathbf{z}}_{c_{out}}=\mathcal{F}(\mathbf{z}_{c_{out}}),\hat{\mathbf{w}}_{c_{out},c_{in}}=\mathcal{F}(\mathbf{w}^\ast_{c_{out},c_{in}})とする.ただし,CNNにおける畳み込みはcross-correlationなので\mathbf{w}^\ast_{c_{out},c_{in}})をspatial domainにおいて反転したフィルタとして,convolutionのフーリエ変換として扱えるようにしている.すると周波数空間において元の計算は次のようになる.

\displaystyle
\hat{\mathbf{z}}_{c_{out}}=\sum_{c_{in}}\hat{\mathbf{w}}_{c_{out},c_{in}}\odot\hat{\mathbf{x}}_{c_{in}}

これをchannelを考慮して\hat{z}_{:,uv},\hat{\mathbf{W}}_{:,:,u,v},\hat{\mathbf{x}}_{:,uv}として表せば次のような行列式の形で表現できる.

\displaystyle
\hat{\mathbf{z}}_{:,:,uv}=\hat{\mathbf{W}}_{uv}\hat{\mathbf{x}}_{:,uv}

\hat{\mathbf{W}}c_{out}\times c_{in}次元の行列で,\hat{\mathbf{x}}_{:,uv},\hat{\mathbf{z}}_{:,uv}はそれぞれc_{in},c_{out}次元のベクトル.最後に\mathcal{F}^{-1}(\hat{\mathbf{z}}_{c_{out}})と逆フーリエ変換すれば元の計算に一致する.フーリエ変換を利用した計算によりヤコビアン行列式は次のようになる.

\displaystyle
\log\left|\det\frac{\partial\mathbf{z}}{\partial\mathbf{x}}\right|=\log\left|\det\frac{\partial\hat{\mathbf{z}}}{\partial\hat{\mathbf{x}}}\right|=\sum_{u,v}\log\left|\det\hat{\mathbf{W}}_{uv}\right|

また,逆変換は逆行列を使って次のように計算できる.

\displaystyle
\hat{\mathbf{x}}_{:,uv}=\hat{\mathbf{W}}_{uv}^{-1}\hat{\mathbf{z}}_{:,uv}

でこの計算自体はある意味でglowにおける1\times 1convの計算と同じなので,glowで提案されたLU分解を利用した計算が可能.ただしここではQR分解に修正したバージョンを利用.というのもglowのLU分解はpermutation行列が固定なので自由度が下がるということを問題視したため.QR分解を使った場合,重み行列は\mathbf{W}=\mathbf{Q}(\mathbf{R}+\mathrm{diag}(\mathbf{s}))と表現される.直行行列\mathbf{Q}はHouseholder reflectionから\mathbf{Q}=\mathbf{Q}_1\mathbf{Q}_2\dots\mathbf{Q}_nとして得られ,\mathbf{Q}_iは次のように表現される.

\displaystyle
\mathbf{Q}_i=\mathbf{I}-2\frac{\mathbf{v}_i\mathbf{i}^T}{\mathbf{v}^T_i\mathbf{v}_i}

\mathbf{v}_iは学習パラメータ.このQR分解を利用した表現により,ヤコビアン行列式\logh\cdot w\cdot\mathrm{sum}(\log|\mathbf{s}|)として計算できる.

まとめ

Spatial domainでのgraph convolutionは確かに逆変換可能で盲点.畳み込みにautoregressiveの要請を入れて三角行列にすることで行列式逆行列の計算を簡略化したのもなるほどな点.

Bottom-up Object Detection by Grouping Extreme and Center Pointsを読んだのでメモ

はじめに

Bottom-up Object Detection by Grouping Extreme and Center Pointsを読んだのでメモ.

気持ち

Object detectionにbounding boxを使うのは良くないのでextreme pointsと中心位置の4点をheatmapとして回帰しようというもの.Extreme pointというのはICCV2017のExtreme clicking for efficient object annotationで提案された概念で,物体の一番上の点,下の点,左の点,右の点を表す(詳しくは論文の図を参照).このextreme pointsで物体検出することでよりタイトで無駄のない物体検出が可能とのこと.特にextreme pointsベースの検出のいいところとしては,物体の種類ごとにextreme pointの位置が大体決まっているということ.例えば人ならtopとbottomを表すextreme pointsは大体頭と足に来るため,bounding boxでの検出よりも優しいタスクになっていると考えられる.また,CVPR2018で提案されたextreme pointを使ったsegmentation手法であるDeepExtremeCutを使えば物体のマスクを生成することが可能で,それによってinstance segmentationによってstate-of-the-artを達成したというもの.

ちなみに乱暴に要約するとCornerNetの出力をextreme pointsに変えたというもの.

ExtremeNet

ExtremeNetの出力は5\times Cのheatmapsと4\times 2のoffset maps.Cはクラス数.heatmapsはextreme pointsと中心位置を表すもので,offset mapsはextreme pointsのsub-pixelレベルでの精度を保証するもので,cornerNetで提案されているものと同じ.cornerNetについては以前にメモをした.検出されたextreme pointsは最終的に幾何的な整合性が取れるようにグルーピングされる.

Extreme pointsのグルーピングは検出された各extreme pointsから計算される中心位置c=\left(\frac{l_x+t_x}{2},\:\frac{t_x+b_x}{2}\right)が検出された中心位置\hat{Y}^{(c)}から閾値\tau_c以内に存在すればそのextreme pointsを一つの物体のextreme pointsとしてまとめる.ちなみにextreme pointsはheatmapとして得られるため,座標として得るために3x3のパッチ内で閾値\tau_pを超える値を持つ最大の点を探索する.

Edge aggregation

Extreme pointsの検出は常にユニークなものではなく,一つの物体に対して複数の点が現れる場合がある.特に車など直線的な物体ではその直線上のどこでもextreme pointになり得る場合がある.そのため,そのような冗長なextreme pointsはエッジに沿って並んでいるという仮定を置いて,エッジ上に並ぶextreme pointsを統合することを考える.方法としては通常通り閾値\tau_pを超えるextreme pointsを選び,水平もしくは垂直方向にextreme pointsのスコアをaggregateするというもの.ただ,闇雲に水平または垂直方向に足していっても仕方がないので,aggregate対象のextreme pointsのスコアが単調減少し続ける限りaggregateする.つまりここでは暗に,真のextreme pointを中心にガウス分布的に冗長なextreme pointsが存在するということを仮定している.スコアは単純な足し合わせではなく次のように\lambda_{aggr}を導入している.

\displaystyle
\tilde{Y}_m=\hat{Y}_m+\lambda_{agar}\sum_{i=I_0}^{i_1}N_i^{(m)}

Y_mは注目しているextreme pointを表し,N_iはaggregate対象のextreme point.\lambda_{aggr}は実験では0.1に設定したとのこと.

学習

Extreme pointをアノテーションしたデータセットは公開されていないのでCOCOのpolygon maskを利用してextreme pointの真値を作成したとのこと.論文を書いたチームは計算資源があまりないのでpretrain済みのCornerNetを5台のGPUで250kイテレーションのfine-tuneをしたとのこと(CornerNetの学習は10GPUで500kイテレーション140GPU日).

まとめ

内容的にはcornerNetの検出対象をbounding boxの隅ではなくextreme pointにしたというもの.そのおかげでcorner poolingのようなトリッキーなことをせず精度よく検出が可能になった.またDeepExtremeCutとの相性が良く,Instance Maskまで高精度で生成が可能というもの.

Neural Relational Inference for Interacting Systemsを読んだのでメモ

はじめに

Neural Relational Inference for Interacting Systemsを読んだのでメモ.ニューラルネットにより要素間のインタラクションを推定したいというもの.例としてバスケットボールをあげていて,ある選手のダイナミクスは相手選手や味方の選手の動きや位置によって影響を受ける.そのような関係性を教師なしで獲得するというのが論文のゴール.実験ではトラジェクトリ推定などに適用.

Neural Relational Inference Model

提案モデルであるNeural Relational Inference model(NRI model)は,入力に現時刻までの(バスケを例にとれば)各選手のトラジェクトリが与えられ,出力として次の時刻の各選手の位置と速度を返すようなモデル.これをVAEを使って実現するというもの.肝としては,VAEの潜在表現\mathbf{z}を各選手間のインタラクションを表すグラフ構造にしたというもの.

時刻tにおけるN個の物体の観測値(位置や速度)を\mathbf{x}^t=\{\mathbf{x}_1,\dots,\mathbf{x}_N^t\}で表し,物体v_iのトラジェクトリを\mathbf{x}_i=(\mathbf{x}_i^1,\dots,\mathbf{x}_i^T),全トラジェクトリは\mathbf{x}=(\mathbf{x}^1,\dots,\mathbf{x}^T)で表す.物体v_iv_j間の関係性を離散エッジ\mathbf{z}_{ij}で表す.各物体のダイナミクスは未知のグラフ\mathbf{z}を入力とするgraph neural network(GNN)でモデル化することを目標とする.

目的関数は次のVAEの目的関数がベースとなる.

\displaystyle
\mathcal{L}=\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})]-\mathrm{KL}[q_\phi(\mathbf{z}|\mathbf{x})\| p_\theta(\mathbf{z})]

ただし,今回の問題設定では時系列を扱うかつ潜在変数が離散エッジのため各分布は少し特殊となる.エンコーダーq_\phi(\mathbf{z}|\mathbf{x})は因子分解された\mathbf{z}_{ij}の分布を返す.\mathbf{z}_{ij}v_i,v_j間のエッジのタイプを表すカテゴリカル分布に従う確率変数で,具体的にはK個のインタラクションのタイプを表すonehot表現になる.

デコーダーは物体間の関係性と各物体のトラジェクトリを入力として次の時刻の各物体の状態を返す.すなわち次のように表現される.

\displaystyle
p_\theta(\mathbf{x}|\mathbf{z})=\prod_{t=1}^Tp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^{t},\dots,\mathbf{x}^1,\mathbf{z})

priorはp_\theta(\mathbf{z})=\prod_{i\neq j}p_\theta(\mathbf{z}_{ij})として表現される因子分解された形の一様分布.

Encoderについて

エンコーダーは次の4つの演算から構成される.

\displaystyle
\mathbf{h}_j^1=f_\mathrm{emb}(\mathbf{x}_j)\\
v\rightarrow e:\mathbf{h}_{(i,j)}^1=f_e^1([\mathbf{h}_i^1,\mathbf{h}_j^1])\\
e\rightarrow v:\mathbf{h}_j^2=f_v^1(\sum_{i\neq j}\mathbf{h}_{(i,j)}^1)\\
v\rightarrow e:\mathbf{h}_{(i,j)}^2=f_e^2([\mathbf{h}_i^2,\mathbf{h}_j^2])

最初の演算は各オブジェクトの観測値を特徴量に変換する計算.次の式は物体v_i,v_j間の関係性を記述する計算(物体vからエッジe).3番目は得られた物体間の関係から物体の特徴量を得るする演算(エッジeから物体v).最後に物体間の関係性(インタラクション)を得る(物体vからエッジe).得られた物体間の関係性を表す特徴量\mathbf{h}_{(i,j)}^2をsoftmax関数にかけることでインタラクションに関する分布q_\phi(\mathbf{z}_{ij}|\mathbf{x})=\mathrm{softmax}(\mathbf{h}^2_{(i,j)})を得る.ここでエッジ->物体->エッジの順で特徴量を計算しているのは,2番目の式から得られるエッジの特徴量は2つの物体間の情報しか入っていないため他のノードとのインタラクションが無視されているためとのこと.繰り返すことで物体全体のインタラクションが考慮可能ということ.各関数は全てニューラルネットにより表現(全結合か1次元のCNN).

今回はVAEの構造を考えているためエンコーダーから得られた分布q_\phi(\mathbf{z}_{ij}|\mathbf{x})からサンプリングする必要がある.この論文ではconcrete distributionによってサンプリングを行ったとのこと.concrete distributionについては過去に論文を読んでメモしたのでそちらを参照.

Decoderについて

最初に書いたように今回のデコーダーはp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\dots,\mathbf{x}^1,\mathbf{z})と表され,潜在変数と現時刻までのトラジェクトリの情報を入力とする.今回はマルコフ性を仮定してp_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\dots,\mathbf{x}^1,\mathbf{z})=p_\theta(\mathbf{x}^{t+1}|\mathbf{x}^t,\mathbf{z})という表現に置き換える.

今回のデコーダーは次の3つの演算から構成される.

\displaystyle
v\rightarrow e:\tilde{\mathbf{h}}_{(i,j)}^t=\sum_kz_{ij,k}\tilde{f}_e^k([\mathbf{x}_i^t,\mathbf{x}_j^t])\\
e\rightarrow v:\mathbf{\mu}_j^{t+1}=\mathbf{x}_j^t+\tilde{f}_v(\sum_{i\neq j}\tilde{h}_{(i,j)}^t)\\
p(\mathbf{x}_j^{t+1}|\mathbf{x}^t,\mathbf{z})=\mathcal{N}(\mathbf{\mu}_j^{t+1},\sigma^2\mathbf{I})

最初の計算は関係性z_{ij,k}を考慮した2物体のインタラクションを表す.ここでのkはインタラクションの種類を表し,理想的にはz_{ij,k}はonehotベクトルだが今回は連続値に緩和しているため各インタラクションの種類に関する重み付き和として計算.2番目の計算は関係する全ての物体とのインタラクションを考慮して現時刻から次の時刻での物体の状態を表現.最終的に固定の分散\sigma^2を持つ正規分布として次の時刻の状態を表現.

ここでの問題は,デコーダーが潜在変数を無視するlatent variable collapse(その他KL collapseなどとも呼ばれる)を起こしてしまうこと.特に,short-termでの等速直線運動を仮定したモデルでも次の予測がそれなりの精度で行えてしまうように,インタラクション\mathbf{z}を無視してもある程度所望の出力を得ることが可能になってしまう.なのでこの問題を回避するため,デコーダーは複数ステップ(Mステップ)の予測をするように学習.そうすることでshort-termでは成り立つようなモデルがうまく機能しなくなり,インタラクション\mathbf{z}の情報に頼らざるを得なくなるというもの.具体的には,次のようにある時刻の出力を使って再帰的に予測していく.

\displaystyle
\mathbf{\mu}_j^2=f_\mathrm{dec}(\mathbf{x}_j^1)\\
\mathbf{\mu}_j^{t+1}=f_\mathrm{dec}(\mathbf{\mu}_j^t)\ \ \ t=2,\dots,M\\
\mathbf{\mu}_j^{M+2}=f_\mathrm{dec}(\mathbf{x}_j^{M+1})\\
\mathbf{\mu}_j^{t+1}=f_\mathrm{dec}(\mathbf{\mu}_j^t)\ \ \ t=M+2,\dots,2M

今回はM=10で実験したとのこと.

Recurrent decoder

悲しいかな,実問題を考えた場合マルコフ性が成り立たない場合が多く存在する.なのでここではデコーダーにRNNを導入して全ての系列を考慮可能なデコーダーを考える.具体的にはGRUを使って次のような演算でデコーダーを構成する.

\displaystyle
v\rightarrow e:\tilde{\mathbf{h}}_{(i,j)}^t=\sum_kz_{ij,k}\tilde{f}_e^k([\tilde{\mathbf{h}}_i^t,\tilde{\mathbf{h}}_j^t])\\
e\rightarrow v:\mathrm{MSG}_j^t=\sum_{i\neq j}\tilde{\mathbf{h}}_{(i,j)}^t\\
\tilde{\mathbf{h}}_j^{t+1}=\mathrm{GRU}([\mathrm{MSG}_j^t,\mathbf{x}_j^t],\tilde{\mathbf{h}}_j^t)\\
\mathbf{\mu}_j^{t+1}=\mathbf{x}_j^t+f_\mathrm{out}(\tilde{\mathbf{h}}_j^{t+1})\\
p(\mathbf{x}^{t+1}|\mathbf{x}^t,\mathbf{z})=\mathcal{N}(\mathbf{\mu}^{t+1},\sigma^2\mathbf{I})

最初の演算はインタラクションに関する隠れ層(hidden variable)の計算.2から4番目の計算がrecurrent部分で,インタラクションに関する隠れ層の値の集約値\mathrm{MSG}^t_jと観測値\mathbf{x}_j^tをGRUに入力,その出力から次の時刻の状態を得るという流れ.

まとめ

トイデータでの実験結果がめちゃくちゃ美しい.実問題での性能はわからないが実験ではLSTMと比べて飛躍的に精度が向上している.

Understanding the Effective Receptive Field in Deep Convolutional Neural Networksを読んだのでメモ

はじめに

Understanding the Effective Receptive Field in Deep Convolutional Neural Networksを読んだのでメモ.

気持ち

CNNにおけるreceptive fieldをちゃんと解析しようというもの.ここではeffective receptive field(ERF)という上位の概念を導入する.ERFは通常のreceptive fieldの考え方と違い,出力にどれだけ貢献しているかに着目してreceptive fieldの大きさを考える.

Effective receptive field

p番目の層の特徴マップの位置(i,j)の値をx_{i,j}^pとし,入力をx_{i,j}^0,出力をy_{i,j}=x_{i,j}^nとする.ただし,中心ピクセル(0,0)とする.ERFは出力に影響を与えた入力の領域として定義する.ここでの影響は出力の中心ピクセルの入力に対する勾配\partial y_{0,0}/\partial x_{i,j}^0によって測る.議論を簡単にするためここでは,任意の損失関数lに対して,\partial l/\partial y_{0,0}=1,\partial l/\partial y_{i,j}のように,出力の中心ピクセル以外は勾配を考えないものとする.また,g(i,j,p)p層目の特徴マップの(i,j)ピクセルの勾配\partial l/\partial x_{i,j}^pとする.

まず単純なケースとしてカーネルサイズk\times kの畳み込み層n層からなるモデルを考える.つまりプーリングや活性化関数はないものとする.さらに特殊な場合として全ての重みが1でバイアスがないものとする.するとこの場合は重み行列がランク落ちするため1Dの畳み込みに分解可能.なので以下1Dの畳み込みとして考える.初期の勾配(中心のみ1でその他0)をu(t),畳み込みフィルタをv(t)とすると畳み込みはデルタ関数を使って

\displaystyle
y(t)=\delta (t),\ v(t)=\sum_{m=0}^{k-1}\delta(t-m),\ \mathrm{where}\:\delta(t)=\left\{\begin{matrix}1,\ t=0\\ 0,\ t\neq 0\end{matrix}\right.

と表現できる.ただし,t=0,1,-1,2,-2ピクセルのインデックスを表す.今畳み込み層のみで構成されたモデルを考えているため入力に対する勾配はo=u\ast v\ast\dots\ast vとして畳み込みの繰り返しを使って表現できる.この畳み込みは離散時間フーリエ変換

\displaystyle
U(\omega)=\sum_{t=-\infty}^\infty u(t)e^{-j\omega t}=1,\ V(\omega)=\sum_{t=-\infty}^\infty v(t)e^{-j\omega t}=\sum_{m=0}^{k-1}e^{-j\omega m}

を使って,重畳積分定理から次のように周波数空間で計算できる.

\displaystyle
\mathcal{F}(o)=\mathcal{F}(u\ast v\ast\dots\ast v)(\omega)=U(\omega)\cdot V(\omega)^n=\left(\sum_{m=0}^{k-1}e^{-j\omega m}\right)^n

これを逆フーリエ変換すれば欲しい勾配が計算可能.

\displaystyle
o(t)=\frac{1}{2\pi}\int^\pi_{-\pi}\left(\sum_{m=0}^{k-1}e^{-j\omega m}\right)^ne^{j\omega t}d\omega\\
\frac{1}{2\pi}\int^\pi_{-\pi}e^{-j\omega s}e^{j\omega t}d\omega=\left(
\begin{matrix}
1,\ s=t\\
0,\ s\neq t
\end{matrix}\right.

k=2の場合ではこの勾配は2次元ガウス分布に従う値となり,k\lt 2の場合にもガウス分布のような分布になる.細かいことは別な文献に譲っていたのでここでは割愛.

次に重みの値がランダムな場合を考える.この時勾配は次のようにかける.

\displaystyle
g(i,j,p-1)=\sum_{a=0}^{k-1}\sum_{b=0}^{k-1}w_{a,b}^pg(i+a,i+b,p)

w_{alb}^pp層目の位置(a,b)の重みを表す.初期の重みは平均0,分散Cの固定の分布から独立にサンプリングされているものとし,勾配gは重みと独立であると仮定する.この仮定は非線形関数のないモデルだから成り立つ仮定で一般的ではないことに注意.重みの平均が0,勾配と重みが独立であるという仮定から勾配の期待値は

\displaystyle
\mathbb{E}_{w,input}[g(i,j,p-1)]=\sum_{a=0}^{k-1}\sum_{b=0}^{k-1}\mathbb{E}_w[w_{a,b}^p]\mathbb{E}_{input}[g(i+a,i+b,p)]=0,\forall p

となる.また,分散は

\displaystyle
\mathrm{Var}[g(i,j,p-1)]=\sum_{a=0}^{k-1}\sum_{b=0}^{k-1}\mathrm{Var}[w_{a,b}^p]\mathrm{Var}[g(i+a,i+b,p)]=C\sum_{a=0}^{k-1}\sum_{b=0}^{k-1}\mathrm{Var}[g(i+a,j+b,p)]

となる.つまり,receptive fieldはガウス分布の形を取る.また,分散の値を見ればn層の場合には分散の値がC^nとなる.

さらに一般的な状況として,各重みが異なる値,異なる分散を持つとする.ここでは再び1次元の畳み込みv(t)=\sum_{m=0}^{k-1}w(m)\delta(t-m)を考える.ここでは一般性を失わず重みが\sum_mw(m)=1と正規化されているとする.

先ほどと同様フーリエ変換と重畳積分定理から

\displaystyle
U(\omega)\cdot V(\omega)\dots V(\omega)=\left(\sum_{m=0}^{k-1}w(m)e^{-j\omega m}\right)^n

が得られる.これは一番単純なケースでの解析に対して重みw(m)がかかっただけの違い.w(m)が正規化されている場合,o(t)p(S_n=t)という確率に等しくなるらしい.ただし,S_n=\sum_{i=1}^nX_iの関係を満たし,X_iw(m)に関するi.i.dなmultinomial variable.この辺りも詳しいことは参考文献に譲っているため細かいことは割愛.重要な点は重みがp(S_n=t)という分布に従い,これは中心極限定理から層の数が増えればガウス分布\mathcal{N}(0,\mathrm{Var}[X])として表現できるということ.するとこのガウス分布は次のような平均と分散を持つ.

\displaystyle
\mathbb{E}[S_n]=n\sum_{m=0}^{k-1}mw(m),\ \mathrm{Var}[S_n]=n\left(\sum_{m=0}^{k-1}m^2w(m)-\left(\sum_{m=0}^{k-1}mw(m)\right)^2\right)

ここで面白いのは,effective receptive fieldはこの分布の標準偏差だけ広がると考えるとERFのサイズは\mathcal{O}(\sqrt{n})のオーダーになるということ.通常のreceptive fieldが層の数に線形,すなわち\mathcal{O}(n)に従って広がることを考えるとERFのサイズは1/\sqrt{n}だけ小さくなっていることがわかる.

ここで重みがw(m)=1/kと一様である場合にはERFの大きさは\mathcal{O}(k\sqrt{n})になる.

さらに非線形関数を導入した場合を考える.任意の非線形関数を\sigmaで表現する.ここでは一般的でないが解析のしやすさから,非線形関数を通ったのち畳み込みという順番で処理が進むとする.すると勾配は次のように計算できる.

\displaystyle
g(i,j,p-1)=\sigma_{i,j}^p\:'\sum_{a=0}^{k-1}\sum_{b=0}^{k-1}w_{a,b}^pg(i+a,j+b,p)

ただし\sigma_{i,j}^p\:'p層目のi,jピクセルに対する非線形関数の微分を表す.ReLUの場合には指示関数を使って\sigma_{i,j}^p\:'=\mathbf{I}[x_{i,j}^p\lt 0]と表すことができる.さらなる仮定としてx_{i,j}^p\:'が平均0で分散が1の対象な分布に従うとし,\sigma'は重みと上の層の勾配gに独立であるとする.すると,分散は

\displaystyle
\mathrm{Var}[g(i,j,p-1)]=\mathbb{E}[\sigma_{i,j}^p\:'^2]\sum_a\sum_b\mathrm{Var}[g(i+a,j+b,p)]

となる.また,\mathbb{E}[\sigma_{i,j}^p\:'^2]=\mathrm{Var}[\sigma_{i,j}^p\:']=1/4と定数になる.

Experiments

基本的には解析結果と一致して導出したガウス分布に従うERFを持つ.ただし,ReLUを使った場合にはERFがガウス分布に従わないという結果が得られたが,これはReLUにより値が0になった部分は勾配を伝えないことに起因するという.また,cifar10等で実験した場合,receptive fieldは画像サイズより大きくなるはずがERFは画像サイズ未満の領域しか持たないことも確認した.興味深いのは学習が進むにつれてERFが広がっていくらしい.

解析結果から新しい初期化方法として畳み込みカーネルの中心に比べ周りの値を大きくするという初期化方法を実験的に検証.条件次第では学習速度が速くなるらしいが基本的にはそんなに恩恵はないとのこと.

まとめ

一部の細かい議論が参考文献に投げているのと自分の数学力のなさから少しわかりにくかった.解析的に出力への依存が高い領域を求めていて,なかなか一般的な状況までは解析できていなくとも非常に有用な内容だった.特にERFが層の数の平方根に比例するというのは興味深い知見.

強化学習勉強まとめ

主にOpenAIが公開している強化学習のプログラムであるSpinning upで勉強してみたメモのまとめ.

その1 Introduction to RL Part1についてのメモ

その2 Introduction to RL Part2についてのメモ

その3 Introduction to RL Part3についてのメモ

その4 Algorithms DocsのVPGについてのメモ

その5 Algorithms DocsのDDPGについてのメモと実装

その6 Algorithms DocsのTD3についてのメモと実装

その7 Algorithms DocsのSACについてのメモと実装

OpenAIのSpinning Upで強化学習を勉強してみた その7

はじめに

その7ということで今度はSoft Actor-Critic(SAC)をpytorchで実装する.

Soft Actor-Critic

SACはTD3とほぼ同時期にpublishされた論文.内容の肝としてはDDPGをベースにentropy regularizationを加えたというもの.簡単に言ってしまえば報酬に対して確率的な方策のエントロピーを加えるというもの.なので価値関数とQ関数は次のように表現される.

\displaystyle
H(P)=\mathbb{E}_{x\sim P}[-\log P(x)]\\
V^\pi(s)=\mathbb{E}_{\tau\sim\pi}\left[\left.\sum_{t=0}^{\infty}\gamma^t\left(R(s_t,a_t,s_{t+1})+\alpha H(\pi(\cdot|s_t))\right)\right| s_0=s\right]\\
Q^\pi(s,a)=\mathbb{E}_{\tau\sim\pi}\left[\left.\sum_{t=0}^\infty\gamma^tR(s_t,a_t,s_{t+1})+\alpha\sum_{t=1}^\infty\gamma^tH(\pi(\cdot|s_t))\right| s_0=s,a_0=a\right]

\alpha \lt 0はハイパーパラメータ.\alphaを大きな値にするとエントロピーを大きくしようとするため方策はランダムな値を取りやすくなる.この式においては価値関数とQ関数の関係は以下のようになる.

\displaystyle
V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)]+\alpha H(\pi(\cdot|s))

なのでQ関数に関するBellman方程式も次のように書き変えられる.

\displaystyle
Q^\pi(s,a)=\mathbb{E}_{s'\sim P}[R(s,a,s')+\gamma(Q^\pi(s',a')+\alpha H(\pi(\cdot|s')))]=\mathbb{E}_{s'\sim P}[R(s,a,s')+\gamma V^\pi(s')]

この式を利用してSACではpolicyと二つのQ関数と一つの価値関数の学習を行う.Q関数が二つあるのはTD3と同様の理由.

Qの学習

Q関数の学習はDDPG同様target networkを利用してMSBEの最小化を行う.ただし,今回Q関数のBellman方程式は価値関数を使ってかけるためここでのtarget networkは価値関数のtarget networkになる.なので目的関数は次のようにかける.

\displaystyle
L(\phi_i,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_i}(s,a)-(r+\gamma(1-d)V_{\phi_\mathrm{targ}}(s'))\right)^2\right]

Target networkはDDPGと同様,移動平均でパラメータを計算.

Vの学習

価値関数の学習は以下の価値関数とQ関数の関係を利用する.

\displaystyle
V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)]+\alpha H(\pi(\cdot|s))=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)-\alpha\log\pi(a|s)]

ここでの期待値計算は確率的な方策からのサンプリングを使って次のように近似する.

\displaystyle
V^\pi(s)\approx Q^\pi(s,\tilde{a})-\alpha\log\pi(\tilde{a}|s),\ \tilde{a}\sim\pi(\cdot|s)

なので方策はサンプリングしやすい分布である必要がある.ここでのQ関数はTD3と同様に二つのQ関数の最小値として計算する(clipped couble-Q).なのでVに関する目的関数は次のようになる.

\displaystyle
L(\phi,\mathcal{D})=\underset{s\sim\mathcal{D},a\sim\pi_\theta}{\mathbb{E}}\left[\left(V_\phi(s)-\left(\min_{i=1,2}Q_{\phi_i}(s,\tilde{a})-\alpha\log\pi_\theta(\tilde{a}|s)\right)\right)^2\right]

実装上の注意としてはサンプリングによる近似を用いているためreplay bufferのactionは使わないということ.

Policyの学習

Policyは今までと同様Q関数を最大とする行動を返すように学習するが,今回はentropy regularizationが入っているため次の値の最大化として学習する.

\displaystyle
\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)-\alpha\log\pi(a|s)]

ここでの計算にもサンプリングが必要になるが,実装ではpolicyにガウス分布を仮定しているため,VAEでおなじみのreparameterization trickを使うことで計算可能.ただpolicyからのサンプリング部分は実装を見ればわかるが学習を安定させるためのいくつかのtipsがあるので注意.具体的にはガウス分布の平均や分散,選択される行動の値がぶっ飛んだ値にならないようtanhやclippingによって有限の値になるように抑えている.ここら辺は言葉よりも実装見た方が早いので細かいことは割愛.

諸々を端折って結論だけ書くとpolicyの目的関数は次のようになる.

\displaystyle
\max_\theta\underset{s\sim\mathcal{D},\xi\sim\mathcal{N}}[Q_{\phi_1}(s,\tilde{a}_\theta(s,\xi))-\alpha\log\pi_\theta(\tilde{a}_\theta(s,\xi)|s)]

ただし,\tilde{a}_\theta(s,\xi)はpolicyからreparameterization tricktanhやclippingによるまるめ込みを使ってサンプリングされたaction.

更新の順番としてはVとQの更新->更新されたQを使ったpolicyの更新->target networkの更新の順番.

実装

以下,実装.policyがgaussianに変わったのとサンプリングに関する細かいテクニック以外は特に変わったところはない.

"""core.py"""
import torch
import torch.nn as nn

import math

EPS = 1e-8
LOG_STD_MAX = 2
LOG_STD_MIN = -20

class gaussian_policy(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        self.main = nn.Sequential(*layer)
        self.mu   = nn.Linear(hidden_layer[-1], act_dim)
        self.log_std = nn.Sequential(
            nn.Linear(hidden_layer[-1], act_dim),
            nn.Tanh()
        )

    def forward(self, obs):
        f = self.main(obs)
        mu, log_std = self.mu(f), self.log_std(f)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        
        pi = mu + torch.randn_like(mu) * log_std.exp()
        logp_pi = self.gaussian_likelihood(pi, mu, log_std)
        return mu, pi, logp_pi

    def gaussian_likelihood(self, pi, mu, log_std):
        return  torch.sum(-0.5 * (((pi-mu)/(log_std.exp()+EPS))**2 + 2 * log_std + math.log(2*math.pi)), 1)

class value_function(nn.Module):
    def __init__(self, inp_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(inp_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], 1))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class actor_critic(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2):
        super().__init__()
        self.policy = gaussian_policy(act_dim, obs_dim, hidden_layer)

        self.q1 = value_function(obs_dim+act_dim, hidden_layer)
        self.q2 = value_function(obs_dim+act_dim, hidden_layer)
        self.v  = value_function(obs_dim, hidden_layer)
        self.act_limit = act_limit

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        self.v_targ = value_function(obs_dim, hidden_layer)

        self.copy_param()

    def pass_gradient_clip(self, x, low, high):
        clip_high = (x > high).float()
        clip_low  = (x < low).float()
        return x + ((high - x) * clip_high + (low - x) * clip_low).detach()

    def squashing(self, mu, pi, logp_pi):
        mu = mu.tanh()
        pi = pi.tanh()
        logp_pi = logp_pi - (self.pass_gradient_clip(1 - pi**2, 0, 1) + 1e-6).log().sum(1)
        return mu, pi, logp_pi

    def copy_param(self):
        self.v_targ.load_state_dict(self.v.state_dict())

    def get_action(self, obs):
        mu, pi, logp_pi = self.policy(obs)
        mu, pi, logp_pi = self.squashing(mu, pi, logp_pi)
        mu = self.act_limit * mu
        pi = self.act_limit * pi
        return mu, pi, logp_pi

    def update_target(self, rho):
        # compute rho * targ_p + (1 - rho) * main_p
        for v_p, v_targ_p in zip(self.v.parameters(), self.v_targ.parameters()):
            v_targ_p.data = rho * v_targ_p.data + (1-rho) * v_p.data

    def compute_v_target(self, obs, alpha):
        _, pi, logp = self.get_action(obs)
        q1, q2 = self.q1(torch.cat([obs, pi], 1)), self.q2(torch.cat([obs, pi], 1))
        q = torch.min(q1, q2).squeeze()
        return (q - alpha * logp.squeeze()).detach()

    def compute_q_target(self, obs, gamma, rewards, done):
        # compute r + gamma * (1 - d) * V(s')
        return (rewards + gamma * (1-done) * self.v_targ(obs).squeeze()).detach()

    def q_function(self, obs, pi):
        q1, q2 = self.q1(torch.cat([obs, pi], 1)), self.q2(torch.cat([obs, pi], 1))
        return q1.squeeze(), q2.squeeze()

    def q_function_w_entropy(self, obs, alpha):
        _, pi, logp_pi = self.get_action(obs)
        q1 = self.q1(torch.cat([obs, pi], 1)).squeeze()
        H = -logp_pi * alpha
        return q1 + H.squeeze()
"""sac.py"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gym, time
import numpy as np

from spinup.utils.logx import EpochLogger
from core import actor_critic as ac

class ReplayBuffer:
    def __init__(self, size):
        self.size, self.max_size = 0, size
        self.obs1_buf = []
        self.obs2_buf = []
        self.acts_buf = []
        self.rews_buf = []
        self.done_buf = []

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf.append(obs)
        self.obs2_buf.append(next_obs)
        self.acts_buf.append(act)
        self.rews_buf.append(rew)
        self.done_buf.append(int(done))
        while len(self.obs1_buf) > self.max_size:
            self.obs1_buf.pop(0)
            self.obs2_buf.pop(0)
            self.acts_buf.pop(0)
            self.rews_buf.pop(0)
            self.done_buf.pop(0)

        self.size = len(self.obs1_buf)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(low=0, high=self.size, size=(batch_size,))
        obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs])
        obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs])
        acts = torch.FloatTensor([self.acts_buf[i] for i in idxs])
        rews = torch.FloatTensor([self.rews_buf[i] for i in idxs])
        done = torch.FloatTensor([self.done_buf[i] for i in idxs])
        return [obs1, obs2, acts, rews, done]

def ddpg(env_name, actor_critic_function, hidden_size,
        steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000, 
        max_ep_len=1000, logger_kwargs=dict()):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    replay_buffer = ReplayBuffer(replay_size)

    env, test_env = gym.make(env_name), gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_limit = int(env.action_space.high[0])

    actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit)

    value_optimizer = optim.Adam([
        {"params":actor_critic.q1.parameters()},
        {"params":actor_critic.q2.parameters()},
        {"params":actor_critic.v.parameters()}
    ], lr)
    policy_optimizer = optim.Adam(actor_critic.policy.parameters(), lr)

    start_time = time.time()

    obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    for t in range(total_steps):
        if t > 50000:
            env.render()
        if t > start_steps:
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            _, act, _ = actor_critic.get_action(obs_tens)
            act = act.detach().numpy().reshape(-1)
        else:
            act = env.action_space.sample()

        obs2, ret, done, _ = env.step(act)

        ep_ret += ret
        ep_len += 1

        done = False if ep_len==max_ep_len else done

        replay_buffer.store(obs, act, ret, obs2, done)

        obs = obs2

        if done or (ep_len == max_ep_len):
            for _ in range(ep_len):
                obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size)

                q_targ = actor_critic.compute_q_target(obs2_tens, gamma, rews_tens, done_tens)
                v_targ = actor_critic.compute_v_target(obs1_tens, alpha)

                q1_val, q2_val = actor_critic.q_function(obs1_tens, acts_tens)
                q_loss = 0.5 * (q_targ - q1_val).pow(2).mean() + 0.5 * (q_targ - q2_val).pow(2).mean()

                v_val  = actor_critic.v(obs1_tens).squeeze()
                v_loss = 0.5 * (v_targ - v_val).pow(2).mean()

                value_loss = q_loss + v_loss

                value_optimizer.zero_grad()
                value_loss.backward()
                value_optimizer.step()

                policy_loss = -actor_critic.q_function_w_entropy(obs1_tens, alpha).mean()
                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()


                logger.store(LossQ=q_loss.item(), Q1Vals=q1_val.detach().numpy(), Q2Vals=q2_val.detach().numpy())
                logger.store(LossV=v_loss.item(), VVals=v_val.detach().numpy())
                logger.store(LossPi=policy_loss.item())

                actor_critic.update_target(polyak)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # test_agent()
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('VVals', with_min_and_max=True) 
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Pendulum-v0')
    parser.add_argument('--hid', type=int, default=300)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='sac')
    args = parser.parse_args()

    from spinup.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    ddpg(args.env, actor_critic_function=ac,
        hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs,
        logger_kwargs=logger_kwargs)

まとめ

これでspinning upにあるQ-learningに関するアルゴリズムの実装は終了.Policy gradient関連はTRPOがヘシアンの計算等を必要として実装までは面倒でやらなさそう.ひとまずspinning upでの強化学習勉強はここで終わる予定.

Spinning upの感想としてはよくまとまっていてpolicy gradientからq-learningまで非常にわかりやすく解説されている気がする.前に強化学習を本で勉強しようとした時には長ったらしい理論的な背景をガンガン説明されて辟易したが,spinning upはアルゴリズムの理解と実装に焦点を当てて理論的な点は飛ばしているので直感的な理解が非常にしやすかった.逆にいえば細かい理論は飛ばされているのでその辺に関する知識をつけたい人にはあまり意味のない内容かと.