機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

OpenAIのSpinning Upで強化学習を勉強してみた その4

はじめに

その4ということで具体的なアルゴリズムの実装をpytorchでしてみる.今回はalgorithms docsの中にあるVanilla Policy Gradient(VPG)の実装をする.

Vanilla Policy Gradient

VPGはSpinning Upのintroduction to rlのpart 3で扱ったsimplest policy gradientの拡張.Policy gradientは勾配法を用いて得られる報酬を最大化するようなpolicyを求めるというもの.具体的にはpolicyに関する勾配は次のように与えられる.

\displaystyle
\nabla_\theta J(\pi_\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\log\pi_\theta(a_t|s_t)\Phi_t\right]

具体的な導出とnotationはイントロのpart 3を参照.\Phi_tにはrewardやQ関数などが使われる.前回のsimplest policy gradientではrewardを用いていたが今回のVPGではadvantage functionを用いる.advantage functionは次の形で与えらえる.

\displaystyle
A^{\pi}(s_t,a_t)=Q^{\pi}(s_t,a_t)-V^\pi(s_t)

Simplest policy gradientのように\Phiにrewardを用いた場合,得られる勾配は実際に取られたactionの尤度を最大化するような勾配になっており,必ずしも最適なpolicyを学習するとは限らない.つまり,取られた行動が全体として報酬を減らすような行動だったとしてもその尤度を最大化するように学習してしまう.それに対しadvantage functionは良いpolicyを学習することができる.というのも,価値関数がQ関数をactionに対して周辺化したものであることを思い出せば,このadvantage functionは実際に取られたactionが行動の平均的な価値より良いか悪いかを表現することができるため,価値の低い行動をとった場合にはその尤度を下げることが可能.この尤度を下げるという勾配は\Phiにrewardを使った場合には得られない.

ここでの問題はadvantage functionをどのようにして得るかということ.advantage function(というかQ関数と価値関数)は陽に求めることができないため何らかの方法で推定する必要がある.ここではGeneralized Advantage Estimation (GAE)という方法を用いてadvantage functionを推定する.細かい導出は長くなるのと論文に丁寧に描かれているのでここでは省略し,次の最終的な形だけ.

\displaystyle
\hat{A}_t^{\mathrm{GAE}(\gamma,\lambda)}:=-V(s_t)+\sum_{l=0}^\infty(\lambda\gamma)^lr_{t+l}

\lambda,\gammaは減衰係数でハイパーパラメータ.\lambdaが1の時には推定されるadvantage functionはhigh variance,low biasで0の時にはhigh bias,low varianceになる.基本的にはbiasのない値が欲しいためどちらの係数も1に近いものを選ぶ.

GAEの式を見ると価値関数が入っていて,やはりまだ陽に計算することができない.なのでこの価値関数をニューラルネットで推定しようというのがここでの解決方法.具体的には以下の最小化問題を解くことで価値関数を推定するニューラルネットを学習する.

\displaystyle
\phi_{k+1}=\underset{\phi}{\mathrm{argmin}}\frac{1}{|\mathcal{D}_k|T}\sum_{\tau\in\mathcal{D}_k}\sum_{t=0}^T\left(V_\phi(s_t)-\hat{R}_t\right)^2

\phiニューラルネットのパラメータで\mathcal{D}はエピソードの数.この最小化問題を利用して学習したニューラルネットを用いてadvantage functionを推定し,推定されたadvantage functionを使って次の勾配を使って勾配上昇法によりpolicyを学習する.ただ,実際にはpolicyとvalue functionは交互に学習していく.詳細はspinning upのページにあるPseudocodeを参照.

\displaystyle
\nabla_\theta J(\pi_\theta)=\frac{1}{|\mathcal{D}_k|}\sum_{\tau\in\mathcal{D}_k}\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|s_t)|_{\theta_k}\hat{A}_t

アルゴリズムの目的や式が直感的にもわかりやすい.このコードをpytorchで書くと次のような感じ.

"""core.py"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.signal

class categorical_policy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=(64,64)):
        super().__init__()
        layer = [
            nn.Linear(obs_dim, hidden_size[0]),
            nn.ReLU()
        ]
        for i in range(1, len(hidden_size)):
            layer.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            layer.append(nn.ReLU())
        layer.append(nn.Linear(hidden_size[-1], act_dim))
        self.policy = nn.Sequential(*layer)

    def forward(self, obs):
        pi = self.policy(obs)
        return pi

class actor_critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=(64,64)):
        super().__init__()
        layer = [
            nn.Linear(obs_dim, hidden_size[0]),
            nn.ReLU()
        ]
        for i in range(1, len(hidden_size)):
            layer.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            layer.append(nn.ReLU())

        layer.append(nn.Linear(hidden_size[-1], 1))
        self.value_function = nn.Sequential(*layer)

        self.policy = categorical_policy(obs_dim, act_dim, hidden_size)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, obs):
        pi = self.policy(obs)
        a  = torch.multinomial(pi.softmax(1), 1).squeeze()
        log_pi = F.log_softmax(pi, 1)
        log_pi = (torch.eye(log_pi.shape[-1])[a] * F.log_softmax(pi, 1)).sum(1)

        v = self.value_function(obs).squeeze()
        return a, log_pi, v

    def value(self, obs):
        return self.value_function(obs).squeeze()

    def likelihood(self, obs, a):
        pi = self.policy(obs)
        log_pi = F.log_softmax(pi, 1)
        log_pi = (torch.eye(log_pi.shape[-1])[a] * F.log_softmax(pi, 1)).sum(1)
        return log_pi

def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.
    input: 
        vector x, 
        [x0, 
         x1, 
         x2]
    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
"""vpg.py"""
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

import time, gym

from core import actor_critic as ac
from core import discount_cumsum

class VPGBuffer:
    """
    A buffer for storing trajectories experienced by a VPG agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.act_buf = np.zeros([size], dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logp):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.ptr += 1

    def finish_path(self, last_val=0):
        """
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.
        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)
        
        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
        
        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
        
        self.path_start_idx = self.ptr

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = self.adv_buf.mean(), self.adv_buf.std()
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        return [self.obs_buf, self.act_buf, self.adv_buf, 
                self.ret_buf, self.logp_buf]

def vpg(env_name, actor_critic_func, ac_kwargs=dict(), seed=0,
        steps_per_epoch=4000, epochs=50, gamma=0.99, pi_lr=3e-4,
        vf_lr=1e-3, train_v_iters=80, lam=0.97, max_ep_len=1000,
        save_freq=10):

    env = gym.make(env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    buf = VPGBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

    actor_critic = actor_critic_func(obs_dim, act_dim, ac_kwargs["hidden_sizes"])

    torch.random.manual_seed(seed)
    np.random.seed(seed)

    optim_policy = optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    optim_value  = optim.Adam(actor_critic.value_function.parameters(), lr=vf_lr)

    def update():
        optim_policy.zero_grad()
        inputs = buf.get()

        obs_tens = torch.from_numpy(inputs[0]).float()
        act = torch.from_numpy(inputs[1]).long()

        log_pi = actor_critic.likelihood(obs_tens, act)

        pi_loss = -(log_pi * torch.from_numpy(inputs[2]).float()).mean()
        pi_loss.backward()
        optim_policy.step()

        for _ in range(train_v_iters):
            optim_value.zero_grad()
            v_t = actor_critic.value(obs_tens)
            v_loss  = (torch.from_numpy(inputs[3]).float() - v_t).pow(2).mean()
            v_loss.backward()

        optim_value.step()

        return pi_loss.item(), v_loss.item()

    start_time = time.time()
    obs, ret, kill, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    for epoch in range(epochs):
        for t in range(steps_per_epoch):
            # if epoch > 40:
                # env.render()
            obs_tens = torch.from_numpy(obs).float().reshape(1,-1)
            act, log_pi, v_t = actor_critic(obs_tens)
            buf.store(obs, act[0].item(), ret, v_t.item(), log_pi.item())

            obs, ret, kill, _ = env.step(act[0].item())

            ep_ret += ret
            ep_len += 1

            terminal = kill or (ep_len == max_ep_len)
            if terminal or (t==steps_per_epoch-1):
                if not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len)
                last_val = ret if kill else actor_critic.value(torch.from_numpy(obs).float()).item()
                buf.finish_path(last_val)
                print(ep_len)
                obs, ret, kill, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        pi_loss, v_loss = update()
        print("pi loss : ", pi_loss)
        print("v loss : ", v_loss)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='CartPole-v0')
    parser.add_argument('--hid', type=int, default=64)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=4)
    parser.add_argument('--steps', type=int, default=4000)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='vpg')
    args = parser.parse_args()
    vpg(args.env, actor_critic_func=ac,
        ac_kwargs=dict(hidden_sizes=[args.hid]*args.l), gamma=args.gamma, 
        seed=args.seed, steps_per_epoch=args.steps, epochs=args.epochs)

元のtensorflowのコードを流用して書いたのでだいぶ汚いコードになっているので注意(次の実装からはちゃんとやります).cartpoleでしか試していないが,simplest policy gradientに比べ学習自体は遅くなっている.実装ミスではなく学習パラメータが増えたためと思いたい…

まとめ

Advantage functionを推定するGAE自体はICLR2016と意外と新しくて驚き.論文ではspinning upのalgorithms docs内のTRPOとして実装されていたが,単純な形の実装に落として紹介しているあたり教材としてよく考えられているなという印象.強化学習初学者の自分にも気持ちや実装方法などがわかりやすくよかった.