diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index cacfdeee9c..bbb2289f27 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 74416ed0e2..80855340da 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, repack=True) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 830bb50dab..c78d786c75 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,6 +18,7 @@ import logging import os import re +import copy import sagemaker from sagemaker import ( @@ -57,6 +58,15 @@ def delete_model(self, *args, **kwargs) -> None: """Destroy resources associated with this model.""" +SCRIPT_PARAM_NAME = "sagemaker_program" +DIR_PARAM_NAME = "sagemaker_submit_directory" +CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level" +JOB_NAME_PARAM_NAME = "sagemaker_job_name" +MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers" +SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region" +SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" + + class Model(ModelBase): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" @@ -73,6 +83,12 @@ def __init__( enable_network_isolation=False, model_kms_key=None, image_config=None, + source_dir=None, + code_location=None, + entry_point=None, + container_log_level=logging.INFO, + dependencies=None, + git_config=None, ): """Initialize an SageMaker ``Model``. @@ -114,6 +130,124 @@ def __init__( model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). + source_dir (str): Path (absolute, relative or an S3 URI) to a directory + with any other training source code dependencies aside from the entry + point file (default: None). If ``source_dir`` is an S3 URI, it must + point to a tar.gz file. Structure within this directory are preserved + when training on Amazon SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git repo. + If the directory points to S3, no code will be uploaded and the S3 location + will be used instead. + + .. admonition:: Example + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='inference.py', source_dir='src'. + code_location (str): Name of the S3 bucket where custom code is + uploaded (default: None). If not specified, default bucket + created by ``sagemaker.session.Session`` is used. + entry_point (str): Path (absolute or relative) to the Python source + file which should be executed as the entry point to model + hosting (default: None). If ``source_dir`` is specified, + then ``entry_point`` must point to a file located at the root of + ``source_dir``. If 'git_config' is provided, 'entry_point' should + be a relative location to the Python source file in the Git repo. + + Example: + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='src/inference.py'. + container_log_level (int): Log level to use within the container + (default: logging.INFO). Valid values are defined in the Python + logging module. + dependencies (list[str]): A list of paths to directories (absolute + or relative) with any additional libraries that will be exported + to the container (default: []). The library folders will be + copied to SageMaker in the same folder where the entrypoint is + copied. If 'git_config' is provided, 'dependencies' should be a + list of relative locations to directories with any additional + libraries needed in the Git repo. If the ```source_dir``` points + to S3, code will be uploaded and the S3 location will be used + instead. + + .. admonition:: Example + + The following call + + >>> Model(entry_point='inference.py', + ... dependencies=['my/libs/common', 'virtual-env']) + + results in the following inside the container: + + >>> $ ls + + >>> opt/ml/code + >>> |------ inference.py + >>> |------ common + >>> |------ virtual-env + + This is not supported with "local code" in Local Mode. + git_config (dict[str, str]): Git configurations used for cloning + files, including ``repo``, ``branch``, ``commit``, + ``2FA_enabled``, ``username``, ``password`` and ``token``. The + ``repo`` field is required. All other fields are optional. + ``repo`` specifies the Git repository where your training script + is stored. If you don't provide ``branch``, the default value + 'master' is used. If you don't provide ``commit``, the latest + commit in the specified branch is used. .. admonition:: Example + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then + checkout the 'master' branch, and checkout the specified + commit. + + ``2FA_enabled``, ``username``, ``password`` and ``token`` are + used for authentication. For GitHub (or other Git) accounts, set + ``2FA_enabled`` to 'True' if two-factor authentication is + enabled for the account, otherwise set it to 'False'. If you do + not provide a value for ``2FA_enabled``, a default value of + 'False' is used. CodeCommit does not support two-factor + authentication, so do not provide "2FA_enabled" with CodeCommit + repositories. + + For GitHub and other Git repos, when SSH URLs are provided, it + doesn't matter whether 2FA is enabled or disabled; you should + either have no passphrase for the SSH key pairs, or have the + ssh-agent configured so that you will not be prompted for SSH + passphrase when you do 'git clone' command with SSH URLs. When + HTTPS URLs are provided: if 2FA is disabled, then either token + or username+password will be used for authentication if provided + (token prioritized); if 2FA is enabled, only token will be used + for authentication if provided. If required authentication info + is not provided, python SDK will try to use local credentials + storage to authenticate. If that fails either, an error message + will be thrown. + + For CodeCommit repos, 2FA is not supported, so '2FA_enabled' + should not be provided. There is no token in CodeCommit, so + 'token' should not be provided too. When 'repo' is an SSH URL, + the requirements are the same as GitHub-like repos. When 'repo' + is an HTTPS URL, username+password will be used for + authentication if they are provided; otherwise, python SDK will + try to use either CodeCommit credential helper or local + credential storage for authentication. + """ self.model_data = model_data self.image_uri = image_uri @@ -131,6 +265,24 @@ def __init__( self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key self.image_config = image_config + self.entry_point = entry_point + self.source_dir = source_dir + self.dependencies = dependencies or [] + self.git_config = git_config + self.container_log_level = container_log_level + if code_location: + self.bucket, self.key_prefix = s3.parse_s3_url(code_location) + else: + self.bucket, self.key_prefix = None, None + if self.git_config: + updates = git_utils.git_clone_repo( + self.git_config, self.entry_point, self.source_dir, self.dependencies + ) + self.entry_point = updates["entry_point"] + self.source_dir = updates["source_dir"] + self.dependencies = updates["dependencies"] + self.uploaded_code = None + self.repacked_model_data = None def register( self, @@ -242,10 +394,90 @@ def prepare_container_def( Returns: dict: A container definition object usable with the CreateModel API. """ + deploy_key_prefix = fw_utils.model_code_key_prefix( + self.key_prefix, self.name, self.image_uri + ) + deploy_env = copy.deepcopy(self.env) + if self.source_dir or self.dependencies or self.entry_point or self.git_config: + if self.key_prefix or self.git_config: + self._upload_code(deploy_key_prefix, repack=False) + elif self.source_dir and self.entry_point: + self._upload_code(deploy_key_prefix, repack=True) + else: + self._upload_code(deploy_key_prefix, repack=False) + deploy_env.update(self._script_mode_env_vars()) return sagemaker.container_def( - self.image_uri, self.model_data, self.env, image_config=self.image_config + self.image_uri, self.model_data, deploy_env, image_config=self.image_config ) + def _upload_code(self, key_prefix: str, repack: bool = False) -> None: + """Uploads code to S3 to be used with script mode with SageMaker inference. + + Args: + key_prefix (str): The S3 key associated with the ``code_location`` parameter of the + ``Model`` class. + repack (bool): Optional. Set to ``True`` to indicate that the source code and model + artifact should be repackaged into a new S3 object. (default: False). + """ + local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) + if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: + self.uploaded_code = None + elif not repack: + bucket = self.bucket or self.sagemaker_session.default_bucket() + self.uploaded_code = fw_utils.tar_and_upload_dir( + session=self.sagemaker_session.boto_session, + bucket=bucket, + s3_key_prefix=key_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies, + ) + + if repack and self.model_data is not None and self.entry_point is not None: + if isinstance(self.model_data, sagemaker.workflow.properties.Properties): + # model is not yet there, defer repacking to later during pipeline execution + return + + bucket = self.bucket or self.sagemaker_session.default_bucket() + repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) + + utils.repack_model( + inference_script=self.entry_point, + source_directory=self.source_dir, + dependencies=self.dependencies, + model_uri=self.model_data, + repacked_model_uri=repacked_model_data, + sagemaker_session=self.sagemaker_session, + kms_key=self.model_kms_key, + ) + + self.repacked_model_data = repacked_model_data + self.uploaded_code = fw_utils.UploadedCode( + s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) + ) + + def _script_mode_env_vars(self): + """Placeholder docstring""" + script_name = None + dir_name = None + if self.uploaded_code: + script_name = self.uploaded_code.script_name + if self.enable_network_isolation(): + dir_name = "/opt/ml/model/code" + else: + dir_name = self.uploaded_code.s3_prefix + elif self.entry_point is not None: + script_name = self.entry_point + if self.source_dir is not None: + dir_name = "file://" + self.source_dir + + return { + SCRIPT_PARAM_NAME.upper(): script_name or str(), + DIR_PARAM_NAME.upper(): dir_name or str(), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, + } + def enable_network_isolation(self): """Whether to enable network isolation when creating this Model @@ -885,15 +1117,6 @@ def delete_model(self): self.sagemaker_session.delete_model(self.name) -SCRIPT_PARAM_NAME = "sagemaker_program" -DIR_PARAM_NAME = "sagemaker_submit_directory" -CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level" -JOB_NAME_PARAM_NAME = "sagemaker_job_name" -MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers" -SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region" -SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" - - class FrameworkModel(Model): """A Model for working with an SageMaker ``Framework``. @@ -1071,113 +1294,14 @@ def __init__( env=env, name=name, sagemaker_session=sagemaker_session, + source_dir=source_dir, + code_location=code_location, + entry_point=entry_point, + container_log_level=container_log_level, + dependencies=dependencies, + git_config=git_config, **kwargs, ) - self.entry_point = entry_point - self.source_dir = source_dir - self.dependencies = dependencies or [] - self.git_config = git_config - self.container_log_level = container_log_level - if code_location: - self.bucket, self.key_prefix = s3.parse_s3_url(code_location) - else: - self.bucket, self.key_prefix = None, None - if self.git_config: - updates = git_utils.git_clone_repo( - self.git_config, self.entry_point, self.source_dir, self.dependencies - ) - self.entry_point = updates["entry_point"] - self.source_dir = updates["source_dir"] - self.dependencies = updates["dependencies"] - self.uploaded_code = None - self.repacked_model_data = None - - def prepare_container_def(self, instance_type=None, accelerator_type=None): - """Return a container definition with framework configuration. - - Framework configuration is set in model environment variables. - This also uploads user-supplied code to S3. - - Args: - instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p2.xlarge'. - accelerator_type (str): The Elastic Inference accelerator type to - deploy to the instance for loading and making inferences to the - model. For example, 'ml.eia1.medium'. - - Returns: - dict[str, str]: A container definition object usable with the - CreateModel API. - """ - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - self._upload_code(deploy_key_prefix) - deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) - return sagemaker.container_def(self.image_uri, self.model_data, deploy_env) - - def _upload_code(self, key_prefix, repack=False): - """Placeholder Docstring""" - local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) - if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: - self.uploaded_code = None - elif not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() - self.uploaded_code = fw_utils.tar_and_upload_dir( - session=self.sagemaker_session.boto_session, - bucket=bucket, - s3_key_prefix=key_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies, - settings=self.sagemaker_session.settings, - ) - - if repack and self.model_data is not None and self.entry_point is not None: - if isinstance(self.model_data, sagemaker.workflow.properties.Properties): - # model is not yet there, defer repacking to later during pipeline execution - return - - bucket = self.bucket or self.sagemaker_session.default_bucket() - repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) - - utils.repack_model( - inference_script=self.entry_point, - source_directory=self.source_dir, - dependencies=self.dependencies, - model_uri=self.model_data, - repacked_model_uri=repacked_model_data, - sagemaker_session=self.sagemaker_session, - kms_key=self.model_kms_key, - ) - - self.repacked_model_data = repacked_model_data - self.uploaded_code = fw_utils.UploadedCode( - s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) - ) - - def _framework_env_vars(self): - """Placeholder docstring""" - script_name = None - dir_name = None - if self.uploaded_code: - script_name = self.uploaded_code.script_name - if self.enable_network_isolation(): - dir_name = "/opt/ml/model/code" - else: - dir_name = self.uploaded_code.s3_prefix - elif self.entry_point is not None: - script_name = self.entry_point - if self.source_dir is not None: - dir_name = "file://" + self.source_dir - - return { - SCRIPT_PARAM_NAME.upper(): script_name or str(), - DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), - SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, - } class ModelPackage(Model): diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index aec5cd86da..df0dd31a28 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, self._is_mms_version()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 1568bb14ac..3a0c3a283c 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, repack=self._is_mms_version()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 6a8e31fe19..8efb7480c9 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 7f0448c018..115e09a9c9 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations): ] deploy_env = dict(model.env) - deploy_env.update(model._framework_env_vars()) + deploy_env.update(model._script_mode_env_vars()) try: if model.model_server_workers: diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 49acc11074..08dc7f8899 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None): deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation()) deploy_env = dict(self.env) - deploy_env.update(self._framework_env_vars()) + deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c931c5bf2b..e1e9d69104 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import MagicMock import pytest from mock import Mock, patch import sagemaker -from sagemaker.model import Model +from sagemaker.model import FrameworkModel, Model MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -27,10 +28,39 @@ INSTANCE_TYPE = "ml.c4.4xlarge" ROLE = "some-role" +REGION = "us-west-2" +BUCKET_NAME = "some-bucket-name" +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +ENTRY_POINT_INFERENCE = "inference.py" -@pytest.fixture +SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz" +IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + + +class DummyFrameworkModel(FrameworkModel): + def __init__(self, **kwargs): + super(DummyFrameworkModel, self).__init__( + **kwargs, + ) + + +@pytest.fixture() def sagemaker_session(): - return Mock() + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + + return sms def test_prepare_container_def_with_model_data(): @@ -345,3 +375,75 @@ def test_delete_model_no_name(sagemaker_session): ): model.delete_model() sagemaker_session.delete_model.assert_not_called() + + +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +@patch("sagemaker.utils.repack_model") +def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session): + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=SCRIPT_URI, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert len(sagemaker_session.create_model.call_args_list) == 1 + assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1 + assert len(repack_model.call_args_list) == 1 + + generic_model_create_model_args = sagemaker_session.create_model.call_args_list + generic_model_endpoint_from_production_variants_args = ( + sagemaker_session.endpoint_from_production_variants.call_args_list + ) + generic_model_repack_model_args = repack_model.call_args_list + + sagemaker_session.create_model.reset_mock() + sagemaker_session.endpoint_from_production_variants.reset_mock() + repack_model.reset_mock() + + t = DummyFrameworkModel( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=SCRIPT_URI, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) + + assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list + assert ( + generic_model_endpoint_from_production_variants_args + == sagemaker_session.endpoint_from_production_variants.call_args_list + ) + assert generic_model_repack_model_args == repack_model.call_args_list + + +@patch("sagemaker.git_utils.git_clone_repo") +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { + "entry_point": "entry_point", + "source_dir": "/tmp/repo_dir/source_dir", + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], + } + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + model = Model( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + dependencies=dependencies, + git_config=git_config, + image_uri=IMAGE_URI, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) + assert model.entry_point == "entry_point" + assert model.source_dir == "/tmp/repo_dir/source_dir" + assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]