From d71d8c35a1cd0ac03094b3424d1640e6e778802c Mon Sep 17 00:00:00 2001 From: Pavitrakumar P Date: Mon, 9 Jan 2017 21:54:52 +0530 Subject: [PATCH 1/3] Update train.py added python 2.x compatibility added workaround for breakout and space invaders to correct the number of moves --- a3c/train.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/a3c/train.py b/a3c/train.py index 3d0d2b4..df1fa92 100644 --- a/a3c/train.py +++ b/a3c/train.py @@ -1,3 +1,5 @@ +from __future__ import print_function +from __future__ import division from scipy.misc import imresize from skimage.color import rgb2gray from multiprocessing import * @@ -151,7 +153,11 @@ def learn_proc(mem_queue, weight_dict): steps = args.steps # ----- env = gym.make(args.game) - agent = LearningAgent(env.action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) + #['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'] - default + #['NOOP', 'FIRE', 'RIGHT', 'LEFT'] - our workaround + if args.game == 'Breakout-v0' or args.game == 'SpaceInvaders-v0': + action_space = Discrete(4) + agent = LearningAgent(action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) # ----- if checkpoint > 0: print(' %5d> Loading weights from file' % (pid,)) @@ -264,8 +270,13 @@ def generate_experience_proc(mem_queue, weight_dict, no): batch_size = args.batch_size # ----- env = gym.make(args.game) - agent = ActingAgent(env.action_space, n_step=args.n_step) - + #['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'] - default + #['NOOP', 'FIRE', 'RIGHT', 'LEFT'] - our workaround + #this work around does not mess with the internals of ALE and will work with any compilation of ALE + if args.game == 'Breakout-v0' or args.game == 'SpaceInvaders-v0': + action_space = Discrete(4) + agent = LearningAgent(action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) + # ----- if frames > 0: print(' %5d> Loaded weights from file' % (pid,)) agent.load_net.load_weights('model-%s-%d.h5' % (args.game, frames)) From 95c9218f33425dc9b6ec2cc2ac1eaff4a15d3c67 Mon Sep 17 00:00:00 2001 From: Pavitrakumar P Date: Sun, 15 Jan 2017 14:28:21 +0530 Subject: [PATCH 2/3] Update train.py removed workarounds for breakout and spaceinveders --- a3c/train.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/a3c/train.py b/a3c/train.py index df1fa92..7d7ed8a 100644 --- a/a3c/train.py +++ b/a3c/train.py @@ -152,12 +152,7 @@ def learn_proc(mem_queue, weight_dict): checkpoint = args.checkpoint steps = args.steps # ----- - env = gym.make(args.game) - #['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'] - default - #['NOOP', 'FIRE', 'RIGHT', 'LEFT'] - our workaround - if args.game == 'Breakout-v0' or args.game == 'SpaceInvaders-v0': - action_space = Discrete(4) - agent = LearningAgent(action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) + agent = LearningAgent(env.action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) # ----- if checkpoint > 0: print(' %5d> Loading weights from file' % (pid,)) @@ -270,12 +265,7 @@ def generate_experience_proc(mem_queue, weight_dict, no): batch_size = args.batch_size # ----- env = gym.make(args.game) - #['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'] - default - #['NOOP', 'FIRE', 'RIGHT', 'LEFT'] - our workaround - #this work around does not mess with the internals of ALE and will work with any compilation of ALE - if args.game == 'Breakout-v0' or args.game == 'SpaceInvaders-v0': - action_space = Discrete(4) - agent = LearningAgent(action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) + agent = ActingAgent(env.action_space, n_step=args.n_step) # ----- if frames > 0: print(' %5d> Loaded weights from file' % (pid,)) From d32feea1626f80283fa76ac25398e35a5ae68ede Mon Sep 17 00:00:00 2001 From: Pavitrakumar P Date: Sun, 15 Jan 2017 14:28:59 +0530 Subject: [PATCH 3/3] Update train.py --- a3c/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/a3c/train.py b/a3c/train.py index 7d7ed8a..fd92e3d 100644 --- a/a3c/train.py +++ b/a3c/train.py @@ -152,6 +152,7 @@ def learn_proc(mem_queue, weight_dict): checkpoint = args.checkpoint steps = args.steps # ----- + env = gym.make(args.game) agent = LearningAgent(env.action_space, batch_size=args.batch_size, swap_freq=args.swap_freq) # ----- if checkpoint > 0: