Image Generation from Scene Graphsを読んだのでメモ
はじめに
Image Generation from Scene Graphsを読んだのでメモ.言語から画像を生成する研究.複雑な文章からでも安定して画像生成ができるとのこと.
概要
ここではシーングラフとノイズを入力として画像を生成するモデルの構築と学習を目標とする.モデルの特徴としてはシーングラフをgraph convolutionで処理するところで,graph convolutionで埋め込まれた各物体の特徴ベクトルを使って物体とその関係性を考慮したsegmentation maskとBB(bounding box)を生成する.後は得られたsegmentation maskとBBとノイズを合わせて,cascaded refinement network (CRN)で画像を生成するというのが処理の大まかな流れ.
Scene Graphs
モデルの入力となるのは画像のシーングラフで,シーングラフは画像中に何の物体がいて,画像中の物体間にどのような関係性があるかを記述したグラフ.具体的には物体のカテゴリと関係性のカテゴリが与えられた時,グラフはで記述され,は物体の集合で,有向エッジを表し,で形成される.
この論文の問題設定では物体とその関係性を示すカテゴリは言語として与えられるため,自然言語で使われる埋め込み処理によりdenseな特徴ベクトルを作る.
Graph Convolution Network
単語の埋め込みベクトルを持ったシーングラフをend-to-endで処理するために,ここではオリジナルのgraph convolution networkを使う.
入力のグラフが持つ特徴ベクトルをとした時,3つの関数を使って出力を作る.を出力する関数は関係性のベクトルが二つの物体からのみ決められるため単純で,として3つのベクトルを入力とする関数.逆に物体に関するベクトルを更新する場合は,ひとつの物体が複数の物体と関係性を持つための更新よりも複雑になる.もっと言えば,今回は有向グラフを考えているため,入ってくるエッジと出て行くエッジで意味合いが変わってくる.当然ある物体に関するベクトルはと関係性を持つ全ての物体のベクトルを使って更新されるべきである.そこで,から伸びる全てのエッジに対してを使ってcandidate vectorを計算する.それと同様にに入るエッジに対してを使ってcandidate vectorを計算する.
このように各方向のエッジに関するcandidate vectorの集合が得られたらに関する出力をとして計算する.ただし,は入力の集合をひとつのベクトルにするpooling関数.この論文では各関数は3つのベクトルを入力とするニューラルネットで構成したとのこと.また,pooling関数は入力のベクトルを平均する関数として定義したらしい.
Scene Layout
Graph convolution networkによってシーングラフからいい感じの情報を抽出した後はグラフで表現された情報を画像に起こす必要がある.ここでは画像生成の第1段階としてシーングラフをシーンレイアウトへと変換することを考える.シーンレイアウトはsegmentation maskとBBから作られ,maskとBBの推定にobject layout networkを使う.
Object layout networkは物体に関する埋め込みベクトルを入力とし,のsoft binary mask とBBの座標を出力する.要はGANのgeneratorの入力ノイズがgraph convolution networkで計算された物体の特徴ベクトルになったということ.細かいことを言えば,maskを出力するネットワークはtranspose convolutionで構成されたネットワークで最終層の活性化関数はsigmoid,BBを出力するネットワークは一般的な多層ニューラルネット.Object layout networkによって出力されたmaskはBBの領域にwarpされ,最終的に全ての物体に関するマスクを統合することでシーンレイアウトを作る.学習中はBBはground-truthを使ってシーンレイアウトをつくることに注意.
Cascaded Refinement Network
シーンレイアウトができたら後はそれをリアルな画像に起こすだけ.ここでは従来手法のCascaded Refinement Network (CRN)を使うとのこと.Cascadedと言うように,解像度を2倍にして行くmoduleを多段に積み上げていて,各moduleはシーンレイアウトと前のmoduleの出力を入力とする(ただし最初のmoduleは前段の出力の代わりにガウスノイズを入力とする).
Discriminator
ここまでで,シーングラフからの特徴抽出,シーンレイアウトの生成,画像の生成という3段構成で画像を生成したが,このプロセスをまとめて画像生成ネットワークとしてend-to-endで学習する.学習はGANの枠組みに沿って行われ,discriminatorとしての二つを用意する.は画像全体を入力とし,は物体領域を切り取って固定サイズにリサイズしたものを入力とする.要は画像全体のそれっぽさと物体のそれっぽさをそれぞれ学習するということ.ただし,は物体の分類問題も同時に解く.
Training
モデルが3段階構成でそれぞれGTを必要とするのでロスは結構複雑.まとめると次の6つのロスの重み付き和を最小化するように学習する.
・Box Loss Object layout networkのBBに関するロスで loss
・Mask Loss Object layout networkのmaskに関するロスでcross-entropy loss
・Pixel loss 画像の再構成に関するロスで,生成画像と真の画像と loss
・Image adversarial loss に関するmin-max game
・Object adversarial loss に関するmin-max game
・Auxiliarly classifier loss の物体の分類問題に関するロス.
まとめ
これ学習できるのすごいなという気持ち.複雑なモデルを思いついても実際に学習させきる力がないからそういう力も養っていきたい.