|
16 | 16 | import os |
17 | 17 | import pickle |
18 | 18 | import sys |
| 19 | +import time |
19 | 20 |
|
20 | 21 | import pytest |
21 | 22 |
|
@@ -301,20 +302,48 @@ def test_transform_byo_estimator(sagemaker_session): |
301 | 302 | assert tags == model_tags |
302 | 303 |
|
303 | 304 |
|
304 | | -def _create_transformer_and_transform_job( |
305 | | - estimator, |
306 | | - transform_input, |
307 | | - volume_kms_key=None, |
308 | | - input_filter=None, |
309 | | - output_filter=None, |
310 | | - join_source=None, |
311 | | -): |
312 | | - transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key) |
313 | | - transformer.transform( |
314 | | - transform_input, |
315 | | - content_type="text/csv", |
316 | | - input_filter=input_filter, |
317 | | - output_filter=output_filter, |
318 | | - join_source=join_source, |
319 | | - ) |
| 305 | +def test_stop_transform_job(sagemaker_session, mxnet_full_version): |
| 306 | + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') |
| 307 | + script_path = os.path.join(data_path, 'mnist.py') |
| 308 | + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] |
| 309 | + |
| 310 | + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, |
| 311 | + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, |
| 312 | + framework_version=mxnet_full_version) |
| 313 | + |
| 314 | + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), |
| 315 | + key_prefix='integ-test-data/mxnet_mnist/train') |
| 316 | + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), |
| 317 | + key_prefix='integ-test-data/mxnet_mnist/test') |
| 318 | + job_name = unique_name_from_base('test-mxnet-transform') |
| 319 | + |
| 320 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 321 | + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) |
| 322 | + |
| 323 | + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') |
| 324 | + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' |
| 325 | + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, |
| 326 | + key_prefix=transform_input_key_prefix) |
| 327 | + |
| 328 | + transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) |
| 329 | + transformer.transform(transform_input, content_type='text/csv') |
| 330 | + |
| 331 | + time.sleep(15) |
| 332 | + |
| 333 | + latest_transform_job_name = transformer.latest_transform_job.name |
| 334 | + |
| 335 | + print('Attempting to stop {}'.format(latest_transform_job_name)) |
| 336 | + |
| 337 | + transformer.stop_transform_job() |
| 338 | + |
| 339 | + desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \ |
| 340 | + .describe_transform_job(TransformJobName=latest_transform_job_name) |
| 341 | + assert desc['TransformJobStatus'] == 'Stopping' |
| 342 | + |
| 343 | + |
| 344 | +def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, |
| 345 | + input_filter=None, output_filter=None, join_source=None): |
| 346 | + transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) |
| 347 | + transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, |
| 348 | + output_filter=output_filter, join_source=join_source) |
320 | 349 | return transformer |
0 commit comments