機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

OpenAIのSpinning Upで強化学習を勉強してみた その6

はじめに

その6ということで今度はTwin Delayed DDPG(TD3)をpytorchで実装する.

Twin Delayed DDPG

DDPGは基本的にはいいアルゴリズムだが,時たま学習が破綻する場合があるとのこと.その理由としてはQ関数が学習初期において過大評価を行なってしまい,そこに含まれる誤差がpolicyを悪い方向へと引っ張ってしまうことがあげられる.なのでTD3ではちょっとしたテクニックでこれを回避する(つまり裏を返せば理論的な大きな変更はない).

Trick One: Clipped Double-Q Learning

これはQ関数を二つ用意してしまおうというもの.二つのQ関数だからアルゴリズムにtwinがつくとかなんとか.

Trick Two: Delayed Policy Updates

Policyのパラメータ更新をQ関数を複数回更新した後に行おうというもの.ここの実装ではQ関数を2回更新するごとにpolicyを1回更新する.

Trick Three: Target Policy Smoothing

Target actionに対してノイズを与えるというもの.これによりpolicyがQ関数に含まれる誤差に引っ張られることを防ぐとのこと.Smoothというのはpolicyの出力の周りでQ関数が似たような値を出力できるようにという意味かと.

Key Equations

まずtrick threeについて,target policyの出力に対し以下のようなガウスノイズを加える.

\displaystyle
A'(s')=\mathrm{clip}(\mu_{\theta_\mathrm{targ}}(s')+\mathrm{clip}(\epsilon,-c,c),a_{Low},a_{High}),\ \epsilon\sim\mathcal{N}(0,\sigma)

このノイズのおかげでQ関数がシャープなピークを持つことができなくなるという気持ちが入っている.

さらにtargetの計算としてclipped double-Q learning,つまり用意された二つのQ関数のうち小さい値をtargetの算出に使う.

\displaystyle
y(r,s',d)=r+\gamma(1-d)\min_{I=1,2}Q_{\phi,\mathrm{targ}}(s',a'(s'))

このtargetを使って各Q関数を次の値を最小化するよう学習.

\displaystyle
L(\phi_1,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_1}(s,a)-y(r,s',d)\right)^2\right]\\
L(\phi_1,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_2}(s,a)-y(r,s',d)\right)^2\right]

Policyはどちらか一方のQ関数を最大化するよう学習.ここではQ_{\phi_1}を最大化するように学習.

\displaystyle
\max_\theta\mathbb{E}_{s\sim\mathcal{D}}[Q_{\phi_1}(s,\mu_\theta(s))]

実装

実装はDDPGとほぼ一緒.

"""core.py"""
import torch
import torch.nn as nn

import copy

class continuous_policy(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], act_dim))
        layer.append(nn.Tanh())
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class q_function(nn.Module):
    def __init__(self, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], 1))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class actor_critic(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2):
        super().__init__()
        self.policy = continuous_policy(act_dim, obs_dim, hidden_layer)

        self.q1 = q_function(obs_dim+act_dim, hidden_layer)
        self.q2 = q_function(obs_dim+act_dim, hidden_layer)
        self.act_limit = act_limit

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        self.policy_targ = continuous_policy(act_dim, obs_dim, hidden_layer)
        self.q1_targ = q_function(obs_dim+act_dim, hidden_layer)
        self.q2_targ = q_function(obs_dim+act_dim, hidden_layer)

        self.copy_param()

    def copy_param(self):
        self.policy_targ.load_state_dict(self.policy.state_dict())
        self.q1_targ.load_state_dict(self.q1.state_dict())
        self.q2_targ.load_state_dict(self.q2.state_dict())
        # for m_targ, m_main in zip(self.policy_targ.modules(), self.policy.modules()):
        #     if isinstance(m_targ, nn.Linear):
        #         m_targ.weight.data = m_main.weight.data
        #         m_targ.bias.data = m_main.bias.data

        # for m_targ, m_main in zip(self.q_targ.modules(), self.q.modules()):
        #     if isinstance(m_targ, nn.Linear):
        #         m_targ.weight.data = m_main.weight.data
        #         m_targ.bias.data = m_main.bias.data

    def get_action(self, obs, noise_scale):
        pi = self.act_limit * self.policy(obs)
        pi += noise_scale * torch.randn_like(pi)
        pi.clamp_(max=self.act_limit, min=-self.act_limit)
        return pi.squeeze()

    def get_target_action(self, obs, noise_scale, clip_param):
        pi = self.act_limit * self.policy_targ(obs)
        eps = noise_scale * torch.randn_like(pi)
        eps.clamp_(max=clip_param, min=-clip_param)
        pi += eps
        pi.clamp_(max=self.act_limit, min=-self.act_limit)
        return pi.detach()

    def update_target(self, rho):
        # compute rho * targ_p + (1 - rho) * main_p
        for poly_p, poly_targ_p in zip(self.policy.parameters(), self.policy_targ.parameters()):
            poly_targ_p.data = rho * poly_targ_p.data + (1-rho) * poly_p.data

        for q_p, q_targ_p in zip(self.q1.parameters(), self.q1_targ.parameters()):
            q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data

        for q_p, q_targ_p in zip(self.q2.parameters(), self.q2_targ.parameters()):
            q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data


    def compute_target(self, obs, pi, gamma, rewards, done):
        # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
        q1 = self.q1_targ(torch.cat([obs, pi], -1))
        q2 = self.q2_targ(torch.cat([obs, pi], -1))
        q  = torch.min(q1, q2)
        return (rewards + gamma * (1-done) * q.squeeze()).detach()

    def q_function(self, obs, detach=True, action=None):
        # compute Q(s, a) or Q(s, mu(s))
        if action is None:
            pi = self.act_limit * self.policy(obs)
        else:
            pi = action
        if detach:
            pi = pi.detach()
        return self.q1(torch.cat([obs, pi], -1)).squeeze(), self.q2(torch.cat([obs, pi], -1)).squeeze()
"""main.py"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gym, time
import numpy as np

from spinup.utils.logx import EpochLogger
from core import actor_critic as ac

class ReplayBuffer:
    def __init__(self, size):
        self.size, self.max_size = 0, size
        self.obs1_buf = []
        self.obs2_buf = []
        self.acts_buf = []
        self.rews_buf = []
        self.done_buf = []

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf.append(obs)
        self.obs2_buf.append(next_obs)
        self.acts_buf.append(act)
        self.rews_buf.append(rew)
        self.done_buf.append(int(done))
        while len(self.obs1_buf) > self.max_size:
            self.obs1_buf.pop(0)
            self.obs2_buf.pop(0)
            self.acts_buf.pop(0)
            self.rews_buf.pop(0)
            self.done_buf.pop(0)

        self.size = len(self.obs1_buf)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(low=0, high=self.size, size=(batch_size,))
        obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs])
        obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs])
        acts = torch.FloatTensor([self.acts_buf[i] for i in idxs])
        rews = torch.FloatTensor([self.rews_buf[i] for i in idxs])
        done = torch.FloatTensor([self.done_buf[i] for i in idxs])
        return [obs1, obs2, acts, rews, done]

def td3(env_name, actor_critic_function, hidden_size,
        steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, 
        act_noise=0.1, target_noise=0.2, noise_clip=0.5, max_ep_len=1000,
        policy_delay=2, logger_kwargs=dict()):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    replay_buffer = ReplayBuffer(replay_size)

    env, test_env = gym.make(env_name), gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_limit = int(env.action_space.high[0])

    actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit)

    q1_optimizer = optim.Adam(actor_critic.q1.parameters(), q_lr)
    q2_optimizer = optim.Adam(actor_critic.q2.parameters(), q_lr)
    policy_optimizer = optim.Adam(actor_critic.policy.parameters(), pi_lr)

    start_time = time.time()

    obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    for t in range(total_steps):
        if t > 50000:
            env.render()
        if t > start_steps:
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            act = actor_critic.get_action(obs_tens, act_noise).detach().numpy().reshape(-1)
        else:
            act = env.action_space.sample()

        obs2, ret, done, _ = env.step(act)

        ep_ret += ret
        ep_len += 1

        done = False if ep_len==max_ep_len else done

        replay_buffer.store(obs, act, ret, obs2, done)

        obs = obs2

        if done or (ep_len == max_ep_len):
            for j in range(ep_len):
                obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size)
                # compute Q(s, a)
                q1, q2 = actor_critic.q_function(obs1_tens, action=acts_tens)
                # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
                pi_targ = actor_critic.get_target_action(obs2_tens, target_noise, noise_clip)
                q_targ = actor_critic.compute_target(obs2_tens, pi_targ, gamma, rews_tens, done_tens)
                # compute (Q(s, a) - y(r, s', d))^2
                q_loss = (q1-q_targ).pow(2).mean() + (q2-q_targ).pow(2).mean()

                q1_optimizer.zero_grad()
                q2_optimizer.zero_grad()
                q_loss.backward()
                q1_optimizer.step()
                q2_optimizer.step()

                logger.store(LossQ=q_loss.item(), Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy())

                if j % policy_delay == 0:
                    # compute Q(s, mu(s))
                    policy_loss, _ = actor_critic.q_function(obs1_tens, detach=False)
                    policy_loss = -policy_loss.mean()
                    policy_optimizer.zero_grad()
                    policy_loss.backward()
                    policy_optimizer.step()

                    logger.store(LossPi=policy_loss.item())

                    # compute rho * targ_p + (1 - rho) * main_p
                    actor_critic.update_target(polyak)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # test_agent()
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Pendulum-v0')
    parser.add_argument('--hid', type=int, default=300)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='td3')
    args = parser.parse_args()

    from spinup.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    td3(args.env, actor_critic_function=ac,
        hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs,
        logger_kwargs=logger_kwargs)

まとめ

急に実践的な改良の多い手法がきてこういう論文もあるのかという感じ.試した環境が単純なものすぎてDDPGに対するアドバンテージは結果から感じられなかったけど実装は単純でやりたいこともわかりやすいといえばわかりやすかった.

OpenAIのSpinning Upで強化学習を勉強してみた その5

はじめに

その5ということで今度はDeep Deterministic Policy Gradient(DDPG)をpytorchで実装する.

Deep Deterministic Policy Gradient

DDPGは今までと違いQ-learningの枠組みを取り入れた(論文の背景的にはQ-learningにpolicy gradientを取り柄れたという方が正確かも)アルゴリズムDQNで有名なQ-learningは行動が離散でないと適用ができないという課題があったが,policyを同時に学習することで行動が連続な課題に対してもQ-learningを適用可能にしたというもの.

Q-learning

簡単にQ-learningについて述べる(といってもspinning upにはQ-learningの解説はないので完全に個人的な理解だが).Q関数は現在の状態においてある行動をとったときにどれくらいの価値があるかを図る関数で,価値が最大となる行動を選択することでエージェントを動かす.この時,行動が離散で高々有限個の行動しかとれない場合には全ての行動に対して価値を計算することで最適な行動を得ることができる.行動が連続になると全ての行動に対して価値を計算するわけにはいかなくなるため今回のような戦略が必要になるというもの.

Q関数自体はbellman方程式に従ってQ関数を推定するようなモデルを学習する.すなわち次の方程式を満たすような関数Q_\phi(s,a)を学習する.

\displaystyle
Q^\ast(s,a)=\mathbb{E}_{s'\sim P}\left[r(s,a)+\gamma\max_{a'}Q^\ast(s',a')\right]

s'は状態sにおいて行動aをとった時に起こりうる状態を表す.この方程式を満たすためには次の最小化問題を解けばいい.

\displaystyle
\underset{\phi}{\min}\mathbb{E}_{(s,a,r,s',d)}\sim\mathcal{D}\left[\left(Q_\phi(s,a)-\left(r+\gamma(1-d)\max_{a'}Q_\phi(s',a')\right)\right)^2\right]

この目的関数は mean-squared Bellman error (MSBE)と呼ばれる.本題のDDPGに戻れば,DDPGはQ-learningを連続な行動に対して適用できるよう,Q関数を最大とするような行動を返すpolicyを同時に学習する.具体的には次の目的関数を最小化する.

\displaystyle
\underset{\phi}{\min}\mathbb{E}_{(s,a,r,s',d)}\sim\mathcal{D}\left[\left(Q_\phi(s,a)-\left(r+\gamma(1-d)\max_{a'}Q_\phi(s',\mu_\theta(s'))\right)\right)^2\right]

単に状態s'におけるQ関数を最大にする行動を選択する部分を\mu_\thetaというpolicyを使って求めるというもの.これにより,離散の時のような全ての行動に対する価値を陽に計算する必要がなくなるため連続な行動に対してもQ-learningを適用可能となる.

Policy自体はQ関数を最大にする行動を返すような関数であって欲しいため,次のような最大化の問題によって学習を行う.

\displaystyle
\max_\theta\mathbb{E}_{s\sim\mathcal{D}}[Q_\phi(s,\mu_\theta(s))]

なのでこのpolicyを学習する最大化問題とQ関数を学習する最小化問題を交互に最適化していけば良いpolicyを得ることができる.

注意としてこのDDPGは離散な行動の場合には利用することができない.これはQ関数を最大にする行動を返すようなpolicyを学習する際に,policyに対する勾配が(policyの出力をargmax等で離散表現に落とすため)計算不可能になってしまうため.ただ,離散の場合にはpolicyを導入することなく普通のQ-learningを行えばいいので特に問題はない気もする.

Tips

ここでは二つのテクニックを利用した学習法を用いている.

Replay Buffers

これは過去のエピソードを保存しておいて学習に利用することで過学習を防ぐとともに学習データのサンプリング効率を上げるというもの.

Target Networks

MSBEの最小化はbellman方程式の右辺をtargetとすれば,targetとtargetに近づけたい部分が同じパラメータを持っているため(すなわちQ_\phi(s,a)\left(r+\gamma(1-d)\max_{a'}Q_\phi(s',a')\right)の差を最小にしたいのにどちらも同じパラメータを持つQ_\phiがいるため)不安定な最適化となる.そこでtarget用のネットワークを新たに用意することで学習の安定化を図ろうというのがこの方法.

具体的には学習したいQ関数と全く同じパラメータを持つtarget network Q_{\phi_{targ}}を用意し,学習したいQ関数は通常通り勾配降下法で,target networkの方は次のような移動平均によってパラメータを計算する.


\phi_{targ}\leftarrow\rho\phi_{targ}+(1-\rho)\phi

\rhoはハイパーパラメータで基本的に1に近い値を用いる.このtarget networkはpolicy側にも用意することができ,ここでは最終的なMSBEを次のような形で与えている.

\displaystyle
\underset{\phi}{\min}\mathbb{E}_{(s,a,r,s',d)}\sim\mathcal{D}\left[\left(Q_\phi(s,a)-\left(r+\gamma(1-d)\max_{a'}Q_{\phi_{targ}}(s',\mu_{\theta_{targ}}(s'))\right)\right)^2\right]

Policyに関するtarget networkも移動平均によりパラメータを更新する.

Exploration vs. Exploitation

Q-learningはある意味で行動と状態を軸とするルックアップテーブルを埋めていく作業としてみることができる(と思う).なのでoff-policyで学習するとPolicyが決定的なものになってしまい多様な範囲を探索できなくなる恐れがある.なのでpolicyの出力にガウス雑音を加えることで行動を確率的なものとしてより広い範囲を探索できるようにするというもの.

実装

以下実装.spinning upのページにはアルゴリズムテーブルもあるので参考に.

"""core.py"""
import torch
import torch.nn as nn

import copy

class continuous_policy(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], act_dim))
        layer.append(nn.Tanh())
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class q_function(nn.Module):
    def __init__(self, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], 1))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class actor_critic(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2):
        super().__init__()
        self.policy = continuous_policy(act_dim, obs_dim, hidden_layer)

        self.q = q_function(obs_dim+act_dim, hidden_layer)
        self.act_limit = act_limit

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        self.policy_targ = continuous_policy(act_dim, obs_dim, hidden_layer)
        self.q_targ = q_function(obs_dim+act_dim, hidden_layer)

        self.copy_param()

    def copy_param(self):
        self.policy_targ.load_state_dict(self.policy.state_dict())
        self.q_targ.load_state_dict(self.q.state_dict())

    def get_action(self, obs, noise_scale):
        pi = self.act_limit * self.policy(obs)
        pi += noise_scale * torch.randn_like(pi)
        pi.clamp_(max=self.act_limit, min=-self.act_limit)
        return pi.squeeze()

    def update_target(self, rho):
        # compute rho * targ_p + (1 - rho) * main_p
        for poly_p, poly_targ_p in zip(self.policy.parameters(), self.policy_targ.parameters()):
            poly_targ_p.data = rho * poly_targ_p.data + (1-rho) * poly_p.data

        for q_p, q_targ_p in zip(self.q.parameters(), self.q_targ.parameters()):
            q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data

    def compute_target(self, obs, gamma, rewards, done):
        # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
        pi = self.act_limit * self.policy_targ(obs)
        return (rewards + gamma * (1-done) * self.q_targ(torch.cat([obs, pi], -1)).squeeze()).detach()

    def q_function(self, obs, detach=True, action=None):
        # compute Q(s, a) or Q(s, mu(s))
        if action is None:
            pi = self.act_limit * self.policy(obs)
        else:
            pi = action
        if detach:
            pi = pi.detach()
        return self.q(torch.cat([obs, pi], -1))
"""ddpg.py"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gym, time
import numpy as np

from spinup.utils.logx import EpochLogger
from core import actor_critic as ac

class ReplayBuffer:
    def __init__(self, size):
        self.size, self.max_size = 0, size
        self.obs1_buf = []
        self.obs2_buf = []
        self.acts_buf = []
        self.rews_buf = []
        self.done_buf = []

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf.append(obs)
        self.obs2_buf.append(next_obs)
        self.acts_buf.append(act)
        self.rews_buf.append(rew)
        self.done_buf.append(int(done))
        while len(self.obs1_buf) > self.max_size:
            self.obs1_buf.pop(0)
            self.obs2_buf.pop(0)
            self.acts_buf.pop(0)
            self.rews_buf.pop(0)
            self.done_buf.pop(0)

        self.size = len(self.obs1_buf)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(low=0, high=self.size, size=(batch_size,))
        obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs])
        obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs])
        acts = torch.FloatTensor([self.acts_buf[i] for i in idxs])
        rews = torch.FloatTensor([self.rews_buf[i] for i in idxs])
        done = torch.FloatTensor([self.done_buf[i] for i in idxs])
        return [obs1, obs2, acts, rews, done]

def ddpg(env_name, actor_critic_function, hidden_size,
        steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, 
        act_noise=0.1, max_ep_len=1000, logger_kwargs=dict()):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    replay_buffer = ReplayBuffer(replay_size)

    env, test_env = gym.make(env_name), gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_limit = int(env.action_space.high[0])

    actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit)

    q_optimizer = optim.Adam(actor_critic.q.parameters(), q_lr)
    policy_optimizer = optim.Adam(actor_critic.policy.parameters(), pi_lr)

    start_time = time.time()

    obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    for t in range(total_steps):
        if t > 50000:
            env.render()
        if t > start_steps:
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            act = actor_critic.get_action(obs_tens, act_noise).detach().numpy().reshape(-1)
        else:
            act = env.action_space.sample()

        obs2, ret, done, _ = env.step(act)

        ep_ret += ret
        ep_len += 1

        done = False if ep_len==max_ep_len else done

        replay_buffer.store(obs, act, ret, obs2, done)

        obs = obs2

        if done or (ep_len == max_ep_len):
            for _ in range(ep_len):
                obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size)
                # compute Q(s, a)
                q = actor_critic.q_function(obs1_tens, action=acts_tens)
                # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
                q_targ = actor_critic.compute_target(obs2_tens, gamma, rews_tens, done_tens)
                # compute (Q(s, a) - y(r, s', d))^2
                q_loss = (q.squeeze()-q_targ).pow(2).mean()

                q_optimizer.zero_grad()
                q_loss.backward()
                q_optimizer.step()

                logger.store(LossQ=q_loss.item(), QVals=q.detach().numpy())

                # compute Q(s, mu(s))
                policy_loss = -actor_critic.q_function(obs1_tens, detach=False).mean()
                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()

                logger.store(LossPi=policy_loss.item())

                # compute rho * targ_p + (1 - rho) * main_p
                actor_critic.update_target(polyak)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # test_agent()
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('QVals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Pendulum-v0')
    parser.add_argument('--hid', type=int, default=300)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='ddpg')
    args = parser.parse_args()

    from spinup.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    ddpg(args.env, actor_critic_function=ac,
        hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs,
        logger_kwargs=logger_kwargs)

Pytorchを使った実装している上で踏んだバグ?があったので注意.具体的には,最初target networkを定義する時に元のネットワークをcopy.deepcopy()することでtarget networkを作ったが,こうすると学習パラメータがtorchのnn.Paramterからtorch.Tensorに置き換わるというもの.このせいでtarget networkのパラメータを更新することができなくなり学習が進まなかった(このバグに気づくのに半日とかした).なのでこのコードではstate_dictを利用してネットワークのパラメータをコピーしている.

まとめ

Q-learningとadvantage functionを使ったpolicy gradientはどっちの方が筋がいいとかあるんだろうか.ここまでの感じだとQ-learning側はいくつかのテクニックを用いないと学習が安定しない分policy gradientの方が手法的には筋がいい?

OpenAIのSpinning Upで強化学習を勉強してみた その4

はじめに

その4ということで具体的なアルゴリズムの実装をpytorchでしてみる.今回はalgorithms docsの中にあるVanilla Policy Gradient(VPG)の実装をする.

Vanilla Policy Gradient

VPGはSpinning Upのintroduction to rlのpart 3で扱ったsimplest policy gradientの拡張.Policy gradientは勾配法を用いて得られる報酬を最大化するようなpolicyを求めるというもの.具体的にはpolicyに関する勾配は次のように与えられる.

\displaystyle
\nabla_\theta J(\pi_\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\log\pi_\theta(a_t|s_t)\Phi_t\right]

具体的な導出とnotationはイントロのpart 3を参照.\Phi_tにはrewardやQ関数などが使われる.前回のsimplest policy gradientではrewardを用いていたが今回のVPGではadvantage functionを用いる.advantage functionは次の形で与えらえる.

\displaystyle
A^{\pi}(s_t,a_t)=Q^{\pi}(s_t,a_t)-V^\pi(s_t)

Simplest policy gradientのように\Phiにrewardを用いた場合,得られる勾配は実際に取られたactionの尤度を最大化するような勾配になっており,必ずしも最適なpolicyを学習するとは限らない.つまり,取られた行動が全体として報酬を減らすような行動だったとしてもその尤度を最大化するように学習してしまう.それに対しadvantage functionは良いpolicyを学習することができる.というのも,価値関数がQ関数をactionに対して周辺化したものであることを思い出せば,このadvantage functionは実際に取られたactionが行動の平均的な価値より良いか悪いかを表現することができるため,価値の低い行動をとった場合にはその尤度を下げることが可能.この尤度を下げるという勾配は\Phiにrewardを使った場合には得られない.

ここでの問題はadvantage functionをどのようにして得るかということ.advantage function(というかQ関数と価値関数)は陽に求めることができないため何らかの方法で推定する必要がある.ここではGeneralized Advantage Estimation (GAE)という方法を用いてadvantage functionを推定する.細かい導出は長くなるのと論文に丁寧に描かれているのでここでは省略し,次の最終的な形だけ.

\displaystyle
\hat{A}_t^{\mathrm{GAE}(\gamma,\lambda)}:=-V(s_t)+\sum_{l=0}^\infty(\lambda\gamma)^lr_{t+l}

\lambda,\gammaは減衰係数でハイパーパラメータ.\lambdaが1の時には推定されるadvantage functionはhigh variance,low biasで0の時にはhigh bias,low varianceになる.基本的にはbiasのない値が欲しいためどちらの係数も1に近いものを選ぶ.

GAEの式を見ると価値関数が入っていて,やはりまだ陽に計算することができない.なのでこの価値関数をニューラルネットで推定しようというのがここでの解決方法.具体的には以下の最小化問題を解くことで価値関数を推定するニューラルネットを学習する.

\displaystyle
\phi_{k+1}=\underset{\phi}{\mathrm{argmin}}\frac{1}{|\mathcal{D}_k|T}\sum_{\tau\in\mathcal{D}_k}\sum_{t=0}^T\left(V_\phi(s_t)-\hat{R}_t\right)^2

\phiニューラルネットのパラメータで\mathcal{D}はエピソードの数.この最小化問題を利用して学習したニューラルネットを用いてadvantage functionを推定し,推定されたadvantage functionを使って次の勾配を使って勾配上昇法によりpolicyを学習する.ただ,実際にはpolicyとvalue functionは交互に学習していく.詳細はspinning upのページにあるPseudocodeを参照.

\displaystyle
\nabla_\theta J(\pi_\theta)=\frac{1}{|\mathcal{D}_k|}\sum_{\tau\in\mathcal{D}_k}\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|s_t)|_{\theta_k}\hat{A}_t

アルゴリズムの目的や式が直感的にもわかりやすい.このコードをpytorchで書くと次のような感じ.

"""core.py"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.signal

class categorical_policy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=(64,64)):
        super().__init__()
        layer = [
            nn.Linear(obs_dim, hidden_size[0]),
            nn.ReLU()
        ]
        for i in range(1, len(hidden_size)):
            layer.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_size[-1], act_dim))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        pi = self.policy(obs)
        return pi

class actor_critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=(64,64)):
        super().__init__()
        layer = [
            nn.Linear(obs_dim, hidden_size[0]),
            nn.ReLU()
        ]
        for i in range(1, len(hidden_size)):
            layer.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            layer.append(nn.ReLU())

        layer.append(nn.Linear(hidden_size[-1], 1))
        self.value_function = nn.Sequential(*layer)

        self.policy = categorical_policy(obs_dim, act_dim, hidden_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, obs):
        pi = self.policy(obs)
        a  = torch.multinomial(pi.softmax(1), 1).squeeze()
        log_pi = F.log_softmax(pi, 1)
        log_pi = (torch.eye(log_pi.shape[-1])[a] * F.log_softmax(pi, 1)).sum(1)

        v = self.value_function(obs).squeeze()
        return a, log_pi, v

    def value(self, obs):
        return self.value_function(obs).squeeze()

    def likelihood(self, obs, a):
        pi = self.policy(obs)
        log_pi = F.log_softmax(pi, 1)
        log_pi = (torch.eye(log_pi.shape[-1])[a] * F.log_softmax(pi, 1)).sum(1)
        return log_pi

def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.
    input: 
        vector x, 
        [x0, 
         x1, 
         x2]
    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
"""vpg.py"""
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

import time, gym

from core import actor_critic as ac
from core import discount_cumsum

class VPGBuffer:
    """
    A buffer for storing trajectories experienced by a VPG agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.act_buf = np.zeros([size], dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logp):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.ptr += 1

    def finish_path(self, last_val=0):
        """
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.
        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)
        
        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
        
        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
        
        self.path_start_idx = self.ptr

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = self.adv_buf.mean(), self.adv_buf.std()
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        return [self.obs_buf, self.act_buf, self.adv_buf, 
                self.ret_buf, self.logp_buf]

def vpg(env_name, actor_critic_func, ac_kwargs=dict(), seed=0,
        steps_per_epoch=4000, epochs=50, gamma=0.99, pi_lr=3e-4,
        vf_lr=1e-3, train_v_iters=80, lam=0.97, max_ep_len=1000,
        save_freq=10):

    env = gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    buf = VPGBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

    actor_critic = actor_critic_func(obs_dim, act_dim, ac_kwargs["hidden_sizes"])

    torch.random.manual_seed(seed)
    np.random.seed(seed)

    optim_policy = optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    optim_value  = optim.Adam(actor_critic.value_function.parameters(), lr=vf_lr)

    def update():
        optim_policy.zero_grad()
        inputs = buf.get()

        obs_tens = torch.from_numpy(inputs[0]).float()
        act = torch.from_numpy(inputs[1]).long()

        log_pi = actor_critic.likelihood(obs_tens, act)

        pi_loss = -(log_pi * torch.from_numpy(inputs[2]).float()).mean()
        pi_loss.backward()
        optim_policy.step()

        for _ in range(train_v_iters):
            optim_value.zero_grad()
            v_t = actor_critic.value(obs_tens)
            v_loss  = (torch.from_numpy(inputs[3]).float() - v_t).pow(2).mean()
            v_loss.backward()

        optim_value.step()

        return pi_loss.item(), v_loss.item()

    start_time = time.time()
    obs, ret, kill, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    for epoch in range(epochs):
        for t in range(steps_per_epoch):
            # if epoch > 40:
                # env.render()
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            act, log_pi, v_t = actor_critic(obs_tens)
            buf.store(obs, act[0].item(), ret, v_t.item(), log_pi.item())

            obs, ret, kill, _ = env.step(act[0].item())

            ep_ret += ret
            ep_len += 1

            terminal = kill or (ep_len == max_ep_len)
            if terminal or (t==steps_per_epoch-1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)
                last_val = ret if kill else actor_critic.value(torch.from_numpy(obs).float()).item()
                buf.finish_path(last_val)
                print(ep_len)
                obs, ret, kill, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        pi_loss, v_loss = update()
        print("pi loss : ", pi_loss)
        print("v loss : ", v_loss)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='CartPole-v0')
    parser.add_argument('--hid', type=int, default=64)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=4)
    parser.add_argument('--steps', type=int, default=4000)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='vpg')
    args = parser.parse_args()
    vpg(args.env, actor_critic_func=ac,
        ac_kwargs=dict(hidden_sizes=[args.hid]*args.l), gamma=args.gamma, 
        seed=args.seed, steps_per_epoch=args.steps, epochs=args.epochs)

元のtensorflowのコードを流用して書いたのでだいぶ汚いコードになっているので注意(次の実装からはちゃんとやります).cartpoleでしか試していないが,simplest policy gradientに比べ学習自体は遅くなっている.実装ミスではなく学習パラメータが増えたためと思いたい…

まとめ

Advantage functionを推定するGAE自体はICLR2016と意外と新しくて驚き.論文ではspinning upのalgorithms docs内のTRPOとして実装されていたが,単純な形の実装に落として紹介しているあたり教材としてよく考えられているなという印象.強化学習初学者の自分にも気持ちや実装方法などがわかりやすくよかった.

Grid R-CNNを読んだのでメモ

はじめに

Grid R-CNNを読んだのでメモ.前に読んだcornerNetと同じくkey pointベースの検出方法.投稿時期とフォーマットからおそらくcvprに投稿された論文.また今年もSenseTimeから沢山の論文がアクセプトされるのでしょうか.

Grid R-CNN

基本的にCornerNetと同じコンセプトで,bounding boxの角(key point)をヒートマップとして出力させて後処理的にBounding Boxを生成する.CornerNetはSSDやYOLOのようなSingle stageの検出方法だったが,ここではtwo stage検出として構成する.CornerNetとの本質的な違いはkey pointの推論に必要な情報がない場合,つまりkey pointに対応するピクセルに物体の情報が無い場合の対処法.CornerNetではCorner poolingによりこれを克服していたが,Grid R-CNNでは検出するkey pointの数を増やすことで,情報のないkey pointの検出を他のkey pointの検出により補うことで解決する.

Grid Guided Localization

ここでは物体のbounding boxの検出としてN\times N gridを検出することを考える.Figure 1には3x3の場合の例が載っていて,この場合には矩形の4角と各辺の中点,後中心の全部で9点のkey pointを検出することになる.今回はR-CNN型の検出器なので,このgridの検出はtwo stage目,つまりRoIpooling(ないしはRoIalign)をした後のrefinementの部分で行われ,key pointの検出に関してはheatmapでの出力となり,学習はkey pointがあるかないかの2値分類としてbinary cross-entropyの最小化を行う.出力のheatmapは各key point毎に1つ,つまり3x3のkey point検出の場合は全部で9つのheatmapを出力する.正解値に関しては論文中に"5 pixels in a cross shape are labeled as positive locations of the target grid point."とだけ書いてあり,十字のハードラベルを正解とするらしい?

N\times N点のkey pointが得られたら,それらを使って次のようにbounding box B=(x_l,y_u,x_r,y_b)を決定する.

\displaystyle
x_l=\frac{1}{N}\sum_{j\in E_1}x_jp_j,\:y_u=\frac{1}{N}\sum_{j\in E_2}y_jp_j\\ \displaystyle
x_r=\frac{1}{N}\sum_{j\in E_3}x_jp_j,\:y_b=\frac{1}{N}\sum_{j\in E_4}y_jp_j

x,yはkey pointの座標でpは推論結果の確率を表す.E_iはbounding boxの各辺上にあるkey pointの集合.なので矩形の中心のkey pointや4x4gridなどに拡張した際に(論文では3x3しかやってないが)出てくる矩形の内側のkey pointは矩形を決める際には寄与しない.なんの意味があって検出するんだということになるがそれは次のfeature fustionに利用される.

Grid Points Feature Fusion

Key pointベースの物体検出で問題となるのは画像上でkey pointが存在する位置に対象となる物体の情報が含まれていない場合があること.CornerNetではcorner poolingなるプーリングによって解決していたが,Grid R-CNNではその他のkey pointを利用して解決をする.

考え方としては,例えば3x3 gridの左上のkey pointを検出することを考えると,上中央と中央左のkey pointの座標から左上のkey pointの座標を得ることができる.なので各key point毎に独立な特徴マップを持っておいて,関係するkey point同士の特徴マップを融合しようというもの.Figure 3にコンセプトの例があり,さらにもう一つ遠いkey pointを利用する例(second order feature fusion)も示されている.

Key point毎の特徴マップをF_i,今注目しているkey pointの特徴を補正するのに使うkey point(source point)の集合をS_iとする.S_iに属するj番目のsource pointの特徴マップF_jに対してカーネルサイズ5x5の畳み込み層を3回適用することで情報の伝達を行う.この畳み込みをT_{j\rightarrow I}とする.T_{j\rightarrow i}から出力された特徴マップを使って次のように注目しているkey pointの特徴マップを変換する.

\displaystyle
F_i'=F_i+\sum_{j\in S_i}T_{j\rightarrow i}(F_j)

基本的に各key pointはregular gridな形の配置になるため,学習される畳み込み層T_{j\rightarrow i}は平行移動を表現するようなものになるはずというのが裏にある思いかと. ちなみにFirst order fusionの場合にはS_iに属するkey pointはgrid上におけるL_1距離が1のもので,second orderの場合には2といった形で変わる.

Extended Region Mapping

ここはtips的な部分で,簡単言えばregion proposal network (RPN)で得られた候補矩形がそもそも検出対象の物体をカバーしてなければheatmapベースの検出を行うことができないので大きめにRoIPoolingしようということ.じゃあどれくらい大きくするかというとこで,この論文では以下の式に従ってRoIPoolingする領域を決める.

\displaystyle
I'_x=P_x+\frac{4H_x-w_o}{2w_o}w_p\\ \displaystyle
I'_y=P_y+\frac{4H_y-h_o}{2h_o}h_p

I_x',I_y'は元画像上での位置.h_o,w_oは出力のマップの縦横サイズで,h_p,w_pRPNから得られた矩形の縦横サイズ,P_x,P_yは矩形の左上の座標.この計算に従ってRoIPoolingでクロップする領域を大きくすれば真の矩形との被り率が0.5を超える矩形なら対象となる物体全体をクロップすることが可能.

まとめ

Object detectionもkey pointが主流になりそうな感じ.ここら辺で一回実装して理解を深めたいところ.

CornerNetと比較してGrid R-CNNはcorner poolingみたいな実装めんどくさそうな演算がないのは実装するとき嬉しいところ.

Deformable ConvNets v2: More Deformable, Better Resultsを読んだのでメモ

はじめに

Deformable ConvNets v2: More Deformable, Better Resultsを読んだのでメモ.MSRAのインターンの成果でdeformable convolutionを改良したとのこと.著者は元の論文の人と一部同じ.

気持ち

Deformable convolutionはより柔軟に物体の構造を捉えることができるように提案された手法で,実際物体検出等のタスクで良好な結果を示した.ただ,deformable convolutionが参照している領域を可視化してみると,大体は単一の物体を参照するように畳み込みのカーネルが変形しているが,うまく物体をカバーできていない部分もあったとのこと.逆に言えばまだdeformable convolutionのポテンシャルがあるから,最大限に性能を引き出すために改善するという感じ.

Analysis of deformable convnet behavior

Deformable convnetがあまりうまく働かない状況を受容野の可視化を通して解析する.ただここでは受容野の可視化にeffective receptive fieldsという方法を使っていて,この手法を知らなかったで次回読む.また,deformable convolutionがどの領域からサンプリングしているかの可視化もeffective receptive fieldsを応用して行うことで解析に役立てる.中身は置いておいてeffective receptive fieldsは単純な受容野だけでなく受容野に含まれる領域がどれだけ出力に寄与しているかも含めて可視化する方法らしい.

可視化の結果はFigure 1にまとまっていて,結果から以下のことがわかる.

  1. 通常のConvNetsはある程度の範囲の幾何的な構造をモデル化できる.

  2. Deformable ConvNetsはConvNetsの幾何構造を捉える能力を劇的に高めている.ただし,サンプリングされたピクセルの識別への寄与率をみると前景と背景どちらの領域にも広がっていることがわかる.

  3. ConvNetsはカーネルの形状が固定なのでその重みを調整することでうまく幾何構造を捉えているが,これはカーネルのオフセットと重み両方を学習するDeformable ConvNetsでも成り立つはずである.

すなわちちゃんと興味対象に焦点を当てるようカーネルをdeformするようにすればより良くなるはずとのこと.論文では言及されてなかったが,おそらく幾何的な構造を捉える部分がoffsetの推論に押し付けられているのがよくないとの解釈で問題はないかと.

More Deformable ConvNets

Stacking More Deformable Conv Layers

可視化をして色々分析したわりに最初の改善案としてはdefomable convを導入しまくるというもの.元々のdeformable ConvNetsは元のネットワーク構造の一部のみをdeformConvに置き換えるというものだった.なのでもっとたくさん導入すれば幾何的な変形を捉える能力がより高まるはずという脳筋的発想.具体的にはResNet-50のconv3からconv5までのフィルタサイズ3\times 3の畳み込みを全てdeformable convlutionに置き換える.色々試したけどdeformable convolutionを導入し過ぎてもダメで,conv3からconv5に入れるのが良かったらしい.

Modulated Deformable Modules

さらなる表現力向上を目指して,modulation mechanismを導入する.このメカニズムによって入力の特徴を認識するためにオフセットを調整するだけでなく特徴のamplitudesを調整することも可能になる.特徴のamplitudeをゼロにすることで特定の領域から信号を受け取ることがなくなる.すなわち幾何構造を捉えるためにはoffsetだけでは難しいという状況にすることで構造のモデル化というタスクを分散させることが可能になるということ.

ここから定式化.畳み込みのカーネルK個の領域(3x3カーネルならK=9)からサンプリングを行うものとし,w_kp_kをそれぞれk番目のピクセルに対する重みとあらかじめ定義されたオフセットを表す.K=9の場合にはp_k\in\{(-1,-1),(-1,0),\dots,(1,1)\}になる.x(p),y(p)を入力の特徴マップxと出力の特徴マップyの位置pにおける値とする.modulated deformable convolutionは次のように表現される.

\displaystyle
y(p)=\sum_{k=1}^K w_k\cdot x(p+p_k+\Delta p_k)\cdot\Delta m_k

\Delta p_k,\Delta m_k\in[0,1]は位置kに対する学習可能なオフセットとmodulationでスカラー値.どちらの値も元のdeformable convolutionと同様,別なconvolution層の出力として得られる.また,基本的な計算方法もdeformable convolutionと変わらずbilinear interpolationを利用して計算する.オフセットとmodulationを出力する層の重みの初期値は0で初期化,すなわち初期段階は通常の畳み込みと等しい結果になるよう初期化する.また学習率はネットワークを学習するための学習率の0.1倍に設定すると良いとのこと.

RoIpoolingも同じようにmodulated deformable RoIpoolingへと次のように拡張する.

\displaystyle
y(k)=\sum_{j=1}^{n_k}x(p_{kj}+\Delta p_k)\cdot\Delta m_k/n_k

p_{kj}はRoIpoolingを行うときに定義されるグリッドのj番目のセル内の位置kピクセルを表す.n_kはグリッドに含まれるピクセル数で,n_kで正規化していることから平均プーリングになっている(RoIpoolingは元々はmax poolingだった気もする).

Modulateという言葉を使っているがニューラルネット界隈でよく使われる表現を使えばattentionを(厳密にはsoftmaxを使っていないのでちょっと違うかもしれないが)導入したということ.こういう言い回しも新規性ある風に見せる時に重要なのかなと思ったり.

R-CNN Feature Mimicking

Feature mimickingという闇深そうな技術がECCV2018で発表されていたよう.知らなかったのと今回の論文の本質とは無関係なためここでは割愛.Feature mimickingの論文は今度読んでみる予定.

まとめ

手法自体はattention的なものを導入して適用する層の数を増やしたというところ.精度は元のdeformable convolutionの適用数を増やすだけで元のdeformable convolutionより3%の向上があり,modulatedは1%弱しか寄与していない.というか元の論文でdeformable convolutionの適用数に関してそんなに実験してなかったのか.

Self-Supervised Generative Adversarial Networksを読んだのでメモ

はじめに

Self-Supervised Generative Adversarial Networksを読んだのでメモ。

気持ち

まず,GAN関係なしにNeural Netは識別タスクにおいて,識別の環境が動的に変化する(タスクが変化する)と前の識別境界を忘れるという問題がある.この問題はGANのdiscriminatorでも起こっていて,generatorが生成する画像が学習中動的に変化していくためdiscriminatorが以前作った識別境界を忘れ,それにより学習が不安定になるというのがこの論文で課題としているところ.ただ,この問題は条件付きの学習(真偽判定だけでなくラベルの識別を含む学習)によって回避できるという.この性質をラベルのないデータでも活用するためにself-supervised GANを提案するというもの.

key issue discriminator forgetting

GANの目的関数は以下で与えられる.

\displaystyle
V(G,D)=\mathbb{E}_{\mathbf{x}\sim P_{data}(\mathbf{x}}[\log P_D(S=1|\mathbf{x})]+\mathbb{E}_{\mathbf{x}\sim P_G(\mathbf{x})}[\log(1-P_D(S=0|\mathbf{x}))]

親の顔より見た式なのでnotationは割愛.ここで問題視しているのは学習中generatorが表現する分布P_G^{(t)}tt回目の更新の意味)が動的に変わっていくことでnon-stationary online learningになっていること.それによってdiscriminatorは更新の度に前回の分布で作った識別境界を忘れてしまい学習が不安定なる.例えばgeneratorは学習初期に物体の大まかな構造のみを再現した画像を生成するとすれば,discriminatorは大局的な構造の違いや局所的な構造の欠損に基づく識別境界を引く.学習が進みgeneratorがテクスチャのような局所的な構造を再現したデータを生成するようになったとすると,識別に有効な特徴量が変化するためdiscriminatorは今までと全く異なった(前のタスクを完全に無視した)識別境界を引く.これによりdiscriminatorから伝播してくるpenaltyの性質が変わりgeneratorは今まで再現できていた構造が崩れてしまう.これが不安定さの要因(この例は論文では言及されてない自分の解釈が一部入っているので注意).

別な考えとしては,最適な状態ではデータの分布とgeneratorの分布が一致するため,そのような状況ではもはやgeneratorは何のペナルティも受けない.すなわちgeneratorは意味のある表現を維持しようとする必要がなくなる.なので正則化項が入っているような場合ではgeneratorは獲得した表現を失う可能性もある.

実際,簡単な実験で忘却の性質を検証していた(Figure 2とFigure 3).

The self-supervised GAN

上で説明した忘却をいかにして防ぐかというのが課題.ここではself-supervisedを導入することでnon-stationaryからstationaryな課題へと変えることで改善をおこなう.今回はSoTAなself-supervised learningの方法であるimage rotationを利用.image rotationは画像を0°,90°,180°,270°と回転させて,その回転具合を識別させることでよい表現を獲得するself-supervised learningの手法.これを元の目的関数に組み合わせるだけという非常に単純なアプローチ.なので目的関数は次のようになる.

\displaystyle
L_G=-V(G,D)-\alpha\mathbb{E}_{\mathbf{x}\sim P_G}\mathbb{E}_{r\sim\mathcal{R}}[\log Q_D(R=r|\mathbf{x}^r)]\\ \displaystyle
L_D=V(G,D)-\beta\mathbb{E}_{\mathbf{x}\sim P_{data}}\mathbb{E}_{r\sim\mathcal{R}}[\log Q_D(R=r|\mathbf{x}^r)]

r\in\mathcal{R}は回転の角度を表し,\mathcal{R}=\{0,90,180,270\}となる.\mathbf{x}^rはデータをr度回転させることを意味する.ここで大切なのはdiscriminatorはrotationの判定に関して真のデータからのみ学習するということ.つまりここが問題をstationaryな学習にしている部分.

実践的な部分として上の定式化ではdiscriminatorとrotationの判定器は最終層のみ別で特徴抽出部は同一ネットワークで構成されている.また,基本的に\alpha\gt 0である限り真の分布とgeneratorの分布が一致するという保証はない.なので最終的に\alpha=0となるように減衰させていくと保証が得られる.

まとめ

手法はシンプルだが論文の実験においてはconditional GANと同じくらい安定した学習を実現している.研究の着眼点が良く考察もしっかりできていてこんな風に論文を書きたい.

TGANv2: Efficient Training of Large Models for Video Generation with Multiple Subsampling Layersを読んだのでメモ

はじめに

TGANv2: Efficient Training of Large Models for Video Generation with Multiple Subsampling Layersを読んだのでメモ.動画生成のGANは計算コストが大きくなってしまうため解像度の低い動画の生成しかできていなかった点を解決したという論文.

assumption

この論文では計算コストの主な原因はGeneratorの後段,出力層に近い部分にあるとしている.理由としては後段では空間的な解像度が高くなるため演算回数もメモリも大きくなってしまうためとのこと.また,GANは後段のレイヤー程線形に近い単純な表現になっていると仮定していて,その性質を利用して計算コストの高くなる後段の層を効率化するという.

Abstract map and subsampling layer

Abstract mapとsubsampling layerを導入する.Notationが面倒なので代わりにGANの目的関数を以下に.

\displaystyle
\mathbb{E}_{\mathbf{x}\sim p_d}[\ln D(\mathbf{x})]+\mathbb{E}_{\mathbf{z}\sim p_z}[\ln(1-D(G(\mathbf{z})))]

ここではgeneratorをstract block g^Aとrendering block g^Rの二つのブロックに分けて考える.Stract blockはabstract mapと呼ばれる特徴マップを計算し,rendering blockはabstract mapからデータをサンプリングする役割を持つ.簡単に言えば,generatorの出力層に近い方の畳み込みをrendering blockと呼んで残りをstract blockとしている感じ.なのでgenerator G(\mathbf{z})は次のように記述可能.

\displaystyle
\mathbf{x}=G(\mathbf{z})=\left(g^R\circ g^A\right)(\mathbf{z})

GANの計算のボトルネックはdiscriminator側にも存在し,生成されるデータの解像度が大きくなるとdiscriminatorの計算コストも上がる.そこでこの問題を解決するためにsubsampling layer \mathcal{S}^Gを導入する.subsampling layerはクロップ関数(画像を適当にクロップする演算)のような何かしらの関数で,abstract mapからランダムにサンプリングを行う.すると元の解像度より低い解像度のabstract mapが得られるので後段の計算コストを下げることができるというもの.このsubsampled data \mathbf{x}'を生成するgeneratorをG'とすれば次のように表現可能.

\displaystyle
\mathbf{x}'=G'(\mathbf{z})=\left(g^R\circ\mathcal{S}^G\circ g^A\right)(\mathbf{z})

Abstract mapだけでなく生成されるデータの解像度も当然落ちるので目的関数も次のように修正.

\displaystyle
\mathbb{E}_{\mathbf{x}\sim p_d}[\ln D'(\mathcal{S}^D(\mathbf{x}))]+\mathbb{E}_{\mathbf{z}\sim p_z}[\ln(1-D'(G'(\mathbf{z})))]

ここで\mathcal{S}^D\mathcal{S}^Gと同じようにデータをサブサンプルする関数.このように学習されたgeneratorが推論時にも同様にデータを生成するためにg^R,\mathcal{S}^G,\mathcal{S}^Dは次の3つの条件を満たす必要があるとのこと.

  1. g^Rは元のabstract mapと解像度の下がったabstract map両方を入力可能

  2. \mathcal{S}^G,\mathcal{S}^Dは入力に関して微分可能

  3. \mathcal{S}^G,\mathcal{S}^Dによってサンプリングされる領域は確率的に決まる

例えば,g^Rをfully-convolutionalなモデルにして,\mathcal{S}^G,\mathcal{S}^Dを固定ウィンドウ幅のクロップとすれば上の条件を満たす.

実践的な利点として,\mathcal{S}^G,\mathcal{S}^Dを導入すればめちゃくちゃ高次元のデータに対しても学習することが可能.もう一点の利点としてネットワーク側のパラメータも大きくすることができる.これは入力の空間的な解像度が下がればチャネル数を大きくしてもメモリ的に大丈夫ということかと.ただし,subsampling layerを一層だけしか使わなかった場合,使わないネットワークと比較して,同一学習回数においてはパフォーマンスが低下するとのこと.ただこれはabstract mapをマルチレベルにすることで回避可能らしい.

Training with multiple subsampling layers

現状,subsampling layerの導入によって計算コストと生成データの質のトレードオフが生まれた.つまり,subsampling layerによって得られたデータの解像度が小さいほど計算は効率的に行われるが,生成されるデータの質は下がる.ただし,多重解像度かもしくはmultiple frame ratesにすれば質を落とすことなく効率的にネットワークを学習可能.

まず,DCGANやSA-GANで使われるような一般的なgeneratorを考える.通常は最初に得られる特徴マップの解像度は低くチャネル数が多い.逆に出力層に近付くほど特徴マップの解像度は上がりチャネル数が少なくなる.これの意味するところは,入力に近い層ほど出力層付近に比べ学習パラメータ数が多く計算コストが小さいということ.なので入力層付近の層にsubsampling layerを入れても計算コストにおける利点が少ないことから後段の層に対して積極的にsubsampling layerを導入したい.

ここで,generatorが次のようにL個のrendering blocksとL個のabstract blocksとL-1層のsubsampling layerから構成されているとする.

\displaystyle
\mathbf{x}=\left(g_L^R\circ g_L^A\circ g_{L-1}^A\circ\dots\circ g_1^A\right)(\mathbf{z})

この演算は,l番目のgenerator(lフレーム目の生成過程)はl-1番目のgeneratorが生成したabstract mapに対しsubsampling layerを適用して得られたabstract mapを入力として受け取ることを意味する.つまり,陽に書き下せば次のようになる.

\displaystyle
G'_1=g_1^R\circ g_1^A\\ \displaystyle
G'_2=g_2^R\circ g_2^A\circ\left(\mathbf{S}_1^G\circ g_1^A\right)\\ \displaystyle
\vdots\\ \displaystyle
G_L'=g_L^R\circ g_L^A\circ\left(\mathbf{S}^G_{L-1}\circ g_{L-1}^A\right)\circ\dots\circ\left(\mathbf{S}_1^G\circ g_1^A\right)

この操作は計算コストの高い後段で何度もsubsamplingすることで効率的な計算を可能にしている.ちなみにFigure 2にモデルの全体像が書いてあるのでそこを見ると大体の計算の流れはわかる.

Multiple discriminators

一回のサンプリングで複数のサンプルが得られるためdiscriminatorも複数用意する必要がある.よってl番目のサンプルを受け取るdiscriminatorをD_l'とすればdiscriminatorのスコアは次のように表現可能.

\displaystyle
D'(\mathbf{x}_1',\dots,\mathbf{x}_L')=\sigma\left(\sum_{l=1}^LD'_l(\mathbf{x}_l')\right)

Adaptive batch reduction

基本的に学習はミニバッチ単位で行われるため,ここではsubsampling layerをバッチ方向にも拡張することでさらなる計算の効率化をする.ただ,Pix2PixやVid2Vvidのようなバッチサイズがめちゃくちゃ小さくてもうまくいく手法がある一方で,BigGANで言われているようにバッチサイズは大きい方がいいという議論もある.そこでこの論文では最初のサンプルの生成は多様性が必要だが,後段はそこまで多様なマップを生成する必要がない(多分abstract mapを引き継ぐ形になっているからという点が根拠かと)と仮定.その仮定のもと初期ブロックではlarge batchで学習を行い後段ではsmall batchになるようにバッチサイズのreduceを導入したとのこと.

まとめ

どう計算コストを減らしたかのみ気になったのでモデルや学習の詳細は割愛.クロップしたマップを次へ次へと渡すと生成される画像はどんどん拡大されていくような気もするがどうなんだろうか.ただ,クロップしても動くのはクロップ後の畳み込みの回数が少なくてそんな大局的な部分を見ないからというので納得はできる.Inception scoreは先行研究からそれなりのジャンプがあるスコアを叩き出しているのでかなり有効なよう.論文にも書いてあったがこの計算コストの減らし方自体は色々応用がききそう.