OpenAIのSpinning Upで強化学習を勉強してみた その6
はじめに
その6ということで今度はTwin Delayed DDPG(TD3)をpytorchで実装する.
Twin Delayed DDPG
DDPGは基本的にはいいアルゴリズムだが,時たま学習が破綻する場合があるとのこと.その理由としてはQ関数が学習初期において過大評価を行なってしまい,そこに含まれる誤差がpolicyを悪い方向へと引っ張ってしまうことがあげられる.なのでTD3ではちょっとしたテクニックでこれを回避する(つまり裏を返せば理論的な大きな変更はない).
Trick One: Clipped Double-Q Learning
これはQ関数を二つ用意してしまおうというもの.二つのQ関数だからアルゴリズムにtwinがつくとかなんとか.
Trick Two: Delayed Policy Updates
Policyのパラメータ更新をQ関数を複数回更新した後に行おうというもの.ここの実装ではQ関数を2回更新するごとにpolicyを1回更新する.
Trick Three: Target Policy Smoothing
Target actionに対してノイズを与えるというもの.これによりpolicyがQ関数に含まれる誤差に引っ張られることを防ぐとのこと.Smoothというのはpolicyの出力の周りでQ関数が似たような値を出力できるようにという意味かと.
Key Equations
まずtrick threeについて,target policyの出力に対し以下のようなガウスノイズを加える.
このノイズのおかげでQ関数がシャープなピークを持つことができなくなるという気持ちが入っている.
さらにtargetの計算としてclipped double-Q learning,つまり用意された二つのQ関数のうち小さい値をtargetの算出に使う.
このtargetを使って各Q関数を次の値を最小化するよう学習.
Policyはどちらか一方のQ関数を最大化するよう学習.ここではを最大化するように学習.
実装
実装はDDPGとほぼ一緒.
"""core.py""" import torch import torch.nn as nn import copy class continuous_policy(nn.Module): def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)): super().__init__() layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()] for i in range(1, len(hidden_layer)): layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i])) layer.append(nn.ReLU()) layer.append(nn.Linear(hidden_layer[-1], act_dim)) layer.append(nn.Tanh()) self.policy = nn.Sequential(*layer) def forward(self, obs): return self.policy(obs) class q_function(nn.Module): def __init__(self, obs_dim, hidden_layer=(400,300)): super().__init__() layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()] for i in range(1, len(hidden_layer)): layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i])) layer.append(nn.ReLU()) layer.append(nn.Linear(hidden_layer[-1], 1)) self.policy = nn.Sequential(*layer) def forward(self, obs): return self.policy(obs) class actor_critic(nn.Module): def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2): super().__init__() self.policy = continuous_policy(act_dim, obs_dim, hidden_layer) self.q1 = q_function(obs_dim+act_dim, hidden_layer) self.q2 = q_function(obs_dim+act_dim, hidden_layer) self.act_limit = act_limit for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) self.policy_targ = continuous_policy(act_dim, obs_dim, hidden_layer) self.q1_targ = q_function(obs_dim+act_dim, hidden_layer) self.q2_targ = q_function(obs_dim+act_dim, hidden_layer) self.copy_param() def copy_param(self): self.policy_targ.load_state_dict(self.policy.state_dict()) self.q1_targ.load_state_dict(self.q1.state_dict()) self.q2_targ.load_state_dict(self.q2.state_dict()) # for m_targ, m_main in zip(self.policy_targ.modules(), self.policy.modules()): # if isinstance(m_targ, nn.Linear): # m_targ.weight.data = m_main.weight.data # m_targ.bias.data = m_main.bias.data # for m_targ, m_main in zip(self.q_targ.modules(), self.q.modules()): # if isinstance(m_targ, nn.Linear): # m_targ.weight.data = m_main.weight.data # m_targ.bias.data = m_main.bias.data def get_action(self, obs, noise_scale): pi = self.act_limit * self.policy(obs) pi += noise_scale * torch.randn_like(pi) pi.clamp_(max=self.act_limit, min=-self.act_limit) return pi.squeeze() def get_target_action(self, obs, noise_scale, clip_param): pi = self.act_limit * self.policy_targ(obs) eps = noise_scale * torch.randn_like(pi) eps.clamp_(max=clip_param, min=-clip_param) pi += eps pi.clamp_(max=self.act_limit, min=-self.act_limit) return pi.detach() def update_target(self, rho): # compute rho * targ_p + (1 - rho) * main_p for poly_p, poly_targ_p in zip(self.policy.parameters(), self.policy_targ.parameters()): poly_targ_p.data = rho * poly_targ_p.data + (1-rho) * poly_p.data for q_p, q_targ_p in zip(self.q1.parameters(), self.q1_targ.parameters()): q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data for q_p, q_targ_p in zip(self.q2.parameters(), self.q2_targ.parameters()): q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data def compute_target(self, obs, pi, gamma, rewards, done): # compute r + gamma * (1 - d) * Q(s', mu_targ(s')) q1 = self.q1_targ(torch.cat([obs, pi], -1)) q2 = self.q2_targ(torch.cat([obs, pi], -1)) q = torch.min(q1, q2) return (rewards + gamma * (1-done) * q.squeeze()).detach() def q_function(self, obs, detach=True, action=None): # compute Q(s, a) or Q(s, mu(s)) if action is None: pi = self.act_limit * self.policy(obs) else: pi = action if detach: pi = pi.detach() return self.q1(torch.cat([obs, pi], -1)).squeeze(), self.q2(torch.cat([obs, pi], -1)).squeeze()
"""main.py""" import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import gym, time import numpy as np from spinup.utils.logx import EpochLogger from core import actor_critic as ac class ReplayBuffer: def __init__(self, size): self.size, self.max_size = 0, size self.obs1_buf = [] self.obs2_buf = [] self.acts_buf = [] self.rews_buf = [] self.done_buf = [] def store(self, obs, act, rew, next_obs, done): self.obs1_buf.append(obs) self.obs2_buf.append(next_obs) self.acts_buf.append(act) self.rews_buf.append(rew) self.done_buf.append(int(done)) while len(self.obs1_buf) > self.max_size: self.obs1_buf.pop(0) self.obs2_buf.pop(0) self.acts_buf.pop(0) self.rews_buf.pop(0) self.done_buf.pop(0) self.size = len(self.obs1_buf) def sample_batch(self, batch_size=32): idxs = np.random.randint(low=0, high=self.size, size=(batch_size,)) obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs]) obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs]) acts = torch.FloatTensor([self.acts_buf[i] for i in idxs]) rews = torch.FloatTensor([self.rews_buf[i] for i in idxs]) done = torch.FloatTensor([self.done_buf[i] for i in idxs]) return [obs1, obs2, acts, rews, done] def td3(env_name, actor_critic_function, hidden_size, steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, act_noise=0.1, target_noise=0.2, noise_clip=0.5, max_ep_len=1000, policy_delay=2, logger_kwargs=dict()): logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) replay_buffer = ReplayBuffer(replay_size) env, test_env = gym.make(env_name), gym.make(env_name) obs_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] act_limit = int(env.action_space.high[0]) actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit) q1_optimizer = optim.Adam(actor_critic.q1.parameters(), q_lr) q2_optimizer = optim.Adam(actor_critic.q2.parameters(), q_lr) policy_optimizer = optim.Adam(actor_critic.policy.parameters(), pi_lr) start_time = time.time() obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0 total_steps = steps_per_epoch * epochs for t in range(total_steps): if t > 50000: env.render() if t > start_steps: obs_tens = torch.from_numpy(obs).float().reshape(1,-1) act = actor_critic.get_action(obs_tens, act_noise).detach().numpy().reshape(-1) else: act = env.action_space.sample() obs2, ret, done, _ = env.step(act) ep_ret += ret ep_len += 1 done = False if ep_len==max_ep_len else done replay_buffer.store(obs, act, ret, obs2, done) obs = obs2 if done or (ep_len == max_ep_len): for j in range(ep_len): obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size) # compute Q(s, a) q1, q2 = actor_critic.q_function(obs1_tens, action=acts_tens) # compute r + gamma * (1 - d) * Q(s', mu_targ(s')) pi_targ = actor_critic.get_target_action(obs2_tens, target_noise, noise_clip) q_targ = actor_critic.compute_target(obs2_tens, pi_targ, gamma, rews_tens, done_tens) # compute (Q(s, a) - y(r, s', d))^2 q_loss = (q1-q_targ).pow(2).mean() + (q2-q_targ).pow(2).mean() q1_optimizer.zero_grad() q2_optimizer.zero_grad() q_loss.backward() q1_optimizer.step() q2_optimizer.step() logger.store(LossQ=q_loss.item(), Q1Vals=q1.detach().numpy(), Q2Vals=q2.detach().numpy()) if j % policy_delay == 0: # compute Q(s, mu(s)) policy_loss, _ = actor_critic.q_function(obs1_tens, detach=False) policy_loss = -policy_loss.mean() policy_optimizer.zero_grad() policy_loss.backward() policy_optimizer.step() logger.store(LossPi=policy_loss.item()) # compute rho * targ_p + (1 - rho) * main_p actor_critic.update_target(polyak) logger.store(EpRet=ep_ret, EpLen=ep_len) obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if t > 0 and t % steps_per_epoch == 0: epoch = t // steps_per_epoch # test_agent() logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) # logger.log_tabular('TestEpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) # logger.log_tabular('TestEpLen', average_only=True) logger.log_tabular('TotalEnvInteracts', t) logger.log_tabular('Q1Vals', with_min_and_max=True) logger.log_tabular('Q2Vals', with_min_and_max=True) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('LossQ', average_only=True) logger.log_tabular('Time', time.time()-start_time) logger.dump_tabular() if __name__ == '__main__': import argparse import argparse parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='Pendulum-v0') parser.add_argument('--hid', type=int, default=300) parser.add_argument('--l', type=int, default=1) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--seed', '-s', type=int, default=0) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--exp_name', type=str, default='td3') args = parser.parse_args() from spinup.utils.run_utils import setup_logger_kwargs logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) td3(args.env, actor_critic_function=ac, hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs, logger_kwargs=logger_kwargs)
まとめ
急に実践的な改良の多い手法がきてこういう論文もあるのかという感じ.試した環境が単純なものすぎてDDPGに対するアドバンテージは結果から感じられなかったけど実装は単純でやりたいこともわかりやすいといえばわかりやすかった.