diff --git a/doc/overview.rst b/doc/overview.rst index 8a2f789252..f6da2c1ba9 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -84,6 +84,65 @@ For more `information >> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + You can assign entry_point='src/train.py'. + git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch' + and 'commit' (default: None). + 'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If + 'commit' is not specified, the latest commit in the required branch will be used. + Example: + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout + the specified commit. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this - directory are preserved when training on Amazon SageMaker. + directory are preserved when training on Amazon SageMaker. If 'git_config' is provided, + source_dir should be a relative location to a directory in the Git repo. + Example: + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- train.py + >>> |----- test.py + + and you need 'train.py' as entry point and 'test.py' as training source code as well, you can + assign entry_point='train.py', source_dir='src'. hyperparameters (dict): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called @@ -1006,6 +1047,7 @@ def __init__( ) ) self.entry_point = entry_point + self.git_config = git_config self.source_dir = source_dir self.dependencies = dependencies or [] if enable_cloudwatch_metrics: @@ -1038,6 +1080,14 @@ def _prepare_for_training(self, job_name=None): """ super(Framework, self)._prepare_for_training(job_name=job_name) + if self.git_config: + updates = git_utils.git_clone_repo( + self.git_config, self.entry_point, self.source_dir, self.dependencies + ) + self.entry_point = updates["entry_point"] + self.source_dir = updates["source_dir"] + self.dependencies = updates["dependencies"] + # validate source dir will raise a ValueError if there is something wrong with the # source directory. We are intentionally not handling it because this is a critical error. if self.source_dir and not self.source_dir.lower().startswith("s3://"): diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py new file mode 100644 index 0000000000..44c91d61ff --- /dev/null +++ b/src/sagemaker/git_utils.py @@ -0,0 +1,104 @@ +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import subprocess +import tempfile + + +def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): + """Git clone repo containing the training code and serving code. This method also validate ``git_config``, + and set ``entry_point``, ``source_dir`` and ``dependencies`` to the right file or directory in the repo cloned. + + Args: + git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` + and ``commit``. ``branch`` and ``commit`` are optional. If ``branch`` is not specified, master branch + will be used. If ``commit`` is not specified, the latest commit in the required branch will be used. + entry_point (str): A relative location to the Python source file which should be executed as the entry point + to training or model hosting in the Git repo. + source_dir (str): A relative location to a directory with other training or model hosting source code + dependencies aside from the entry point file in the Git repo (default: None). Structure within this + directory are preserved when training on Amazon SageMaker. + dependencies (list[str]): A list of relative locations to directories with any additional libraries that will + be exported to the container in the Git repo (default: []). + + Raises: + CalledProcessError: If 1. failed to clone git repo + 2. failed to checkout the required branch + 3. failed to checkout the required commit + ValueError: If 1. entry point specified does not exist in the repo + 2. source dir specified does not exist in the repo + + Returns: + dict: A dict that contains the updated values of entry_point, source_dir and dependencies + """ + _validate_git_config(git_config) + repo_dir = tempfile.mkdtemp() + subprocess.check_call(["git", "clone", git_config["repo"], repo_dir]) + + _checkout_branch_and_commit(git_config, repo_dir) + + ret = {"entry_point": entry_point, "source_dir": source_dir, "dependencies": dependencies} + # check if the cloned repo contains entry point, source directory and dependencies + if source_dir: + if not os.path.isdir(os.path.join(repo_dir, source_dir)): + raise ValueError("Source directory does not exist in the repo.") + if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)): + raise ValueError("Entry point does not exist in the repo.") + ret["source_dir"] = os.path.join(repo_dir, source_dir) + else: + if not os.path.isfile(os.path.join(repo_dir, entry_point)): + raise ValueError("Entry point does not exist in the repo.") + ret["entry_point"] = os.path.join(repo_dir, entry_point) + + ret["dependencies"] = [] + for path in dependencies: + if not os.path.exists(os.path.join(repo_dir, path)): + raise ValueError("Dependency {} does not exist in the repo.".format(path)) + ret["dependencies"].append(os.path.join(repo_dir, path)) + return ret + + +def _validate_git_config(git_config): + """check if a git_config param is valid + + Args: + git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` + and ``commit``. + + Raises: + ValueError: If: + 1. git_config has no key 'repo' + 2. git_config['repo'] is in the wrong format. + """ + if "repo" not in git_config: + raise ValueError("Please provide a repo for git_config.") + + +def _checkout_branch_and_commit(git_config, repo_dir): + """Checkout the required branch and commit. + + Args: + git_config: (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` + and ``commit``. + repo_dir (str): the directory where the repo is cloned + + Raises: + ValueError: If 1. entry point specified does not exist in the repo + 2. source dir specified does not exist in the repo + """ + if "branch" in git_config: + subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir)) + if "commit" in git_config: + subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(repo_dir)) diff --git a/tests/integ/test_git.py b/tests/integ/test_git.py new file mode 100644 index 0000000000..0f01455fe9 --- /dev/null +++ b/tests/integ/test_git.py @@ -0,0 +1,100 @@ +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +import numpy +import tempfile + +from tests.integ import lock as lock +from sagemaker.mxnet.estimator import MXNet +from sagemaker.pytorch.estimator import PyTorch +from tests.integ import DATA_DIR, PYTHON_VERSION + +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" + +# endpoint tests all use the same port, so we use this lock to prevent concurrent execution +LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock") + + +def test_git_support_with_pytorch(sagemaker_local_session): + script_path = "mnist.py" + data_path = os.path.join(DATA_DIR, "pytorch_mnist") + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + pytorch = PyTorch( + entry_point=script_path, + role="SageMakerRole", + source_dir="pytorch", + framework_version=PyTorch.LATEST_VERSION, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + git_config=git_config, + ) + + pytorch.fit({"training": "file://" + os.path.join(data_path, "training")}) + + with lock.lock(LOCK_PATH): + try: + predictor = pytorch.deploy(initial_instance_count=1, instance_type="local") + + data = numpy.zeros(shape=(1, 1, 28, 28)).astype(numpy.float32) + result = predictor.predict(data) + assert result is not None + finally: + predictor.delete_endpoint() + + +def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version): + script_path = "mnist.py" + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + dependencies = ["foo/bar.py"] + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + source_dir="mxnet", + dependencies=dependencies, + framework_version=MXNet.LATEST_VERSION, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + git_config=git_config, + ) + + mx.fit( + { + "train": "file://" + os.path.join(data_path, "train"), + "test": "file://" + os.path.join(data_path, "test"), + } + ) + + files = [file for file in os.listdir(mx.source_dir)] + assert "some_file" in files + assert "mnist.py" in files + assert os.path.exists(mx.dependencies[0]) + + with lock.lock(LOCK_PATH): + try: + predictor = mx.deploy(initial_instance_count=1, instance_type="local") + + data = numpy.zeros(shape=(1, 1, 28, 28)) + result = predictor.predict(data) + assert result is not None + finally: + predictor.delete_endpoint() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 924c494cd3..2994f917b3 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -15,6 +15,7 @@ import logging import json import os +import subprocess from time import sleep import pytest @@ -47,6 +48,19 @@ JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] OUTPUT_PATH = "s3://bucket/prefix" +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" + +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} +INSTANCE_TYPE = "c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +ROLE = "DummyRole" +IMAGE_NAME = "fakeimage" +REGION = "us-west-2" +JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) +TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] +OUTPUT_PATH = "s3://bucket/prefix" DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} @@ -760,6 +774,252 @@ def test_prepare_for_training_force_name_generation(strftime, sagemaker_session) assert JOB_NAME == fw._current_job_name +@patch("sagemaker.git_utils.git_clone_repo") +def test_git_support_with_branch_and_commit_succeed(git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + } + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + + +@patch("sagemaker.git_utils.git_clone_repo") +def test_git_support_with_branch_succeed(git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir, dependencies=None: { + "entry_point": "/tmp/repo_dir/source_dir/entry_point", + "source_dir": None, + "dependencies": None, + } + git_config = {"repo": GIT_REPO, "branch": BRANCH} + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + + +@patch("sagemaker.git_utils.git_clone_repo") +def test_git_support_with_dependencies_succeed(git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir, dependencies: { + "entry_point": "/tmp/repo_dir/source_dir/entry_point", + "source_dir": None, + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/foo/bar"], + } + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "source_dir/entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + dependencies=["foo", "foo/bar"], + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, ["foo", "foo/bar"]) + + +@patch("sagemaker.git_utils.git_clone_repo") +def test_git_support_without_branch_and_commit_succeed(git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir, dependencies=None: { + "entry_point": "/tmp/repo_dir/source_dir/entry_point", + "source_dir": None, + "dependencies": None, + } + git_config = {"repo": GIT_REPO} + entry_point = "source_dir/entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + + +def test_git_support_repo_not_provided(sagemaker_session): + git_config = {"branch": BRANCH, "commit": COMMIT} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + source_dir="source_dir", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(ValueError) as error: + fw.fit() + assert "Please provide a repo for git_config." in str(error) + + +def test_git_support_bad_repo_url_format(sagemaker_session): + git_config = {"repo": "hhttps://github.com/user/repo.git", "branch": BRANCH} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + source_dir="source_dir", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(subprocess.CalledProcessError) as error: + fw.fit() + assert "returned non-zero exit status" in str(error) + + +def test_git_support_git_clone_fail(sagemaker_session): + git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(subprocess.CalledProcessError) as error: + fw.fit() + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout branch-that-does-not-exist" + ), +) +def test_git_support_branch_not_exist(sagemaker_session): + git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(subprocess.CalledProcessError) as error: + fw.fit() + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout commit-sha-that-does-not-exist" + ), +) +def test_git_support_commit_not_exist(sagemaker_session): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(subprocess.CalledProcessError) as error: + fw.fit() + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Entry point does not exist in the repo."), +) +def test_git_support_entry_point_not_exist(sagemaker_session): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + fw = DummyFramework( + entry_point="entry_point_that_does_not_exist", + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(ValueError) as error: + fw.fit() + assert "Entry point does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Source directory does not exist in the repo."), +) +def test_git_support_source_dir_not_exist(sagemaker_session): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + source_dir="source_dir_that_does_not_exist", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(ValueError) as error: + fw.fit() + assert "Source directory does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Dependency no-such-dir does not exist in the repo."), +) +def test_git_support_dependencies_not_exist(sagemaker_session): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + fw = DummyFramework( + entry_point="entry_point", + git_config=git_config, + source_dir="source_dir", + dependencies=["foo", "no-such-dir"], + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(ValueError) as error: + fw.fit() + assert "Dependency", "does not exist in the repo." in str(error) + + @patch("time.strftime", return_value=TIMESTAMP) def test_init_with_source_dir_s3(strftime, sagemaker_session): fw = DummyFramework( @@ -1609,6 +1869,3 @@ def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session): '"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together' in str(error) ) - - -################################################################################# diff --git a/tests/unit/test_git_utils.py b/tests/unit/test_git_utils.py new file mode 100644 index 0000000000..02fb2f43e1 --- /dev/null +++ b/tests/unit/test_git_utils.py @@ -0,0 +1,164 @@ +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import subprocess +from mock import patch + +from sagemaker import git_utils + +REPO_DIR = "/tmp/repo_dir" +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + ret = git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + check_call.assert_any_call(["git", "clone", git_config["repo"], REPO_DIR]) + check_call.assert_any_call(args=["git", "checkout", BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", COMMIT], cwd=REPO_DIR) + mkdtemp.assert_called_once() + assert ret["entry_point"] == "entry_point" + assert ret["source_dir"] == "/tmp/repo_dir/source_dir" + assert ret["dependencies"] == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_repo_not_provided(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point_that_does_not_exist" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "Please provide a repo for git_config." in str(error) + + +@patch( + "subprocess.check_call", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(GIT_REPO, REPO_DIR) + ), +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_clone_fail(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "returned non-zero exit status" in str(error) + + +@patch( + "subprocess.check_call", + side_effect=[True, subprocess.CalledProcessError(returncode=1, cmd="git checkout banana")], +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_branch_not_exist(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "returned non-zero exit status" in str(error) + + +@patch( + "subprocess.check_call", + side_effect=[ + True, + True, + subprocess.CalledProcessError(returncode=1, cmd="git checkout {}".format(COMMIT)), + ], +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_commit_not_exist(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "returned non-zero exit status" in str(error) + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=False) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point_that_does_not_exist" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "Entry point does not exist in the repo." in str(error) + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=False) +@patch("os.path.exists", return_value=True) +def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir_that_does_not_exist" + dependencies = ["foo", "bar"] + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "Source directory does not exist in the repo." in str(error) + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +@patch("os.path.isdir", return_value=True) +@patch("os.path.exists", side_effect=[True, False]) +def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, check_call): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "dep_that_does_not_exist"] + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "does not exist in the repo." in str(error)