機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

OpenAIのSpinning Upで強化学習を勉強してみた その6

はじめに

その6ということで今度はTwin Delayed DDPG(TD3)をpytorchで実装する.

Twin Delayed DDPG

DDPGは基本的にはいいアルゴリズムだが,時たま学習が破綻する場合があるとのこと.その理由としてはQ関数が学習初期において過大評価を行なってしまい,そこに含まれる誤差がpolicyを悪い方向へと引っ張ってしまうことがあげられる.なのでTD3ではちょっとしたテクニックでこれを回避する(つまり裏を返せば理論的な大きな変更はない).

Trick One: Clipped Double-Q Learning

これはQ関数を二つ用意してしまおうというもの.二つのQ関数だからアルゴリズムにtwinがつくとかなんとか.

Trick Two: Delayed Policy Updates

Policyのパラメータ更新をQ関数を複数回更新した後に行おうというもの.ここの実装ではQ関数を2回更新するごとにpolicyを1回更新する.

Trick Three: Target Policy Smoothing

Target actionに対してノイズを与えるというもの.これによりpolicyがQ関数に含まれる誤差に引っ張られることを防ぐとのこと.Smoothというのはpolicyの出力の周りでQ関数が似たような値を出力できるようにという意味かと.

Key Equations

まずtrick threeについて,target policyの出力に対し以下のようなガウスノイズを加える.

\displaystyle
A'(s')=\mathrm{clip}(\mu_{\theta_\mathrm{targ}}(s')+\mathrm{clip}(\epsilon,-c,c),a_{Low},a_{High}),\ \epsilon\sim\mathcal{N}(0,\sigma)

このノイズのおかげでQ関数がシャープなピークを持つことができなくなるという気持ちが入っている.

さらにtargetの計算としてclipped double-Q learning,つまり用意された二つのQ関数のうち小さい値をtargetの算出に使う.

\displaystyle
y(r,s',d)=r+\gamma(1-d)\min_{I=1,2}Q_{\phi,\mathrm{targ}}(s',a'(s'))

このtargetを使って各Q関数を次の値を最小化するよう学習.

\displaystyle
L(\phi_1,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_1}(s,a)-y(r,s',d)\right)^2\right]\\
L(\phi_1,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_2}(s,a)-y(r,s',d)\right)^2\right]

Policyはどちらか一方のQ関数を最大化するよう学習.ここではQ_{\phi_1}を最大化するように学習.

\displaystyle
\max_\theta\mathbb{E}_{s\sim\mathcal{D}}[Q_{\phi_1}(s,\mu_\theta(s))]

実装

実装はDDPGとほぼ一緒.

"""core.py"""
import torch
import torch.nn as nn

import copy

class continuous_policy(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], act_dim))
        layer.append(nn.Tanh())
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class q_function(nn.Module):
    def __init__(self, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], 1))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class actor_critic(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2):
        super().__init__()
        self.policy = continuous_policy(act_dim, obs_dim, hidden_layer)

        self.q1 = q_function(obs_dim+act_dim, hidden_layer)
        self.q2 = q_function(obs_dim+act_dim, hidden_layer)
        self.act_limit = act_limit

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        self.policy_targ = continuous_policy(act_dim, obs_dim, hidden_layer)
        self.q1_targ = q_function(obs_dim+act_dim, hidden_layer)
        self.q2_targ = q_function(obs_dim+act_dim, hidden_layer)

        self.copy_param()

    def copy_param(self):
        self.policy_targ.load_state_dict(self.policy.state_dict())
        self.q1_targ.load_state_dict(self.q1.state_dict())
        self.q2_targ.load_state_dict(self.q2.state_dict())
        # for m_targ, m_main in zip(self.policy_targ.modules(), self.policy.modules()):
        #     if isinstance(m_targ, nn.Linear):
        #         m_targ.weight.data = m_main.weight.data
        #         m_targ.bias.data = m_main.bias.data

        # for m_targ, m_main in zip(self.q_targ.modules(), self.q.modules()):
        #     if isinstance(m_targ, nn.Linear):
        #         m_targ.weight.data = m_main.weight.data
        #         m_targ.bias.data = m_main.bias.data

    def get_action(self, obs, noise_scale):
        pi = self.act_limit * self.policy(obs)
        pi += noise_scale * torch.randn_like(pi)
        pi.clamp_(max=self.act_limit, min=-self.act_limit)
        return pi.squeeze()

    def get_target_action(self, obs, noise_scale, clip_param):
        pi = self.act_limit * self.policy_targ(obs)
        eps = noise_scale * torch.randn_like(pi)
        eps.clamp_(max=clip_param, min=-clip_param)
        pi += eps
        pi.clamp_(max=self.act_limit, min=-self.act_limit)
        return pi.detach()

    def update_target(self, rho):
        # compute rho * targ_p + (1 - rho) * main_p
        for poly_p, poly_targ_p in zip(self.policy.parameters(), self.policy_targ.parameters()):
            poly_targ_p.data = rho * poly_targ_p.data + (1-rho) * poly_p.data

        for q_p, q_targ_p in zip(self.q1.parameters(), self.q1_targ.parameters()):
            q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data

        for q_p, q_targ_p in zip(self.q2.parameters(), self.q2_targ.parameters()):
            q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data


    def compute_target(self, obs, pi, gamma, rewards, done):
        # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
        q1 = self.q1_targ(torch.cat([obs, pi], -1))
        q2 = self.q2_targ(torch.cat([obs, pi], -1))
        q  = torch.min(q1, q2)
        return (rewards + gamma * (1-done) * q.squeeze()).detach()

    def q_function(self, obs, detach=True, action=None):
        # compute Q(s, a) or Q(s, mu(s))
        if action is None:
            pi = self.act_limit * self.policy(obs)
        else:
            pi = action
        if detach:
            pi = pi.detach()
        return self.q1(torch.cat([obs, pi], -1)).squeeze(), self.q2(torch.cat([obs, pi], -1)).squeeze()
"""main.py"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gym, time
import numpy as np

from spinup.utils.logx import EpochLogger
from core import actor_critic as ac

class ReplayBuffer:
    def __init__(self, size):
        self.size, self.max_size = 0, size
        self.obs1_buf = []
        self.obs2_buf = []
        self.acts_buf = []
        self.rews_buf = []
        self.done_buf = []

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf.append(obs)
        self.obs2_buf.append(next_obs)
        self.acts_buf.append(act)
        self.rews_buf.append(rew)
        self.done_buf.append(int(done))
        while len(self.obs1_buf) > self.max_size:
            self.obs1_buf.pop(0)
            self.obs2_buf.pop(0)
            self.acts_buf.pop(0)
            self.rews_buf.pop(0)
            self.done_buf.pop(0)

        self.size = len(self.obs1_buf)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(low=0, high=self.size, size=(batch_size,))
        obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs])
        obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs])
        acts = torch.FloatTensor([self.acts_buf[i] for i in idxs])
        rews = torch.FloatTensor([self.rews_buf[i] for i in idxs])
        done = torch.FloatTensor([self.done_buf[i] for i in idxs])
        return [obs1, obs2, acts, rews, done]

def td3(env_name, actor_critic_function, hidden_size,
        steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, 
        act_noise=0.1, target_noise=0.2, noise_clip=0.5, max_ep_len=1000,
        policy_delay=2, logger_kwargs=dict()):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    replay_buffer = ReplayBuffer(replay_size)

    env, test_env = gym.make(env_name), gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_limit = int(env.action_space.high[0])

    actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit)

    q1_optimizer = optim.Adam(actor_critic.q1.parameters(), q_lr)
    q2_optimizer = optim.Adam(actor_critic.q2.parameters(), q_lr)
    policy_optimizer = optim.Adam(actor_critic.policy.parameters(), pi_lr)

    start_time = time.time()

    obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    for t in range(total_steps):
        if t > 50000:
            env.render()
        if t > start_steps:
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            act = actor_critic.get_action(obs_tens, act_noise).detach().numpy().reshape(-1)
        else:
            act = env.action_space.sample()

        obs2, ret, done, _ = env.step(act)

        ep_ret += ret
        ep_len += 1

        done = False if ep_len==max_ep_len else done

        replay_buffer.store(obs, act, ret, obs2, done)

        obs = obs2

        if done or (ep_len == max_ep_len):
            for j in range(ep_len):
                obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size)
                # compute Q(s, a)
                q1, q2 = actor_critic.q_function(obs1_tens, action=acts_tens)
                # compute r + gamma * (1 - d) * Q(s', mu_targ(s'))
                pi_targ = actor_critic.get_target_action(obs2_tens, target_noise, noise_clip)
                q_targ = actor_critic.compute_target(obs2_tens, pi_targ, gamma, rews_tens, done_tens)
                # compute (Q(s, a) - y(r, s', d))^2
                q_loss = (q1-q_targ).pow(2).mean() + (q2-q_targ).pow(2).mean()

                q1_optimizer.zero_grad()
                q2_optimizer.zero_grad()
                q_loss.backward()
                q1_optimizer.step()
                q2_optimizer.step()

                logger.store(LossQ=q_loss.item(), Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy())

                if j % policy_delay == 0:
                    # compute Q(s, mu(s))
                    policy_loss, _ = actor_critic.q_function(obs1_tens, detach=False)
                    policy_loss = -policy_loss.mean()
                    policy_optimizer.zero_grad()
                    policy_loss.backward()
                    policy_optimizer.step()

                    logger.store(LossPi=policy_loss.item())

                    # compute rho * targ_p + (1 - rho) * main_p
                    actor_critic.update_target(polyak)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # test_agent()
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Pendulum-v0')
    parser.add_argument('--hid', type=int, default=300)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='td3')
    args = parser.parse_args()

    from spinup.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    td3(args.env, actor_critic_function=ac,
        hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs,
        logger_kwargs=logger_kwargs)

まとめ

急に実践的な改良の多い手法がきてこういう論文もあるのかという感じ.試した環境が単純なものすぎてDDPGに対するアドバンテージは結果から感じられなかったけど実装は単純でやりたいこともわかりやすいといえばわかりやすかった.