Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def read_version():
"pandas",
"black==19.3b0 ; python_version >= '3.6'",
"stopit==1.1.2",
"apache-airflow==1.10.5",
]
},
entry_points={"console_scripts": ["sagemaker=sagemaker.cli.main:main"]},
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ class constructor
del init_params["image"]
return init_params

def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
"""Calls _prepare_for_training. Used when setting up a workflow.

Args:
records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
mini_batch_size (int or None): The size of each mini-batch to use when
training. If ``None``, a default value will be used.
job_name (str): Name of the training job to be created. If not
specified, one is generated, using the base name given to the
constructor if applicable.
"""
self._prepare_for_training(
records=records, mini_batch_size=mini_batch_size, job_name=job_name
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
"""Set hyperparameters needed for training.

Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,16 @@ def enable_network_isolation(self):
"""
return False

def prepare_workflow_for_training(self, job_name=None):
"""Calls _prepare_for_training. Used when setting up a workflow.

Args:
job_name (str): Name of the training job to be created. If not
specified, one is generated, using the base name given to the
constructor if applicable.
"""
self._prepare_for_training(job_name=job_name)

def _prepare_for_training(self, job_name=None):
"""Set any values in the estimator that need to be set before training.

Expand Down
26 changes: 25 additions & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model
from sagemaker.transformer import Transformer
from sagemaker import utils
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

Expand Down Expand Up @@ -755,8 +756,31 @@ def transformer(
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
"""

role = role or self.role

if self.latest_training_job is None:
logging.warning(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
return Transformer(
self._current_job_name,
instance_count,
instance_type,
strategy=strategy,
assemble_with=assemble_with,
output_path=output_path,
output_kms_key=output_kms_key,
accept=accept,
max_concurrent_transforms=max_concurrent_transforms,
max_payload=max_payload,
env=env or {},
tags=tags,
base_transform_job_name=self.base_job_name,
volume_kms_key=volume_kms_key,
sagemaker_session=self.sagemaker_session,
)

model = self.create_model(
model_server_workers=model_server_workers,
role=role,
Expand Down
14 changes: 13 additions & 1 deletion src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def prepare_framework(estimator, s3_operations):
if estimator.code_location is not None:
bucket, key = fw_utils.parse_s3_url(estimator.code_location)
key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz")
elif estimator.uploaded_code is not None:
bucket, key = fw_utils.parse_s3_url(estimator.uploaded_code.s3_prefix)
else:
bucket = estimator.sagemaker_session._default_bucket
key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz")

script = os.path.basename(estimator.entry_point)

if estimator.source_dir and estimator.source_dir.lower().startswith("s3://"):
code_dir = estimator.source_dir
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
Expand Down Expand Up @@ -96,7 +100,7 @@ def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None):
estimator.mini_batch_size = mini_batch_size


def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None):
def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=None): # noqa: C901
"""Export Airflow base training config from an estimator

Args:
Expand Down Expand Up @@ -134,6 +138,13 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
dict: Training config that can be directly used by
SageMakerTrainingOperator in Airflow.
"""
if isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
estimator.prepare_workflow_for_training(
records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
)
else:
estimator.prepare_workflow_for_training(job_name=job_name)

default_bucket = estimator.sagemaker_session.default_bucket()
s3_operations = {}

Expand Down Expand Up @@ -528,6 +539,7 @@ def model_config_from_estimator(
model_server_workers=model_server_workers,
role=role,
vpc_config_override=vpc_config_override,
entry_point=estimator.entry_point,
)
else:
raise TypeError(
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ def create_model(
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
"""
role = role or self.role

# Remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("Removing unused entry_point argument: %s", str(kwargs["entry_point"]))
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}

return XGBoostModel(
self.model_data,
role,
Expand Down
Loading