機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

OpenAIのSpinning Upで強化学習を勉強してみた その7

はじめに

その7ということで今度はSoft Actor-Critic(SAC)をpytorchで実装する.

Soft Actor-Critic

SACはTD3とほぼ同時期にpublishされた論文.内容の肝としてはDDPGをベースにentropy regularizationを加えたというもの.簡単に言ってしまえば報酬に対して確率的な方策のエントロピーを加えるというもの.なので価値関数とQ関数は次のように表現される.

\displaystyle
H(P)=\mathbb{E}_{x\sim P}[-\log P(x)]\\
V^\pi(s)=\mathbb{E}_{\tau\sim\pi}\left[\left.\sum_{t=0}^{\infty}\gamma^t\left(R(s_t,a_t,s_{t+1})+\alpha H(\pi(\cdot|s_t))\right)\right| s_0=s\right]\\
Q^\pi(s,a)=\mathbb{E}_{\tau\sim\pi}\left[\left.\sum_{t=0}^\infty\gamma^tR(s_t,a_t,s_{t+1})+\alpha\sum_{t=1}^\infty\gamma^tH(\pi(\cdot|s_t))\right| s_0=s,a_0=a\right]

\alpha \lt 0はハイパーパラメータ.\alphaを大きな値にするとエントロピーを大きくしようとするため方策はランダムな値を取りやすくなる.この式においては価値関数とQ関数の関係は以下のようになる.

\displaystyle
V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)]+\alpha H(\pi(\cdot|s))

なのでQ関数に関するBellman方程式も次のように書き変えられる.

\displaystyle
Q^\pi(s,a)=\mathbb{E}_{s'\sim P}[R(s,a,s')+\gamma(Q^\pi(s',a')+\alpha H(\pi(\cdot|s')))]=\mathbb{E}_{s'\sim P}[R(s,a,s')+\gamma V^\pi(s')]

この式を利用してSACではpolicyと二つのQ関数と一つの価値関数の学習を行う.Q関数が二つあるのはTD3と同様の理由.

Qの学習

Q関数の学習はDDPG同様target networkを利用してMSBEの最小化を行う.ただし,今回Q関数のBellman方程式は価値関数を使ってかけるためここでのtarget networkは価値関数のtarget networkになる.なので目的関数は次のようにかける.

\displaystyle
L(\phi_i,\mathcal{D})=\underset{(s,a,r,s',d)\sim\mathcal{D}}{\mathbb{E}}\left[\left(Q_{\phi_i}(s,a)-(r+\gamma(1-d)V_{\phi_\mathrm{targ}}(s'))\right)^2\right]

Target networkはDDPGと同様,移動平均でパラメータを計算.

Vの学習

価値関数の学習は以下の価値関数とQ関数の関係を利用する.

\displaystyle
V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)]+\alpha H(\pi(\cdot|s))=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)-\alpha\log\pi(a|s)]

ここでの期待値計算は確率的な方策からのサンプリングを使って次のように近似する.

\displaystyle
V^\pi(s)\approx Q^\pi(s,\tilde{a})-\alpha\log\pi(\tilde{a}|s),\ \tilde{a}\sim\pi(\cdot|s)

なので方策はサンプリングしやすい分布である必要がある.ここでのQ関数はTD3と同様に二つのQ関数の最小値として計算する(clipped couble-Q).なのでVに関する目的関数は次のようになる.

\displaystyle
L(\phi,\mathcal{D})=\underset{s\sim\mathcal{D},a\sim\pi_\theta}{\mathbb{E}}\left[\left(V_\phi(s)-\left(\min_{i=1,2}Q_{\phi_i}(s,\tilde{a})-\alpha\log\pi_\theta(\tilde{a}|s)\right)\right)^2\right]

実装上の注意としてはサンプリングによる近似を用いているためreplay bufferのactionは使わないということ.

Policyの学習

Policyは今までと同様Q関数を最大とする行動を返すように学習するが,今回はentropy regularizationが入っているため次の値の最大化として学習する.

\displaystyle
\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)-\alpha\log\pi(a|s)]

ここでの計算にもサンプリングが必要になるが,実装ではpolicyにガウス分布を仮定しているため,VAEでおなじみのreparameterization trickを使うことで計算可能.ただpolicyからのサンプリング部分は実装を見ればわかるが学習を安定させるためのいくつかのtipsがあるので注意.具体的にはガウス分布の平均や分散,選択される行動の値がぶっ飛んだ値にならないようtanhやclippingによって有限の値になるように抑えている.ここら辺は言葉よりも実装見た方が早いので細かいことは割愛.

諸々を端折って結論だけ書くとpolicyの目的関数は次のようになる.

\displaystyle
\max_\theta\underset{s\sim\mathcal{D},\xi\sim\mathcal{N}}[Q_{\phi_1}(s,\tilde{a}_\theta(s,\xi))-\alpha\log\pi_\theta(\tilde{a}_\theta(s,\xi)|s)]

ただし,\tilde{a}_\theta(s,\xi)はpolicyからreparameterization tricktanhやclippingによるまるめ込みを使ってサンプリングされたaction.

更新の順番としてはVとQの更新->更新されたQを使ったpolicyの更新->target networkの更新の順番.

実装

以下,実装.policyがgaussianに変わったのとサンプリングに関する細かいテクニック以外は特に変わったところはない.

"""core.py"""
import torch
import torch.nn as nn

import math

EPS = 1e-8
LOG_STD_MAX = 2
LOG_STD_MIN = -20

class gaussian_policy(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(obs_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        self.main = nn.Sequential(*layer)
        self.mu   = nn.Linear(hidden_layer[-1], act_dim)
        self.log_std = nn.Sequential(
            nn.Linear(hidden_layer[-1], act_dim),
            nn.Tanh()
        )

    def forward(self, obs):
        f = self.main(obs)
        mu, log_std = self.mu(f), self.log_std(f)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        
        pi = mu + torch.randn_like(mu) * log_std.exp()
        logp_pi = self.gaussian_likelihood(pi, mu, log_std)
        return mu, pi, logp_pi

    def gaussian_likelihood(self, pi, mu, log_std):
        return  torch.sum(-0.5 * (((pi-mu)/(log_std.exp()+EPS))**2 + 2 * log_std + math.log(2*math.pi)), 1)

class value_function(nn.Module):
    def __init__(self, inp_dim, hidden_layer=(400,300)):
        super().__init__()
        layer = [nn.Linear(inp_dim, hidden_layer[0]), nn.ReLU()]
        for i in range(1, len(hidden_layer)):
            layer.append(nn.Linear(hidden_layer[i-1], hidden_layer[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_layer[-1], 1))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        return self.policy(obs)

class actor_critic(nn.Module):
    def __init__(self, act_dim, obs_dim, hidden_layer=(400,300), act_limit=2):
        super().__init__()
        self.policy = gaussian_policy(act_dim, obs_dim, hidden_layer)

        self.q1 = value_function(obs_dim+act_dim, hidden_layer)
        self.q2 = value_function(obs_dim+act_dim, hidden_layer)
        self.v  = value_function(obs_dim, hidden_layer)
        self.act_limit = act_limit

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        self.v_targ = value_function(obs_dim, hidden_layer)

        self.copy_param()

    def pass_gradient_clip(self, x, low, high):
        clip_high = (x > high).float()
        clip_low  = (x < low).float()
        return x + ((high - x) * clip_high + (low - x) * clip_low).detach()

    def squashing(self, mu, pi, logp_pi):
        mu = mu.tanh()
        pi = pi.tanh()
        logp_pi = logp_pi - (self.pass_gradient_clip(1 - pi**2, 0, 1) + 1e-6).log().sum(1)
        return mu, pi, logp_pi

    def copy_param(self):
        self.v_targ.load_state_dict(self.v.state_dict())

    def get_action(self, obs):
        mu, pi, logp_pi = self.policy(obs)
        mu, pi, logp_pi = self.squashing(mu, pi, logp_pi)
        mu = self.act_limit * mu
        pi = self.act_limit * pi
        return mu, pi, logp_pi

    def update_target(self, rho):
        # compute rho * targ_p + (1 - rho) * main_p
        for v_p, v_targ_p in zip(self.v.parameters(), self.v_targ.parameters()):
            v_targ_p.data = rho * v_targ_p.data + (1-rho) * v_p.data

    def compute_v_target(self, obs, alpha):
        _, pi, logp = self.get_action(obs)
        q1, q2 = self.q1(torch.cat([obs, pi], 1)), self.q2(torch.cat([obs, pi], 1))
        q = torch.min(q1, q2).squeeze()
        return (q - alpha * logp.squeeze()).detach()

    def compute_q_target(self, obs, gamma, rewards, done):
        # compute r + gamma * (1 - d) * V(s')
        return (rewards + gamma * (1-done) * self.v_targ(obs).squeeze()).detach()

    def q_function(self, obs, pi):
        q1, q2 = self.q1(torch.cat([obs, pi], 1)), self.q2(torch.cat([obs, pi], 1))
        return q1.squeeze(), q2.squeeze()

    def q_function_w_entropy(self, obs, alpha):
        _, pi, logp_pi = self.get_action(obs)
        q1 = self.q1(torch.cat([obs, pi], 1)).squeeze()
        H = -logp_pi * alpha
        return q1 + H.squeeze()
"""sac.py"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gym, time
import numpy as np

from spinup.utils.logx import EpochLogger
from core import actor_critic as ac

class ReplayBuffer:
    def __init__(self, size):
        self.size, self.max_size = 0, size
        self.obs1_buf = []
        self.obs2_buf = []
        self.acts_buf = []
        self.rews_buf = []
        self.done_buf = []

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf.append(obs)
        self.obs2_buf.append(next_obs)
        self.acts_buf.append(act)
        self.rews_buf.append(rew)
        self.done_buf.append(int(done))
        while len(self.obs1_buf) > self.max_size:
            self.obs1_buf.pop(0)
            self.obs2_buf.pop(0)
            self.acts_buf.pop(0)
            self.rews_buf.pop(0)
            self.done_buf.pop(0)

        self.size = len(self.obs1_buf)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(low=0, high=self.size, size=(batch_size,))
        obs1 = torch.FloatTensor([self.obs1_buf[i] for i in idxs])
        obs2 = torch.FloatTensor([self.obs2_buf[i] for i in idxs])
        acts = torch.FloatTensor([self.acts_buf[i] for i in idxs])
        rews = torch.FloatTensor([self.rews_buf[i] for i in idxs])
        done = torch.FloatTensor([self.done_buf[i] for i in idxs])
        return [obs1, obs2, acts, rews, done]

def ddpg(env_name, actor_critic_function, hidden_size,
        steps_per_epoch=5000, epochs=100, replay_size=int(1e6), gamma=0.99, 
        polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000, 
        max_ep_len=1000, logger_kwargs=dict()):

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    replay_buffer = ReplayBuffer(replay_size)

    env, test_env = gym.make(env_name), gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    act_limit = int(env.action_space.high[0])

    actor_critic = actor_critic_function(act_dim, obs_dim, hidden_size, act_limit)

    value_optimizer = optim.Adam([
        {"params":actor_critic.q1.parameters()},
        {"params":actor_critic.q2.parameters()},
        {"params":actor_critic.v.parameters()}
    ], lr)
    policy_optimizer = optim.Adam(actor_critic.policy.parameters(), lr)

    start_time = time.time()

    obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0
    total_steps = steps_per_epoch * epochs

    for t in range(total_steps):
        if t > 50000:
            env.render()
        if t > start_steps:
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            _, act, _ = actor_critic.get_action(obs_tens)
            act = act.detach().numpy().reshape(-1)
        else:
            act = env.action_space.sample()

        obs2, ret, done, _ = env.step(act)

        ep_ret += ret
        ep_len += 1

        done = False if ep_len==max_ep_len else done

        replay_buffer.store(obs, act, ret, obs2, done)

        obs = obs2

        if done or (ep_len == max_ep_len):
            for _ in range(ep_len):
                obs1_tens, obs2_tens, acts_tens, rews_tens, done_tens = replay_buffer.sample_batch(batch_size)

                q_targ = actor_critic.compute_q_target(obs2_tens, gamma, rews_tens, done_tens)
                v_targ = actor_critic.compute_v_target(obs1_tens, alpha)

                q1_val, q2_val = actor_critic.q_function(obs1_tens, acts_tens)
                q_loss = 0.5 * (q_targ - q1_val).pow(2).mean() + 0.5 * (q_targ - q2_val).pow(2).mean()

                v_val  = actor_critic.v(obs1_tens).squeeze()
                v_loss = 0.5 * (v_targ - v_val).pow(2).mean()

                value_loss = q_loss + v_loss

                value_optimizer.zero_grad()
                value_loss.backward()
                value_optimizer.step()

                policy_loss = -actor_critic.q_function_w_entropy(obs1_tens, alpha).mean()
                policy_optimizer.zero_grad()
                policy_loss.backward()
                policy_optimizer.step()


                logger.store(LossQ=q_loss.item(), Q1Vals=q1_val.detach().numpy(), Q2Vals=q2_val.detach().numpy())
                logger.store(LossV=v_loss.item(), VVals=v_val.detach().numpy())
                logger.store(LossPi=policy_loss.item())

                actor_critic.update_target(polyak)

            logger.store(EpRet=ep_ret, EpLen=ep_len)
            obs, ret, done, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        if t > 0 and t % steps_per_epoch == 0:
            epoch = t // steps_per_epoch

            # test_agent()
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            # logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            # logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('VVals', with_min_and_max=True) 
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Pendulum-v0')
    parser.add_argument('--hid', type=int, default=300)
    parser.add_argument('--l', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='sac')
    args = parser.parse_args()

    from spinup.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    ddpg(args.env, actor_critic_function=ac,
        hidden_size=[args.hid]*args.l, gamma=args.gamma, epochs=args.epochs,
        logger_kwargs=logger_kwargs)

まとめ

これでspinning upにあるQ-learningに関するアルゴリズムの実装は終了.Policy gradient関連はTRPOがヘシアンの計算等を必要として実装までは面倒でやらなさそう.ひとまずspinning upでの強化学習勉強はここで終わる予定.

Spinning upの感想としてはよくまとまっていてpolicy gradientからq-learningまで非常にわかりやすく解説されている気がする.前に強化学習を本で勉強しようとした時には長ったらしい理論的な背景をガンガン説明されて辟易したが,spinning upはアルゴリズムの理解と実装に焦点を当てて理論的な点は飛ばしているので直感的な理解が非常にしやすかった.逆にいえば細かい理論は飛ばされているのでその辺に関する知識をつけたい人にはあまり意味のない内容かと.