From f64f5a98c733a9b3d7ca02a1af2a2085e3b2017e Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Fri, 2 Aug 2019 15:50:45 -0700 Subject: [PATCH] fix: allow Airflow enabled estimators to use absolute path entry_point --- src/sagemaker/workflow/airflow.py | 7 ++++++- tests/unit/test_airflow.py | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 9b74580f19..5e55351bcd 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -45,7 +45,12 @@ def prepare_framework(estimator, s3_operations): code_dir = "s3://{}/{}".format(bucket, key) estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) s3_operations["S3Upload"] = [ - {"Path": estimator.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True} + { + "Path": estimator.source_dir or estimator.entry_point, + "Bucket": bucket, + "Key": key, + "Tar": True, + } ] estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index 2cd6bbc9af..71ceedaaf1 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -164,7 +164,7 @@ def test_byo_training_config_all_args(sagemaker_session): @patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_framework_training_config_required_args(sagemaker_session): tf = tensorflow.TensorFlow( - entry_point="{{ entry_point }}", + entry_point="/some/script.py", framework_version="1.10.0", training_steps=1000, evaluation_steps=100, @@ -206,7 +206,7 @@ def test_framework_training_config_required_args(sagemaker_session): "HyperParameters": { "sagemaker_submit_directory": '"s3://output/sagemaker-tensorflow-%s/source/sourcedir.tar.gz"' % TIME_STAMP, - "sagemaker_program": '"{{ entry_point }}"', + "sagemaker_program": '"script.py"', "sagemaker_enable_cloudwatch_metrics": "false", "sagemaker_container_log_level": "20", "sagemaker_job_name": '"sagemaker-tensorflow-%s"' % TIME_STAMP, @@ -219,7 +219,7 @@ def test_framework_training_config_required_args(sagemaker_session): "S3Operations": { "S3Upload": [ { - "Path": "{{ entry_point }}", + "Path": "/some/script.py", "Bucket": "output", "Key": "sagemaker-tensorflow-%s/source/sourcedir.tar.gz" % TIME_STAMP, "Tar": True,