This article dives into the foundations of Monte Carlo Tree Search, a multi-armed bandit style technique of balancing exploratin and exploitation while navigating a complex and large search space of possible decisions.
Recap: Reinforcement Learning Terminology
Before we start our discussion of tree search algorithms, let’s recap common RL terminology that will be used in the rest of this article. Every problem in RL starts with the “environment”. The environment is the space in which our agent can act. A “state” the observable subset of the environment that the agent is provided as input. Each “agent” that exists within an environment is defined by a “policy”, often represented by $\pi$, which defines what action it will take when given a state. The goal of this policy is to maximize “reward” through the environment and thus learn to take good actions.
Each produced action is judged by some sort of “reward model” to produce a reward. The reward can be a network, but can also be a combination of heuristics.
Introduction to Game Trees
A game tree is a representation of a game such that every node is a game state and every edge is an action that transitions that game state into some related adjacent state.
For the following sections, we will use a tic-tac-toe game tree as our motivating example. If we were to attempt to design an agent that would beat a human at tic-tac-toe, how could we go about this?
Algorithm 1: Enumerative Search — Minimax
The core idea behind minimax is to play out all possible scenarios before making a move— in doing so, one chooses the branch that minimizes your opponent’s best possible score while maximizing your own. We won’t go further into the minimax algorithm here, but [1] is a good resource to learn more.
With enough resources, one can play out the entirety of all possible branches and play completely optimally.
but even for a game as simple as tic-tac-toe, this can get extremely complicated!
Enumeration Doesn’t Scale
Readers might already have intuition about why enumerative search doesn’t scale; let’s formalize this. Imagine it takes an average of $d$ moves to complete a game, and that at each possible step there are an average of $b$ possible moves — this means that the size of the search space would be $O(b^d)$!
For tic-tac-toe, this ends up being 5^9 = 2M possible moves each turn.
If you think that’s a lot, for more complex games like chess and go, the numbers are ~10^40 and ~10^170 respectively. This is more than the number of atoms in the universe.
Algorithm 2: Reducing Computation with Heuristics
There are only two ways to reduce computation — either
- we decrease the set of moves we allow ourselves to explore per step $b$ or
- we stop exploring at a certain depth $d$ and instead estimate the rest of the tree’s scores.
In practice, these decisions are often made with heuristics. An example of this is stockfish! It prunes its action space by using a set of heuristics based on a given board and was at one point the world’s best chess engine.
Exploration / Exploitation Curves
To search efficiently means to have to balance an exploration / exploitation tradeoff. In a game with many possible actions and many actions before the end of the game, you can split compute to either spend time more deeply exploring a subset of the game tree (exploitation), or more generally exploring a lot more of the game tree (exploration).
Exploration lets us learn the values that we are uncertain about. Exploitation lets us focus more on the most promising parts of the game tree.
Algorithm 3: Monte Carlo Tree Search
The core idea behind monte-carlo tree search is simple: a game tree is built incrementally and assymetrically. On each iteration of the monte-carlo algorithm, we add a leaf to the game tree. Each node in the tree possess a value associated with it about how promising the node is.
The algorithm plays a lot of random games. It keep the statistics of the winning games. The algorithm keeps doing this as long as there is time. One done, the algorithm plays the move that has the highest win percentage.
Let’s inspire a couple of design decisions within Monte Carlo Tree Search before jumping into the algorithm itself.
Multi-Armed Bandits
One can liken the choice of path to explore in MCTS to a multi-armed bandit problem. Imagine you have $k$ slot machines, which you can play simultaneously and independently of each other. Each machine has its own distribution and a fixed, unknown average payoff.
How do you find the machine with the best payoff?
- Explore all machines
- Exploit promising machines more often
- Minimize the regret of playing poor machines
Upper Confidence Bounds (UCB)
In a multi-armed bandit problem, we have to ensure that the optimal machine is not missed due to promising rewards from a suboptimal arm. So, we place an upper confidence bound on our rewards earned so far.
Then, we optimistically select the arm with the highest confidence bound, increasing the required confidence over time.
it typically looks something like the following:
$$ A_t = \argmax_{a}\text{ } \overbrace{R_t(a)}^{\text{exploit}} + c \underbrace{\sqrt{\frac{\ln(t)}{N_t(a)}}}_{\text{explore}} $$
where $R_t$ defines the current value of the action $a$ at time $t$ and $N_t$ describes how many times $a$ has been played to date.
Think about it this way: if $t$ is large but $N_t$ is small, the exploration term will be large and the action is more likely to be explored even if it has a low current estimate value $R_t$.
MCTS Algorithm Breakdown
The MCTS algorithm is made of four steps:
- Selection: Give a game tree, find a leaf node of choice based on some selection criteria.
- Expansion: For this node, expand the node by adding children from the set of possible actions permitted form this state.
- Simulation: For one of the children, randomly simulate the rest of the game until termination. get the reward signal from this terminal state
- Backprop: Pass this reward value back up the path to the root node of the tree.
Here is the reference code block for the algorith:
def monte_carlo_tree_search(root_state):
"""
Monte Carlo Tree Search algorithm:
1. Selection: Choose promising nodes using UCB1
2. Expansion: Add a new child node
3. Simulation: Play random moves until terminal state
4. Backpropagation: Update statistics up the tree
"""
def select(node):
# Select child node with highest UCB1 value
# UCB1 = wins/visits + C * sqrt(ln(parent_visits)/visits)
while not is_terminal(node) and has_children(node):
node = select_best_ucb(node)
return node
def expand(node):
# Create a new child node from an unexplored action
if not is_terminal(node):
new_state = apply_random_unexplored_action(node)
return create_node(new_state)
return node
def simulate(node):
# Play random moves until reaching terminal state
# Return result (win/loss/draw)
current_state = copy_state(node)
while not is_terminal(current_state):
action = select_random_action(current_state)
apply_action(current_state, action)
return evaluate_terminal_state(current_state)
def backpropagate(node, result):
# Update node statistics up the tree
while node is not None:
node.visits += 1
node.wins += result
node = node.parent
# Main MCTS loop
root = create_node(root_state)
for _ in range(computation_budget):
leaf = select(root)
child = expand(leaf)
result = simulate(child)
backpropagate(child, result)
return select_best_child(root)
Step 1 — Selection
def select(node):
# Select child node with highest UCT value
while not is_terminal(node) and has_children(node):
node = select_best_ucb(node)
return node
Let’s discuss how the selection step works by employing a formulation of upper confidence bound. This is known as UCT, or Upper Confidence Bound, Applied to Trees. We wish to calculate, with an upper confidence bound, which child we should choose when we are in some node state $R$.
Then, the formula looks something like this:
$$ UCT = \argmax_j X(j) + c * \sqrt{\frac{ln(N(R))}{N(j)}} $$
Where $X(j)$ is the win ratio of the child $j$ and $N(x)$ returns the number of times a node $x$ has been visited in the past. The principle remains the same: if we’ve visited the parent often and some child $j_1$ rarely, then we’ll have a higher exploitation term and a higher chance to visit the node.
Step 2 — Expansion
def expand(node):
# Create a new child node from an unexplored action
if not is_terminal(node):
new_state = apply_random_unexplored_action(node)
return create_node(new_state)
return node
Once we have a node selected, we initialize a new leaf. We take an unexplored action from the state, and create a new state, and add it to the tree. The exception here is if the node itself is a terminal state (eg: a completed tic-tac-toe game), in which case we simply return the node itself.
Step 3 — Simulation
def simulate(node):
accumulated_val = 0
for _ in NUM_ROLLOUTS:
# Play random moves until reaching terminal state
# Return result (win/loss/draw)
current_state = copy_state(node)
while not is_terminal(current_state):
action = select_random_action(current_state)
apply_action(current_state, action)
accumulated_val += evaluate_terminal_state(current_state)
return accumulated_val
But just adding the node to the tree is not very useful— we need some information about the value of this new node! To do this, we play some number of random games starting from the node. Over enough random games, a statistical bias will (hopefully) emerge that will mark this node as a good or bad spot for the agent.
Step 4 — Backprop
def backpropagate(node, result):
# Update node statistics up the tree
while node is not None:
node.visits += 1
node.wins += result
node = node.parent
And of course, once we have the information we need, we need to store it in the new node we made, and inform all parent nodes about it, so that the next round of selections can be better informed!
AlphaZero: MCTS with Deep Learning
AlphaZero is a separate algorithm to Monte Carlo Tree Search created by the folks at Deepmind [6]. AlphaGo is in particular applied to two-player zero-sum games, which you might be familiar as the world champions in games of Go and Chess. They even made a documentary about this! [7]
There are two observations to be made about the previous sections:
- We’ve yet to talk about any neural network being a part of the system. The policy is defined by the tree samples themselves at the moment.
- We don’t know how to adapt MCTS to two-player games yet.
In AlphaZero, we’ll do both of these things. The goal of this section is to build up to alphazero, which will allow us to understand why the decisions were made as they were.
MCTS for Two-Player Games
First, let’s resolve how to update MCTS for two-player games. The principle is the same as that of alpha-beta search — player 1 will try to maximize its (positive) reward and player 2 will try to minimize its (negative) reward. What effectively changes?
For AlphaZero, it makes sense for each node to think that it is “player 1” if the rules are symmetrical for both players (since then the model we will train can benefit from the shared data). This makes implementation a bit tricky, in ways we will not get into at the moment. Details can be read in Appendix A.
How to Source Your Data
First, imagine no model. There can’t be a model without training data, of course. So we need training data! How can we collect it?
To decide what data we want, we need to decide what to train!
- We definitely need to have some sort of policy that accepts a given state and then outputs a distribution over the next actions (see the original RL setup).
- We also need a reward model that could compare two states and tell us which one is more likely to end up in a win.
To train these two things — a policy head and a reward head — we need:
- the input (the state of the board) and
- the two outputs (some distribution over the next available actions and the final reward).
Let’s start our data collection at the easiest point — how about we just record some people playing some games? That would work! For each game, we have a ton of states as well as the final reward (whether player won or lost). We make our target action distribution = 1 at the chosen action, and 0 everywhere else.
This works, to an extent.
But getting data from people is slow and expensive. Is there something else? A kind property of games is that we can be self-verified. How about we just randomly select a next action for each state and let two random machines play each other? We’ll still have states, rewards, and our one-hot target action prob distribution.
This would work as well, to an extent.
The excellent idea of letting two machines play each other misses something important — humans are clever decision makers, and the human games would probably be more interesting than the random games. We like interesting games, because we need much less data to train a good model that way. After a point, the marginal value of uninteresting games to a model steadily decreases. So we need some fast way to get interesting games.
How can we make interesting games by just letting machines play each other? Enter, our good friend MCTS! As we keep playing rollouts in MCTS, the tree starts learning how to play to maximize its own win (which is what humans do) and games start to get interesting! They again give us what we want — the chosen action, the state, and the final reward.
So using the network we’re training and MCTS to collect data would work!
But we can do even better — with MCTS, we don’t need to create a one-hot target distribution. MCTS maintains a statistical win distribution over available actions for each state! So instead, we can train on state → (win/loss, MCTS action distribution) which will help our training go even faster.
This novel idea is known as “self-play”, since a player network is effectively playing against itself here. Note here that once we have a trained model, we’ll modify the MCTS to include it.
Train the model
So now, we have tons of data— let’s train the model. The input is a stack of 1/0 images to a CNN that represent the board, and the output is some policy distribution $p$ and some value $v$. Given MCTS distribution $\pi$ and final win/loss reward $z$, we want to minimize the loss
$$ L = \text{MSELoss}(z, v) + \text{CrossEntropy}(p, \pi) + \text{L2Regularization}(\theta) $$
Inference with the Model
Now, we have a trained model! We want to see if it’s any good. How do we run inference with it?
Let’s start by doing the simplest thing — for each board state, give the model the state. Either choose the most likely rated action or sample from the action distribution. Do this until win/lose.
This would work!
But this would only give the model one shot at deciding what actions to take. Can we use more compute at inference time to make a smarter decision?
We could do many random games and then take the most likely action from those at every stage. This is similar to majority voting — letting many samples vote for what they agree on as the best action, reducing variance.
But we could yet again take this one step further— we use MCTS at inference time as well.
MCTS with the model
Start with the root node. For selection — change nothing except adding a term to the UCT equation that multiplies the exploration term with how likely the policy head thinks the move can win. The statistical mean value of the state X remains unchanged.
$$ UCT = \argmax_j X(j) + c * p(j) *\sqrt{\frac{ln(N(R))}{N(j)}} $$
Once a leaf node is reached, get move probabilities $p$ and the value of the state $v$. Initialize child nodes to the leaf node with the move probabilities and store the value for the state. This value will now be backpropped to all parent nodes.
doDo this many times (alphazero does it for 1600 simulations). Now you have a monte-carlo tree that is based on statistically sampling from the model’s distribution and trusting the model’s rewards. Now, for each state you can choose the action that is most visited (and therefore, most likely to lead to a win according to our network).
Model MCTS back into self play
It’s time to come full cirlce. We have a pretty good model and a pretty good way to sample from it at inference time.
We want to use it to make an even better model. To do this, toss the model in the loop for self-play. Use the updated UCT mentioned above. everything else remains the same, but this especially allows us to explore areas of the search tree where the model value and the win/loss defer a bunch — allowing us to make improvements more targeted and giving us more interesting games!
Here it all is in totality:
I hope this breakdown helped you realize just how much thought goes into something like AlphaZero. I’d personally put it right alongside transformers as one of the coolest ideas of post-modern machine learning!
[8] is a cheat sheet I refer back to quite often. I believe it captures what alphazero is quite concisely.
References
[1] Roberts, E. “The Minimax Algorithm.” Stanford University.
[2] Duvenaud, D. “Introduction to Monte Carlo Tree Search (MCTS).” Learning to Search.
[3] Rezazadeh, R. “Monte Carlo Search Tree: Report.” Stanford University.
[4] Unknown Author. “Monte Carlo Tree Search Tutorial.” AIOps.org.
[5] Silver, D., et al. “Mastering the Game of Go without Human Knowledge.” arXiv preprint arXiv:1705.08439.
[6] Schrittwieser, J., et al. “Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.” arXiv preprint arXiv:1712.01815.
[7] Brown, J. (2016). “AlphaGo – The Movie | Full Documentary.” YouTube.
[8] Unknown Author. “AlphaGo Zero Cheat Sheet.”
[9] Brown, J. (2018). “DeepMind AI – AlphaZero Explained.” YouTube.
Appendix A: Two-Player MCTS
The following are changes that need to be made to support two-player MCTS:
- Each node thinks that it is player 1 and will try to maximize its rewards. Each node thinks that its children are player 2 and that its parents is also player 2, and will thus try to minize their rewards.
- In Selection— we update the UCT algorithm to take into account that the child is an opponent. So, instead of trying to find the maximum child, we’re trying to find the child with the minimum value. Typically, this is done by negating the child’s value before argmaxing.
- In backprop, each node knows that its parents are opponents, so we will negate the reward every time we backprop it up one layer (since our high reward is their low reward, and vice-versa).