|
17 | 17 | import os |
18 | 18 | import subprocess |
19 | 19 | from time import sleep |
| 20 | +from sagemaker.fw_utils import UploadedCode |
| 21 | + |
20 | 22 |
|
21 | 23 | import pytest |
22 | 24 | from botocore.exceptions import ClientError |
@@ -3350,3 +3352,87 @@ def test_image_name_map(sagemaker_session): |
3350 | 3352 | ) |
3351 | 3353 |
|
3352 | 3354 | assert e.image_uri == IMAGE_URI |
| 3355 | + |
| 3356 | + |
| 3357 | +@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3") |
| 3358 | +def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): |
| 3359 | + patched_stage_user_code.return_value = UploadedCode( |
| 3360 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 3361 | + ) |
| 3362 | + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" |
| 3363 | + image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" |
| 3364 | + model_uri = "s3://someprefix2/models/model.tar.gz" |
| 3365 | + t = Estimator( |
| 3366 | + entry_point=SCRIPT_PATH, |
| 3367 | + role=ROLE, |
| 3368 | + sagemaker_session=sagemaker_session, |
| 3369 | + instance_count=INSTANCE_COUNT, |
| 3370 | + instance_type=INSTANCE_TYPE, |
| 3371 | + source_dir=script_uri, |
| 3372 | + image_uri=image_uri, |
| 3373 | + model_uri=model_uri, |
| 3374 | + ) |
| 3375 | + t.fit("s3://bucket/mydata") |
| 3376 | + |
| 3377 | + patched_stage_user_code.assert_called_once() |
| 3378 | + sagemaker_session.train.assert_called_once() |
| 3379 | + |
| 3380 | + |
| 3381 | +@patch("time.time", return_value=TIME) |
| 3382 | +@patch("sagemaker.estimator.tar_and_upload_dir") |
| 3383 | +def test_script_mode_estimator_same_calls_as_framework( |
| 3384 | + patched_tar_and_upload_dir, sagemaker_session |
| 3385 | +): |
| 3386 | + |
| 3387 | + patched_tar_and_upload_dir.return_value = UploadedCode( |
| 3388 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 3389 | + ) |
| 3390 | + sagemaker_session.boto_region_name = REGION |
| 3391 | + |
| 3392 | + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" |
| 3393 | + |
| 3394 | + instance_type = "ml.p2.xlarge" |
| 3395 | + instance_count = 1 |
| 3396 | + |
| 3397 | + model_uri = "s3://someprefix2/models/model.tar.gz" |
| 3398 | + training_data_uri = "s3://bucket/mydata" |
| 3399 | + |
| 3400 | + generic_estimator = Estimator( |
| 3401 | + entry_point=SCRIPT_PATH, |
| 3402 | + role=ROLE, |
| 3403 | + region=REGION, |
| 3404 | + sagemaker_session=sagemaker_session, |
| 3405 | + instance_count=instance_count, |
| 3406 | + instance_type=instance_type, |
| 3407 | + source_dir=script_uri, |
| 3408 | + image_uri=IMAGE_URI, |
| 3409 | + model_uri=model_uri, |
| 3410 | + environment={"USE_SMDEBUG": "0"}, |
| 3411 | + dependencies=[], |
| 3412 | + debugger_hook_config={}, |
| 3413 | + ) |
| 3414 | + generic_estimator.fit(training_data_uri) |
| 3415 | + |
| 3416 | + generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list |
| 3417 | + generic_estimator_train_args = sagemaker_session.train.call_args_list |
| 3418 | + |
| 3419 | + patched_tar_and_upload_dir.reset_mock() |
| 3420 | + sagemaker_session.train.reset_mock() |
| 3421 | + |
| 3422 | + framework_estimator = DummyFramework( |
| 3423 | + entry_point=SCRIPT_PATH, |
| 3424 | + role=ROLE, |
| 3425 | + region=REGION, |
| 3426 | + source_dir=script_uri, |
| 3427 | + instance_count=instance_count, |
| 3428 | + instance_type=instance_type, |
| 3429 | + sagemaker_session=sagemaker_session, |
| 3430 | + model_uri=model_uri, |
| 3431 | + dependencies=[], |
| 3432 | + debugger_hook_config={}, |
| 3433 | + ) |
| 3434 | + framework_estimator.fit(training_data_uri) |
| 3435 | + |
| 3436 | + assert len(generic_estimator_tar_and_upload_dir_args) == 1 |
| 3437 | + assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list |
| 3438 | + assert generic_estimator_train_args == sagemaker_session.train.call_args_list |
0 commit comments