In this notebook, we will explore the problem of balancing a cartpole using reinforcement learning techniques, with a particular focus on leveraging transformers to make decisions based on textual descriptions of the environment state.

Problem Statement: The cartpole balancing problem is a classic control task in reinforcement learning, where the objective is to balance a pole attached to a cart by applying forces on the cart. The system is unstable, and without intervention, the pole will fall. The agent learns to take appropriate actions to keep the pole upright by interacting with the environment.

Goal: The main goal of this notebook is to demonstrate how transformers can be used to learn and make decisions based on textual descriptions of the cartpole's state. We will use a transformer model to predict the best action given a textual input representing the state. After training the model, we will integrate it with the cartpole environment to test its performance in balancing the pole.

DQN Implementation

Install necessary packages

!pip install wandb
!pip install transformers

Imports and environment setup

This block imports the required libraries, sets up the CartPole environment using OpenAI Gym, initializes matplotlib for visualization, and selects the appropriate device (GPU or CPU) for running the code.

import gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import wandb

# Create the CartPole environment
env = gym.make("CartPole-v1")

# Set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# Use GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DQN Agent and Replay Memory

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

Model Initialization and Hyperparameter Selection

# Hyperparameters
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state = env.reset()
n_observations = len(state)

# Initialize policy and target networks
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

# Set up optimizer and replay memory
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

# Counter for the number of steps taken
steps_done = 0

Helper Functions for Action Selection/Optimization

This block defines two helper functions. The select_action function takes the current state as input and returns an action to take. It uses an epsilon-greedy strategy, where the probability of taking a random action decreases over time. If the random sample is greater than the current threshold, the function uses the policy network to select the action with the highest expected reward. Otherwise, it randomly selects an action from the action space.

The optimize_model function performs one step of the optimization process on the policy network. It first checks if the replay memory has enough transitions to sample a batch of size BATCH_SIZE. If not, the function returns without updating the network. If enough transitions are available, the function samples a batch of transitions, computes the state-action values for the current state and selected action, and computes the expected state-action values using the target network. It then computes the Huber loss between the two sets of values and uses the optimizer to backpropagate and update the policy network's weights.

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \\
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

Visualization

This block defines the plot_durations function, which is used to visualize the training progress. It creates a new figure and plots the episode durations as a line graph. If the show_result argument is set to True, the function changes the title of the plot to "Result" to indicate that this is the final plot after training is completed. If show_result is False, the function clears the plot and changes the title to "Training..." to indicate that training is still ongoing.