|
306 | 306 | "\n", |
307 | 307 | "s3_url = \"s3://{}/{}\".format(s3_bucket,job_name)\n", |
308 | 308 | "\n", |
309 | | - "\n", |
310 | 309 | "intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n", |
311 | 310 | "intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n", |
312 | 311 | "\n", |
|
333 | 332 | "outputs": [], |
334 | 333 | "source": [ |
335 | 334 | "recent_videos = wait_for_s3_object(\n", |
336 | | - " s3_bucket, intermediate_folder_key, tmp_dir, \n", |
337 | | - " fetch_only=(lambda obj: obj.key.endswith(\".mp4\") and obj.size>0), \n", |
338 | | - " limit=10, training_job_name=job_name)" |
| 335 | + " s3_bucket, intermediate_folder_key, tmp_dir, \n", |
| 336 | + " fetch_only=(lambda obj: obj.key.endswith(\".mp4\") and obj.size>0), \n", |
| 337 | + " limit=10, training_job_name=job_name)" |
339 | 338 | ] |
340 | 339 | }, |
341 | 340 | { |
|
366 | 365 | "%matplotlib inline\n", |
367 | 366 | "from sagemaker.analytics import TrainingJobAnalytics\n", |
368 | 367 | "\n", |
369 | | - "df = TrainingJobAnalytics(job_name, ['episode_reward_mean']).dataframe()\n", |
370 | | - "num_metrics = len(df)\n", |
371 | | - "if num_metrics == 0:\n", |
372 | | - " print(\"No algorithm metrics found in CloudWatch\")\n", |
| 368 | + "if not local_mode:\n", |
| 369 | + " df = TrainingJobAnalytics(job_name, ['episode_reward_mean']).dataframe()\n", |
| 370 | + " num_metrics = len(df)\n", |
| 371 | + " if num_metrics == 0:\n", |
| 372 | + " print(\"No algorithm metrics found in CloudWatch\")\n", |
| 373 | + " else:\n", |
| 374 | + " plt = df.plot(x='timestamp', y='value', figsize=(12,5), legend=True, style='b-')\n", |
| 375 | + " plt.set_ylabel('Mean reward per episode')\n", |
| 376 | + " plt.set_xlabel('Training time (s)')\n", |
373 | 377 | "else:\n", |
374 | | - " plt = df.plot(x='timestamp', y='value', figsize=(12,5), legend=True, style='b-')\n", |
375 | | - " plt.set_ylabel('Mean reward per episode')\n", |
376 | | - " plt.set_xlabel('Training time (s)')" |
| 378 | + " print(\"Can't plot metrics in local mode.\")" |
377 | 379 | ] |
378 | 380 | }, |
379 | 381 | { |
|
403 | 405 | "metadata": {}, |
404 | 406 | "outputs": [], |
405 | 407 | "source": [ |
406 | | - "model_tar_key = \"{}/output/model.tar.gz\".format(job_name)\n", |
| 408 | + "if local_mode:\n", |
| 409 | + " model_tar_key = \"{}/model.tar.gz\".format(job_name)\n", |
| 410 | + "else:\n", |
| 411 | + " model_tar_key = \"{}/output/model.tar.gz\".format(job_name)\n", |
| 412 | + " \n", |
407 | 413 | "local_checkpoint_dir = \"{}/model\".format(tmp_dir)\n", |
408 | 414 | "\n", |
409 | 415 | "wait_for_s3_object(s3_bucket, model_tar_key, tmp_dir, training_job_name=job_name) \n", |
|
0 commit comments