Neural Nearest Neighbors Networksを読んだのでメモ
はじめに
Neural Nearest Neighbors Networksを読んだのでメモ.differentiableなKNN selection ruleを提案するというもの.NIPS 2018でgithubにコードも公開されている.
Differentiable -Nearest Neighbors
クエリとなるアイテムを,データベースの候補を,距離関数をとする.はデータベースの中にないものと仮定し,はクエリからの距離に従ってデータベースのアイテムのランキングを生成する.ここで[tes:\pi_q:I\rightarrow I]をの距離にしたがってソートを行うpermutationと定義する.
するとのKNNはpermutation において最初の個のアイテムの集合として与えられる.
KNNのselection ruleはdeterministicなため微分不可能で,距離に関する微分を伝播することができない.そこでこの問題をone-hotベクトルの連続値への緩和を使って解消する.
KNN rule as limit distribution
ここではKNN selection ruleをカテゴリカル分布の極限分布としてみなすことで緩和する.をデータベースのアイテムのインデックスに関するカテゴリカル分布とする.クエリアイテムとの距離をとし,温度パラメータを導入すれば,番目のアイテムを選択する確率は次のように与えられる.
Gumbel softmaxやconcrete distributionで見たような式だが,ここではアイテム間の距離によって決定的にサンプリングされるアイテムが決まるためgumbel分布からのサンプリングがない.の時はone-hotベクトルとなる.この式を1-NNのstochastic relaxationとしてみなすことができるため,これをの場合に拡張することでKNNを微分可能な形に緩和することができる.そのために条件付き分布へと拡張する.ここでは次のように計算される.
つまり一度選択されたアイテムは選択される確率が0に近くなるようクエリアイテムとの距離を離すというもの.実際がone-hotベクトルの場合にははにおいてになる.これによって番目のサンプルのインデックスは次のように得られる.
Index vector からのstochastic nearest neighbors を次のように定義する.
温度の時がone-hotベクトルとなり nearest neighborsと一致する.ただ基本的にはone-hotベクトルではないので多分mixup的な形になるはず(というかone-hotベクトルに近づけば近づくほど勾配が消失していくので意味がなくなる).
Neural Nearest Neighbors Block
今までの議論を踏まえてニューラルネットの一つの層としてneural nearest neighbors block ( blocks)を提案する. blocksは二つの部分からなり,一つはembedding networkで,もう一つは連続緩和されたnearest neighborsを計算する部分がある.
Embedding network
Embedding networkはでparameterizeされたCNNを使って入力から特徴ベクトルを得る部分.得られた特徴をとし,ユークリッド距離を用いてというpairwise distance matrix を得る.さらに,ここでは温度を計算するもう一つのネットワークを用意する.ただし,とは出力層以外は重み共有している.
Continuous nearest neighbors selection
距離行列,温度テンソル,入力の特徴から continuous nearest neighbors feature volumes を計算する.nearest neighborsとは入力の次元が同じなのでelement-wiseな操作によってまとめることができるが今回はチャネル方向に結合することで情報をまとめたとのこと.
block for image data
基本的に blockは入力のドメインに関係なく適用可能だが,画像データにカニsてはちょっとした修正が必要.従来画像におけるnon-local methodはパッチレベルで行われ,パッチレベルの処理には様々な利点が存在する.そのため画像に blockを適用する際にはim2colを使ってパッチレベルで計算するとのこと.
まとめ
基本的にはdeterministicなGumbel softmaxという感じ.実験ではdenoisingや超解像などの逆問題を解くのに使っていたが,いろいろなことに使えそう.ただ,複雑なことをしようとすると距離(というかembedding network)に何らかのpriorを持たせないと学習は難しそう.