Takuma Yoneda

Policy Gradient Method

What is Policy Gradient?

Roughly speaking, Reinforcement Learning methods can be classified into value-based methods and policy-based methods. Value-based methods learn value function from the experience, and agent chooses an action that has the highest value (e.g., \hat{a} = \text{arg}\max_{a} Q(s, a)).
Policy-based methods learn policy \pi(a|s) directly from the experience. They can naturally handle stochastic policy since \pi(a|s) is defined to be the distribution over actions, and can also work with continuous action space which is tricky for value-based methods due to its maximization over continuous action space.

Policy gradient method is one of the policy-based methods. This method assumes the policy is differentiable, and iteratively calculate the update of parameters using its gradient.

Why/When do we want Policy Gradient?

There are several cases where policy gradient would have an advantage than other methods.

  • When the policy is easier to model than value function.
    For example, in the game of Breakout (Atari), it might not be easy to predict the score from the display, however a simple policy that just follow the ball might work well.
  • When the optimal policy is not deterministic.
    For example, in the game of rock-paper-scissors, the optimal policy is to choose the hand at uniformly random.
    Value based RL methods cannot learn this stochastic policy, since the action with highest state-action value is always selected.
  • When the action space is continuous.
    There is no natural way to handle continuous action for value based methods, because taking the maximum of a value function over all possible actions is intractable. On the other hand, policy gradient methods directly model the policy and therefore it can naturally handle continuous action space.

Basics of Policy Gradient

In this post, we focus on non-discounted (i.e., \gamma = 1), episodic cases (the episode terminates in finite steps. cf. continual case / life-long learning, etc.), but the following explanation and proof applies to continual or discounted case with a bit of modifications.
As the name suggests, policy gradient method updates policy parameters by gradient ascent.
We assume the policy is differentiable[1].
Now, we define a function J(\theta) that represents the performance of the policy.
In episodic cases, it can be defined as the expected returns of following the policy \pi_{\theta} starting from the initial state s_{0}, which is indeed the definition of value function.

J(\theta) = E[\sum_{t=1}^{T}{r_t} | s_{0}, \pi_{\theta}] = V^{\pi_\theta} (s_{0}) \tag{1}

In the following part, we will calculate the gradient of J(\theta). Once we get the gradient, the update of the parameters will be straightforward:

\theta \leftarrow \theta + \alpha \nabla_{\theta} J(\theta)

, where \alpha is learning rate.

How to calculate \nabla J(\theta)

All we need is just to take the derivative of J(\theta) w.r.t. \theta. Why don't we just calculate it?
Well, it's not that easy. It is easy to calcluate how the distribution over actions \pi_\theta(a|s) would change when we slightly modify \theta, however, as we saw in equation (1), J(\theta) also depends on the environment (i.e., how much more reward you would get by following the perturbed policy).
Due to the lack of knowledge of the environment, there is no obvious way to calculate the gradient.

But surprisingly, Policy Gradient Theorem[2] shows the gradient is written out in the following beautiful way.

\nabla J(\theta) \propto \sum_{s} \mu(s) \sum_{a}{\nabla \pi_{\theta}(a|s) \cdot Q^{\pi_{\theta}}(s, a)} \tag{2}

,where \mu(s) is on-policy distribution which represents the time spent in a state s.
Note that Q^{\pi}(s, a) is the total episodic reward, which can be estimated with samples in practice.

In (Sutton et al., 2000), which presented the proof, authors commented about this equation as:

their [sic] are no terms of the form \frac{\partial d^{\pi_{\theta}}(s)}{\partial\theta}: the effect of policy changes on the distribution of states does not appear. This is convenient for approximating the gradient by sampling.

We will put the proof in the end of this post. For here, we show the simpler proof for One-step MDP, where the episode terminates immediately after choosing an action.
I hope this provides an intuition for the equation above.
Let d(s) be the probability distribution over the initial states. Since we only consider one-step,

\begin{align} J(\theta) & = E_{s \sim d(\cdot)}[E_{a \sim \pi_{\theta}(\cdot|s)}[R_{s}^{a}]] \\ & = \sum_{s} d(s)\sum_{a} \pi_{\theta} (a|s) \cdot R_{s}^{a} \\ \therefore \nabla J(\theta) & = \sum_{s} d(s)\sum_{a} \nabla \pi_{\theta} (a|s) \cdot R_{s}^{a} \end{align}

One can see the correspondence between this and equation (1).

Interpretation of the equation
\nabla \pi_{\theta}(a|s) is the direction in which the probability to choose action a increases. And it is weighted by its value for each action.
Thus, the policy will be updated so that it chooses an action that the agent has experienced a large reward.
On top of that, the outer sum weighs it according to the probability to visit a state s.
This means that the update for frequently visited states will have more importance.

Another form of Policy Gradient Theorem

Equation (2) is also written in the following form.

\begin{align} \nabla J(\theta) &\propto \sum_{s} \mu^{\pi_\theta}(s) \sum_{a}{\nabla \pi_{\theta}(a|s) \cdot Q^{\pi_{\theta}}(s, a)}\\ &= E_{s \sim \pi_{\theta}}[\sum_{a}{\nabla \pi_{\theta}(a|s) \cdot Q^{\pi_{\theta}}(s, a)}]\\ &= E_{s \sim \pi_{\theta}}[\sum_{a}{\pi(a|s) \cdot \frac{\nabla \pi_{\theta}(a|s)}{\pi(a|s)} \cdot Q^{\pi_{\theta}}(s, a)}]\\ &= E_{s \sim \pi_{\theta}}[E_{a \sim \pi_{\theta}(\cdot|s)}[\frac{\nabla \pi_{\theta}(a|s)}{\pi(a|s)} \cdot Q^{\pi_{\theta}}(s, a)]]\\ &= E_{s, a \sim \pi_{\theta}}[\frac{\nabla \pi_{\theta}(a|s)}{\pi_{\theta}(a|s)} \cdot Q^{\pi_{\theta}}(s, a)] \\ \therefore \nabla J(\theta) &\propto E_{s, a \sim \pi_{\theta}}[\frac{\nabla \pi_{\theta}(a|s)}{\pi_{\theta}(a|s)} \cdot Q^{\pi_{\theta}}(s, a)] \tag{3} \end{align}
Full derivation of Policy Gradient Theorem

Policy Gradient Theorem

Note: this proof is identical to the proof shown in P.325 in [Sutton and Barto, 2018]

\begin{align} \nabla v_\pi(s) &= \nabla [\sum_a \pi(a|s) q_\pi (s, a)], ~~ \text{for all}~~ s \in S \\ &= \sum_a [\nabla \pi(a|s) q_\pi (s, a) + \pi(a|s) \nabla{\color{green} q_\pi (s, a)}] ~~~~~~~~ \text{(chain rule)}\\ &= \sum_a [\nabla \pi(a|s) q_\pi(s, a) + \pi(a|s) \nabla {\color{green} \sum_{s', r} p(s', r| s, a) (r + v_\pi(s'))}]\\ &= \sum_a [\nabla \pi(a|s) q_\pi(s, a) + \pi(a|s)\sum_{s'} p(s'|s, a) {\color{blue} \nabla v_{\pi} (s')}] \\ &= \sum_a [\nabla \pi(a|s)q_{\pi}(s, a) + \pi(a|s) \sum_{s'} p(s'|s,a) {\color{blue} \sum_{a'}[\nabla \pi (a'|s')q_\pi (s', a')+ \pi (a'|s')\sum_{s''}p(s''|s', a'){\color{purple} \nabla v_\pi (s'')} ]}] \\ &= \sum_a [\nabla \pi(a|s)q_{\pi}(s, a) + \underline{\pi(a|s) \sum_{s'} p(s'|s,a) {\color{blue} \sum_{a'}[\nabla \pi (a'|s')q_\pi (s', a')} }\\ &~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ {\color{blue} + \pi (a'|s')\sum_{s''}p(s''|s', a')} {\color{purple} \sum_{a''}[\nabla\pi(a''|s'') q_\pi(s'', a'') + \pi(a''|s'')\sum_{s'''}p(s'''|s'', a'') {\color{brown} \nabla v_{\pi}(s''') }} ] \\ &= \sum_a [\nabla \pi(a|s)q_{\pi}(s, a) + \underline{\sum_{s'\in S} Pr(s\rightarrow s', 1, \pi)} + \sum_{s'' \in S} Pr(s\rightarrow s'', 2, \pi) + \cdots \\ &= \sum_{x \in S} \sum_{k=0}^\infty Pr(s\rightarrow x, k, \pi) \sum_a \nabla \pi(a|x) q_\pi(x, a), \end{align}

where Pr(s\rightarrow x, k , \pi) is the probability of transitioning from state s to state x in k steps under policy \pi.


\begin{align} \nabla J(\theta) &= \nabla v_\pi(s_0) \\ &= \sum_s (\sum_{k=0}^{\infty} Pr(s_0 \rightarrow s, k, \pi) ) \sum_a \nabla \pi(a|s) q_{\pi} (s, a) \\ &= \sum_s \eta(s) \sum_a \nabla\pi(a|s) q_\pi(s, a) \\ &= \sum_{s'} \eta(s') \sum_s \frac{\eta(s)}{\sum_{s'} \eta(s')} \sum_a \nabla \pi(a|s) q_\pi(s, a) \\ &= \sum_{s'} \eta(s') \sum_s \mu(s) \sum_a \nabla \pi(a|s) q_\pi(s, a) \\ &\propto \sum_s \mu(s) \sum_a \nabla \pi(a|s) q_\pi(s, a) \end{align}
  1. Even if not, it is possible to calculate the semi-gradient with perturbation based on the definition of derivatives ↩︎

  2. One can have a look at the proof in Sutton et al.(RL book) or the original version (Sutton, McAllester et al) ↩︎