Skip to content

Commit 1ce01d2

Browse files
authored
Merge pull request #581 from saurabh3949/master
Fix ray checkpointing issues
2 parents 3e82324 + 484a504 commit 1ce01d2

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

reinforcement_learning/rl_roboschool_ray/common/sagemaker_rl/configuration_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class ConfigurationList(object):
5-
"""Helper Object for converting CLI arguments (or SageMaker hyperparameters)
5+
"""Helper Object for converting CLI arguments (or SageMaker hyperparameters)
66
into Coach configuration.
77
"""
88

@@ -65,6 +65,8 @@ def _set_rl_property_value(self, obj, key, val, path=""):
6565
def _autotype(self, val):
6666
"""Converts string to an int or float as possible.
6767
"""
68+
if type(val) == bool:
69+
return val
6870
try:
6971
return int(val)
7072
except ValueError:
@@ -83,7 +85,7 @@ def _parse_type(self, key, val):
8385
Automatically detects ints and floats when possible.
8486
If the key takes the form "foo:bar" then it looks in ALLOWED_TYPES
8587
for an entry of bar, and instantiates one of those objects, passing
86-
val to the constructor. So if key="foo:EnvironmentSteps" then
88+
val to the constructor. So if key="foo:EnvironmentSteps" then
8789
"""
8890
val = self._autotype(val)
8991
if key.find(":") > 0:
@@ -93,5 +95,3 @@ def _parse_type(self, key, val):
9395
raise ValueError("Unrecognized object type %s. Allowed values are %s" % (obj_type, self.ALLOWED_TYPES.keys()))
9496
val = cls(val)
9597
return key, val
96-
97-

reinforcement_learning/rl_roboschool_ray/common/sagemaker_rl/ray_launcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def customize_experiment_config(self, config):
9898
# Set output dir to intermediate
9999
# TODO: move this to before customer-specified so they can override
100100
hyperparams_dict["rl.training.local_dir"] = INTERMEDIATE_DIR
101+
hyperparams_dict["rl.training.checkpoint_at_end"] = True
102+
hyperparams_dict["rl.training.checkpoint_freq"] = 10
101103
self.hyperparameters = ConfigurationList() # TODO: move to shared
102104
for name, value in hyperparams_dict.items():
103105
# self.map_hyperparameter(name, val) #TODO
@@ -106,10 +108,7 @@ def customize_experiment_config(self, config):
106108
self.hyperparameters.store(name, value)
107109
# else:
108110
# raise ValueError("Unknown hyperparameter %s" % name)
109-
110111
self.hyperparameters.apply_subset(config, "rl.")
111-
hyperparams_dict["rl.training.checkpoint_at_end"] = True
112-
hyperparams_dict["rl.training.checkpoint_freq"] = 10
113112
return config
114113

115114
def get_all_host_names(self):

reinforcement_learning/rl_roboschool_ray/rl_roboschool_ray.ipynb

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,6 @@
306306
"\n",
307307
"s3_url = \"s3://{}/{}\".format(s3_bucket,job_name)\n",
308308
"\n",
309-
"\n",
310309
"intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n",
311310
"intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n",
312311
"\n",
@@ -333,9 +332,9 @@
333332
"outputs": [],
334333
"source": [
335334
"recent_videos = wait_for_s3_object(\n",
336-
" s3_bucket, intermediate_folder_key, tmp_dir, \n",
337-
" fetch_only=(lambda obj: obj.key.endswith(\".mp4\") and obj.size>0), \n",
338-
" limit=10, training_job_name=job_name)"
335+
" s3_bucket, intermediate_folder_key, tmp_dir, \n",
336+
" fetch_only=(lambda obj: obj.key.endswith(\".mp4\") and obj.size>0), \n",
337+
" limit=10, training_job_name=job_name)"
339338
]
340339
},
341340
{
@@ -366,14 +365,17 @@
366365
"%matplotlib inline\n",
367366
"from sagemaker.analytics import TrainingJobAnalytics\n",
368367
"\n",
369-
"df = TrainingJobAnalytics(job_name, ['episode_reward_mean']).dataframe()\n",
370-
"num_metrics = len(df)\n",
371-
"if num_metrics == 0:\n",
372-
" print(\"No algorithm metrics found in CloudWatch\")\n",
368+
"if not local_mode:\n",
369+
" df = TrainingJobAnalytics(job_name, ['episode_reward_mean']).dataframe()\n",
370+
" num_metrics = len(df)\n",
371+
" if num_metrics == 0:\n",
372+
" print(\"No algorithm metrics found in CloudWatch\")\n",
373+
" else:\n",
374+
" plt = df.plot(x='timestamp', y='value', figsize=(12,5), legend=True, style='b-')\n",
375+
" plt.set_ylabel('Mean reward per episode')\n",
376+
" plt.set_xlabel('Training time (s)')\n",
373377
"else:\n",
374-
" plt = df.plot(x='timestamp', y='value', figsize=(12,5), legend=True, style='b-')\n",
375-
" plt.set_ylabel('Mean reward per episode')\n",
376-
" plt.set_xlabel('Training time (s)')"
378+
" print(\"Can't plot metrics in local mode.\")"
377379
]
378380
},
379381
{
@@ -403,7 +405,11 @@
403405
"metadata": {},
404406
"outputs": [],
405407
"source": [
406-
"model_tar_key = \"{}/output/model.tar.gz\".format(job_name)\n",
408+
"if local_mode:\n",
409+
" model_tar_key = \"{}/model.tar.gz\".format(job_name)\n",
410+
"else:\n",
411+
" model_tar_key = \"{}/output/model.tar.gz\".format(job_name)\n",
412+
" \n",
407413
"local_checkpoint_dir = \"{}/model\".format(tmp_dir)\n",
408414
"\n",
409415
"wait_for_s3_object(s3_bucket, model_tar_key, tmp_dir, training_job_name=job_name) \n",

0 commit comments

Comments
 (0)