diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 75ef1f8f8e..e64086d953 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1103,6 +1103,27 @@ def wait_for_transform_job(self, job, poll=5): self._check_job_status(job, desc, "TransformJobStatus") return desc + def stop_transform_job(self, name): + """Stop the Amazon SageMaker hyperparameter tuning job with the specified name. + + Args: + name (str): Name of the Amazon SageMaker batch transform job. + + Raises: + ClientError: If an error occurs while trying to stop the batch transform job. + """ + try: + LOGGER.info("Stopping transform job: %s", name) + self.sagemaker_client.stop_transform_job(TransformJobName=name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + # allow to pass if the job already stopped + if error_code == "ValidationException": + LOGGER.info("Transform job: %s is already stopped or not running.", name) + else: + LOGGER.error("Error occurred while attempting to stop transform job: %s.", name) + raise + def _check_job_status(self, job, desc, status_key_name): """Check to see if the job completed successfully and, if not, construct and raise a exceptions.UnexpectedStatusException. diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index d66b73b6a7..d5efe87cdf 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -229,6 +229,14 @@ def wait(self): self._ensure_last_transform_job() self.latest_transform_job.wait() + def stop_transform_job(self, wait=True): + """Stop latest running batch transform job. + """ + self._ensure_last_transform_job() + self.latest_transform_job.stop() + if wait: + self.latest_transform_job.wait() + def _ensure_last_transform_job(self): """Placeholder docstring""" if self.latest_transform_job is None: @@ -346,6 +354,10 @@ def start_new( def wait(self): self.sagemaker_session.wait_for_transform_job(self.job_name) + def stop(self): + """Placeholder docstring""" + self.sagemaker_session.stop_transform_job(name=self.job_name) + @staticmethod def _load_config(data, data_type, content_type, compression_type, split_type, transformer): """ diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index a909fa2bfe..93a4dd2f34 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -16,6 +16,7 @@ import os import pickle import sys +import time import pytest @@ -349,6 +350,54 @@ def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version, ) +def test_stop_transform_job(sagemaker_session, mxnet_full_version): + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + tags = [{"Key": "some-tag", "Value": "value-for-tag"}] + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) + + transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags) + transformer.transform(transform_input, content_type="text/csv") + + time.sleep(15) + + latest_transform_job_name = transformer.latest_transform_job.name + + print("Attempting to stop {}".format(latest_transform_job_name)) + + transformer.stop_transform_job() + + desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=latest_transform_job_name + ) + assert desc["TransformJobStatus"] == "Stopped" + + def _create_transformer_and_transform_job( estimator, transform_input, diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 325f6536f1..6104f91789 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -449,3 +449,18 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session): transformer.transform(DATA, job_name="job-2") assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2") + + +def test_stop_transform_job(sagemaker_session, transformer): + sagemaker_session.stop_transform_job = Mock(name="stop_transform_job") + transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME) + + transformer.stop_transform_job() + + sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME) + + +def test_stop_transform_job_no_transform_job(transformer): + with pytest.raises(ValueError) as e: + transformer.stop_transform_job() + assert "No transform job available" in str(e)