OpenAIのSpinning Upで強化学習を勉強してみた その5
はじめに
その5ということで今度はDeep Deterministic Policy Gradient(DDPG)をpytorchで実装する.
Deep Deterministic Policy Gradient
DDPGは今までと違いQ-learningの枠組みを取り入れた(論文の背景的にはQ-learningにpolicy gradientを取り柄れたという方が正確かも)アルゴリズム.DQNで有名なQ-learningは行動が離散でないと適用ができないという課題があったが,policyを同時に学習することで行動が連続な課題に対してもQ-learningを適用可能にしたというもの.
Q-learning
簡単にQ-learningについて述べる(といってもspinning upにはQ-learningの解説はないので完全に個人的な理解だが).Q関数は現在の状態においてある行動をとったときにどれくらいの価値があるかを図る関数で,価値が最大となる行動を選択することでエージェントを動かす.この時,行動が離散で高々有限個の行動しかとれない場合には全ての行動に対して価値を計算することで最適な行動を得ることができる.行動が連続になると全ての行動に対して価値を計算するわけにはいかなくなるため今回のような戦略が必要になるというもの.
Q関数自体はbellman方程式に従ってQ関数を推定するようなモデルを学習する.すなわち次の方程式を満たすような関数を学習する.
は状態において行動をとった時に起こりうる状態を表す.この方程式を満たすためには次の最小化問題を解けばいい.
この目的関数は mean-squared Bellman error (MSBE)と呼ばれる.本題のDDPGに戻れば,DDPGはQ-learningを連続な行動に対して適用できるよう,Q関数を最大とするような行動を返すpolicyを同時に学習する.具体的には次の目的関数を最小化する.
単に状態におけるQ関数を最大にする行動を選択する部分をというpolicyを使って求めるというもの.これにより,離散の時のような全ての行動に対する価値を陽に計算する必要がなくなるため連続な行動に対してもQ-learningを適用可能となる.
Policy自体はQ関数を最大にする行動を返すような関数であって欲しいため,次のような最大化の問題によって学習を行う.
なのでこのpolicyを学習する最大化問題とQ関数を学習する最小化問題を交互に最適化していけば良いpolicyを得ることができる.
注意としてこのDDPGは離散な行動の場合には利用することができない.これはQ関数を最大にする行動を返すようなpolicyを学習する際に,policyに対する勾配が(policyの出力をargmax等で離散表現に落とすため)計算不可能になってしまうため.ただ,離散の場合にはpolicyを導入することなく普通のQ-learningを行えばいいので特に問題はない気もする.
Tips
ここでは二つのテクニックを利用した学習法を用いている.
Replay Buffers
これは過去のエピソードを保存しておいて学習に利用することで過学習を防ぐとともに学習データのサンプリング効率を上げるというもの.
Target Networks
MSBEの最小化はbellman方程式の右辺をtargetとすれば,targetとtargetに近づけたい部分が同じパラメータを持っているため(すなわちとの差を最小にしたいのにどちらも同じパラメータを持つがいるため)不安定な最適化となる.そこでtarget用のネットワークを新たに用意することで学習の安定化を図ろうというのがこの方法.
具体的には学習したいQ関数と全く同じパラメータを持つtarget network を用意し,学習したいQ関数は通常通り勾配降下法で,target networkの方は次のような移動平均によってパラメータを計算する.
はハイパーパラメータで基本的に1に近い値を用いる.このtarget networkはpolicy側にも用意することができ,ここでは最終的なMSBEを次のような形で与えている.
Policyに関するtarget networkも移動平均によりパラメータを更新する.
Exploration vs. Exploitation
Q-learningはある意味で行動と状態を軸とするルックアップテーブルを埋めていく作業としてみることができる(と思う).なのでoff-policyで学習するとPolicyが決定的なものになってしまい多様な範囲を探索できなくなる恐れがある.なのでpolicyの出力にガウス雑音を加えることで行動を確率的なものとしてより広い範囲を探索できるようにするというもの.
実装
以下実装.spinning upのページにはアルゴリズムテーブルもあるので参考に.
"""core.py""" import torch import torch.nn as nn import copy class continuous_policy(nn.Module): def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)): super().__init__() layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()] for i in range(1, len(hidden_layer)): layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i])) layer.append(nn.ReLU()) layer.append(nn.Linear(hidden_layer[-1], act_dim)) layer.append(nn.Tanh()) self.policy = nn.Sequential(*layer) def forward(self, obs): return self.policy(obs) class q_function(nn.Module): def __init__(self, obs_dim, hidden_layer=(400,300)): super().__init__() layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()] for i in range(1, len(hidden_layer)): layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i])) layer.append(nn.ReLU()) layer.append(nn.Linear(hidden_layer[-1], 1)) self.policy = nn.Sequential(*layer) def forward(self, obs): return self.policy(obs) class actor_critic(nn.Module): def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2): super().__init__() self.policy = continuous_policy(act_dim, obs_dim, hidden_layer) self.q = q_function(obs_dim+act_dim, hidden_layer) self.act_limit = act_limit for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) self.policy_targ = continuous_policy(act_dim, obs_dim, hidden_layer) self.q_targ = q_function(obs_dim+act_dim, hidden_layer) self.copy_param() def copy_param(self): self.policy_targ.load_state_dict(self.policy.state_dict()) self.q_targ.load_state_dict(self.q.state_dict()) def get_action(self, obs, noise_scale): pi = self.act_limit * self.policy(obs) pi += noise_scale * torch.randn_like(pi) pi.clamp_(max=self.act_limit, min=-self.act_limit) return pi.squeeze() def update_target(self, rho): # compute rho * targ_p + (1 - rho) * main_p for poly_p, poly_targ_p in zip(self.policy.parameters(), self.policy_targ.parameters()): poly_targ_p.data = rho * poly_targ_p.data + (1-rho) * poly_p.data for q_p, q_targ_p in zip(self.q.parameters(), self.q_targ.parameters()): q_targ_p.data = rho * q_targ_p.data + (1-rho) * q_p.data def compute_target(self, obs, gamma, rewards, done): # compute r + gamma * (1 - d) * Q(s', mu_targ(s')) pi = self.act_limit * self.policy_targ(obs) return (rewards + gamma * (1-done) * self.q_targ(torch.cat([obs, pi], -1)).squeeze()).detach() def q_function(self, obs, detach=True, action=None): # compute Q(s, a) or Q(s, mu(s)) if action is None: pi = self.act_limit * self.policy(obs) else: pi = action if detach: pi = pi.detach() return self.q(torch.cat([obs, pi], -1))
"""ddpg.py""" import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import gym, time import numpy as np from spinup.utils.logx import EpochLogger from core import actor_critic as ac class ReplayBuffer: def __init__(self, size): self.size, self.max_size = 0, size self.obs1_buf = [] self.obs2_buf = [] self.acts_buf = [] self.rews_buf = [] self.done_buf = [] def store(self, obs, act, rew, next_obs, done): self.obs1_buf.append(obs) self.obs2_buf.append(next_obs) self.acts_buf.append(act) self.rews_buf.append(rew) self.done_buf.append(int(done)) while len(self.obs1_buf) > self.max_size: self.obs1_buf.pop(0) self.obs2_buf.pop(0) self.acts_buf.pop(0) self.rews_buf.pop(0) self.done_buf.pop(0) self.size = len(self.obs1_buf) def sample_batch(self, batch_size=32): idxs = np.random.randint(low=0, high=self.size, size=(batch_size,)) obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs]) obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs]) acts = torch.FloatTensor([self.acts_buf[i] for i in idxs]) rews = torch.FloatTensor([self.rews_buf[i] for i in idxs]) done = torch.FloatTensor([self.done_buf[i] for i in idxs]) return [obs1, obs2, acts, rews, done] def ddpg(env_name, actor_critic_function, hidden_size, steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, polyak=0.995, pi_lr=1e-3, q_lr=1e-3, batch_size=100, start_steps=10000, act_noise=0.1, max_ep_len=1000, logger_kwargs=dict()): logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) replay_buffer = ReplayBuffer(replay_size) env, test_env = gym.make(env_name), gym.make(env_name) obs_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] act_limit = int(env.action_space.high[0]) actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit) q_optimizer = optim.Adam(actor_critic.q.parameters(), q_lr) policy_optimizer = optim.Adam(actor_critic.policy.parameters(), pi_lr) start_time = time.time() obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0 total_steps = steps_per_epoch * epochs for t in range(total_steps): if t > 50000: env.render() if t > start_steps: obs_tens = torch.from_numpy(obs).float().reshape(1,-1) act = actor_critic.get_action(obs_tens, act_noise).detach().numpy().reshape(-1) else: act = env.action_space.sample() obs2, ret, done, _ = env.step(act) ep_ret += ret ep_len += 1 done = False if ep_len==max_ep_len else done replay_buffer.store(obs, act, ret, obs2, done) obs = obs2 if done or (ep_len == max_ep_len): for _ in range(ep_len): obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size) # compute Q(s, a) q = actor_critic.q_function(obs1_tens, action=acts_tens) # compute r + gamma * (1 - d) * Q(s', mu_targ(s')) q_targ = actor_critic.compute_target(obs2_tens, gamma, rews_tens, done_tens) # compute (Q(s, a) - y(r, s', d))^2 q_loss = (q.squeeze()-q_targ).pow(2).mean() q_optimizer.zero_grad() q_loss.backward() q_optimizer.step() logger.store(LossQ=q_loss.item(), QVals=q.detach().numpy()) # compute Q(s, mu(s)) policy_loss = -actor_critic.q_function(obs1_tens, detach=False).mean() policy_optimizer.zero_grad() policy_loss.backward() policy_optimizer.step() logger.store(LossPi=policy_loss.item()) # compute rho * targ_p + (1 - rho) * main_p actor_critic.update_target(polyak) logger.store(EpRet=ep_ret, EpLen=ep_len) obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if t > 0 and t % steps_per_epoch == 0: epoch = t // steps_per_epoch # test_agent() logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) # logger.log_tabular('TestEpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) # logger.log_tabular('TestEpLen', average_only=True) logger.log_tabular('TotalEnvInteracts', t) logger.log_tabular('QVals', with_min_and_max=True) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('LossQ', average_only=True) logger.log_tabular('Time', time.time()-start_time) logger.dump_tabular() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='Pendulum-v0') parser.add_argument('--hid', type=int, default=300) parser.add_argument('--l', type=int, default=1) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--seed', '-s', type=int, default=0) parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--exp_name', type=str, default='ddpg') args = parser.parse_args() from spinup.utils.run_utils import setup_logger_kwargs logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed) ddpg(args.env, actor_critic_function=ac, hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs, logger_kwargs=logger_kwargs)
Pytorchを使った実装している上で踏んだバグ?があったので注意.具体的には,最初target networkを定義する時に元のネットワークをcopy.deepcopy()することでtarget networkを作ったが,こうすると学習パラメータがtorchのnn.Paramterからtorch.Tensorに置き換わるというもの.このせいでtarget networkのパラメータを更新することができなくなり学習が進まなかった(このバグに気づくのに半日とかした).なのでこのコードではstate_dictを利用してネットワークのパラメータをコピーしている.
まとめ
Q-learningとadvantage functionを使ったpolicy gradientはどっちの方が筋がいいとかあるんだろうか.ここまでの感じだとQ-learning側はいくつかのテクニックを用いないと学習が安定しない分policy gradientの方が手法的には筋がいい?