In [1]:
from tf_agents.environments import suite_gym
env = suite_gym.load("BreakoutNoFrameskip-v4")
env

<tf_agents.environments.wrappers.TimeLimit at 0x7fd82fa4a190>

In [2]:
env.reset()

TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]], dtype=uint8))

In [3]:
env.observation_spec()

BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255)

In [4]:
env.action_spec()

BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=3)

In [5]:
env.time_step_spec()

TimeStep(step_type=ArraySpec(shape=(), dtype=dtype('int32'), name='step_type'), reward=ArraySpec(shape=(), dtype=dtype('float32'), name='reward'), discount=BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0), observation=BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255))

In [6]:
env.gym.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT']

**Initialization**

In [7]:
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4

max_episode_steps = 27000
environment_name = "BreakoutNoFrameskip-v4"
#environment_name = "PongNoFrameskip-v4"
#environment_name = "MsPacman-v4"

env = suite_atari.load(
    environment_name,
    max_episode_steps=max_episode_steps,
    gym_env_wrappers=[AtariPreprocessing, FrameStack4])

In [8]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment
tf_env = TFPyEnvironment(env)

**Creation of the Q-Network**

In [9]:
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
preprocessing_layer = tf.keras.layers.Lambda(lambda obs: tf.cast(obs, np.float32)/255.)
conv_layer_parameters = [(32,(8,8),4),(64,(4,4),2),(64,(3,3),1)]
fc_layer_params = [512]
q_net = QNetwork (
    tf_env.observation_spec(),
    tf_env.action_spec(),
    preprocessing_layers=preprocessing_layer,
    conv_layer_params=conv_layer_parameters,
    fc_layer_params=fc_layer_params)

**Creation of the DQN Agent**

In [10]:
from tf_agents.agents.dqn.dqn_agent import DqnAgent
train_step = tf.Variable(0)
update_period = 4
optimizer = tf.keras.optimizers.RMSprop(lr=2.5e-4, rho=0.95, momentum=0.0, epsilon=0.00001, centered=True)
epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate = 1.0,
        decay_steps=250000,
        end_learning_rate=0.01)
agent = DqnAgent(tf_env.time_step_spec(),
                    tf_env.action_spec(),
                    q_network=q_net,
                    optimizer=optimizer,
                     target_update_period=2000,
                     td_errors_loss_fn=tf.keras.losses.Huber(reduction="none"),
                     gamma=0.99,
                     train_step_counter=train_step,
                     epsilon_greedy=lambda: epsilon_fn(train_step))
agent.initialize()

**Creation of the Replay Buffer**

In [11]:
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=100000)

replay_buffer_observer = replay_buffer.add_batch

**Definition of Metrics**

In [12]:
from tf_agents.metrics import tf_metrics

training_metrics = [tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
    ]

**Definition of Dynamic Step Driver**

In [13]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

collect_driver = DynamicStepDriver(
    tf_env,
agent.collect_policy,
observers=[replay_buffer_observer]+training_metrics,
num_steps=update_period)

**Initialization of the Replay Buffer**

In [14]:
class showProgress:
    def __init__(self, total):
        self.counter = 0
        self.total = total
    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter +=1
        if self.counter % 100 ==0:
            print("\r{}/{}".format(self.counter, self.total), end="")

In [15]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),tf_env.action_spec())
init_driver = DynamicStepDriver(
    tf_env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch, showProgress(20000)],
    num_steps=20000) #80000 frames 4 x 20000
final_time_step, final_policy_state = init_driver.run()

20000/20000

In [16]:
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2,
    num_parallel_calls=3).prefetch(3)

**Definition of TF Functions**

In [17]:
from tf_agents.utils.common import function

collect_driver.run=function(collect_driver.run)
agent.train = function(agent.train)

**Training of the Agent**

In [18]:
def train_agent(number_of_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)
    iterator = iter(dataset)
    for iteration in range(number_of_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="")
        if iteration >20:
            env.render("human")

In [None]:
train_agent(10000000)

5363 loss:0.00809