In this notebook, we will explore the problem of balancing a cartpole using reinforcement learning techniques, with a particular focus on leveraging transformers to make decisions based on textual descriptions of the environment state.
Problem Statement: The cartpole balancing problem is a classic control task in reinforcement learning, where the objective is to balance a pole attached to a cart by applying forces on the cart. The system is unstable, and without intervention, the pole will fall. The agent learns to take appropriate actions to keep the pole upright by interacting with the environment.
Goal: The main goal of this notebook is to demonstrate how transformers can be used to learn and make decisions based on textual descriptions of the cartpole's state. We will use a transformer model to predict the best action given a textual input representing the state. After training the model, we will integrate it with the cartpole environment to test its performance in balancing the pole.
!pip install wandb
!pip install transformers
This block imports the required libraries, sets up the CartPole environment using OpenAI Gym, initializes matplotlib
for visualization, and selects the appropriate device (GPU
or CPU
) for running the code.
import gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import wandb
# Create the CartPole environment
env = gym.make("CartPole-v1")
# Set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
# Use GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
def __init__(self, n_observations, n_actions):
super(DQN, self).__init__()
self.layer1 = nn.Linear(n_observations, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_actions)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
# Hyperparameters
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state = env.reset()
n_observations = len(state)
# Initialize policy and target networks
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
# Set up optimizer and replay memory
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)
# Counter for the number of steps taken
steps_done = 0
This block defines two helper functions. The select_action
function takes the current state as input and returns an action to take. It uses an epsilon-greedy strategy, where the probability of taking a random action decreases over time. If the random sample is greater than the current threshold, the function uses the policy network to select the action with the highest expected reward. Otherwise, it randomly selects an action from the action space.
The optimize_model
function performs one step of the optimization process on the policy network. It first checks if the replay memory has enough transitions to sample a batch of size BATCH_SIZE
. If not, the function returns without updating the network. If enough transitions are available, the function samples a batch of transitions, computes the state-action values for the current state and selected action, and computes the expected state-action values using the target network. It then computes the Huber loss between the two sets of values and uses the optimizer to backpropagate and update the policy network's weights.
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * \\
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
with torch.no_grad():
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()
This block defines the plot_durations
function, which is used to visualize the training progress. It creates a new figure and plots the episode durations as a line graph. If the show_result
argument is set to True, the function changes the title of the plot to "Result" to indicate that this is the final plot after training is completed. If show_result
is False, the function clears the plot and changes the title to "Training..." to indicate that training is still ongoing.