|
17 | 17 | import os |
18 | 18 | import subprocess |
19 | 19 | from time import sleep |
| 20 | +from sagemaker.fw_utils import UploadedCode |
| 21 | + |
20 | 22 |
|
21 | 23 | import pytest |
22 | 24 | from botocore.exceptions import ClientError |
@@ -3350,3 +3352,112 @@ def test_image_name_map(sagemaker_session): |
3350 | 3352 | ) |
3351 | 3353 |
|
3352 | 3354 | assert e.image_uri == IMAGE_URI |
| 3355 | + |
| 3356 | + |
| 3357 | +@patch("sagemaker.git_utils.git_clone_repo") |
| 3358 | +def test_git_support_with_branch_and_commit_succeed_estimator_class( |
| 3359 | + git_clone_repo, sagemaker_session |
| 3360 | +): |
| 3361 | + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { |
| 3362 | + "entry_point": "/tmp/repo_dir/entry_point", |
| 3363 | + "source_dir": None, |
| 3364 | + "dependencies": None, |
| 3365 | + } |
| 3366 | + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} |
| 3367 | + entry_point = "entry_point" |
| 3368 | + fw = Estimator( |
| 3369 | + entry_point=entry_point, |
| 3370 | + git_config=git_config, |
| 3371 | + role=ROLE, |
| 3372 | + sagemaker_session=sagemaker_session, |
| 3373 | + instance_count=INSTANCE_COUNT, |
| 3374 | + instance_type=INSTANCE_TYPE, |
| 3375 | + image_uri=IMAGE_URI, |
| 3376 | + ) |
| 3377 | + fw.fit() |
| 3378 | + git_clone_repo.assert_called_once_with(git_config, entry_point, None, None) |
| 3379 | + |
| 3380 | + |
| 3381 | +@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3") |
| 3382 | +def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): |
| 3383 | + patched_stage_user_code.return_value = UploadedCode( |
| 3384 | + s3_prefix="s3://bucket/key", script_name="script_name" |
| 3385 | + ) |
| 3386 | + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" |
| 3387 | + image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" |
| 3388 | + model_uri = "s3://someprefix2/models/model.tar.gz" |
| 3389 | + t = Estimator( |
| 3390 | + entry_point=SCRIPT_PATH, |
| 3391 | + role=ROLE, |
| 3392 | + sagemaker_session=sagemaker_session, |
| 3393 | + instance_count=INSTANCE_COUNT, |
| 3394 | + instance_type=INSTANCE_TYPE, |
| 3395 | + source_dir=script_uri, |
| 3396 | + image_uri=image_uri, |
| 3397 | + model_uri=model_uri, |
| 3398 | + ) |
| 3399 | + t.fit("s3://bucket/mydata") |
| 3400 | + |
| 3401 | + patched_stage_user_code.assert_called_once() |
| 3402 | + sagemaker_session.train.assert_called_once() |
| 3403 | + |
| 3404 | + |
| 3405 | +@patch("time.time", return_value=TIME) |
| 3406 | +@patch("sagemaker.estimator.tar_and_upload_dir") |
| 3407 | +def test_script_mode_estimator_same_calls_as_framework( |
| 3408 | + patched_tar_and_upload_dir, sagemaker_session |
| 3409 | +): |
| 3410 | + |
| 3411 | + patched_tar_and_upload_dir.return_value = UploadedCode( |
| 3412 | + s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
| 3413 | + ) |
| 3414 | + sagemaker_session.boto_region_name = REGION |
| 3415 | + |
| 3416 | + script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz" |
| 3417 | + |
| 3418 | + instance_type = "ml.p2.xlarge" |
| 3419 | + instance_count = 1 |
| 3420 | + |
| 3421 | + model_uri = "s3://someprefix2/models/model.tar.gz" |
| 3422 | + training_data_uri = "s3://bucket/mydata" |
| 3423 | + |
| 3424 | + generic_estimator = Estimator( |
| 3425 | + entry_point=SCRIPT_PATH, |
| 3426 | + role=ROLE, |
| 3427 | + region=REGION, |
| 3428 | + sagemaker_session=sagemaker_session, |
| 3429 | + instance_count=instance_count, |
| 3430 | + instance_type=instance_type, |
| 3431 | + source_dir=script_uri, |
| 3432 | + image_uri=IMAGE_URI, |
| 3433 | + model_uri=model_uri, |
| 3434 | + environment={"USE_SMDEBUG": "0"}, |
| 3435 | + dependencies=[], |
| 3436 | + debugger_hook_config={}, |
| 3437 | + ) |
| 3438 | + generic_estimator.fit(training_data_uri) |
| 3439 | + |
| 3440 | + generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list |
| 3441 | + generic_estimator_train_args = sagemaker_session.train.call_args_list |
| 3442 | + |
| 3443 | + patched_tar_and_upload_dir.reset_mock() |
| 3444 | + sagemaker_session.train.reset_mock() |
| 3445 | + |
| 3446 | + framework_estimator = DummyFramework( |
| 3447 | + entry_point=SCRIPT_PATH, |
| 3448 | + role=ROLE, |
| 3449 | + region=REGION, |
| 3450 | + source_dir=script_uri, |
| 3451 | + instance_count=instance_count, |
| 3452 | + instance_type=instance_type, |
| 3453 | + sagemaker_session=sagemaker_session, |
| 3454 | + model_uri=model_uri, |
| 3455 | + dependencies=[], |
| 3456 | + debugger_hook_config={}, |
| 3457 | + ) |
| 3458 | + framework_estimator.fit(training_data_uri) |
| 3459 | + |
| 3460 | + assert len(generic_estimator_tar_and_upload_dir_args) == 1 |
| 3461 | + assert len(generic_estimator_train_args) == 1 |
| 3462 | + assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list |
| 3463 | + assert generic_estimator_train_args == sagemaker_session.train.call_args_list |
0 commit comments