From b9767ce5b2d8aa238976230eb1fa702eab1eb3d5 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Tue, 10 Oct 2023 14:19:44 +0200 Subject: [PATCH] fix: Fix writing into non-closed file with git clone command --- src/sagemaker/git_utils.py | 13 ++++--- tests/unit/test_git_utils.py | 66 +++++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index adde1b5585..49d151a00b 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import os +from pathlib import Path import subprocess import tempfile import warnings @@ -279,11 +280,13 @@ def _run_clone_command(repo_url, dest_dir): subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) elif repo_url.startswith("git@") or repo_url.startswith("ssh://"): try: - with tempfile.NamedTemporaryFile() as sshnoprompt: - with open(sshnoprompt.name, "w") as write_pipe: - write_pipe.write("ssh -oBatchMode=yes $@") - os.chmod(sshnoprompt.name, 0o511) - my_env["GIT_SSH"] = sshnoprompt.name + with tempfile.TemporaryDirectory() as tmp_dir: + custom_ssh_executable = Path(tmp_dir) / "ssh_batch" + with open(custom_ssh_executable, "w") as pipe: + print("#!/bin/sh", file=pipe) + print("ssh -oBatchMode=yes $@", file=pipe) + os.chmod(custom_ssh_executable, 0o511) + my_env["GIT_SSH"] = str(custom_ssh_executable) subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) except subprocess.CalledProcessError: del my_env["GIT_SSH"] diff --git a/tests/unit/test_git_utils.py b/tests/unit/test_git_utils.py index edf37fe812..03bbc1ebcd 100644 --- a/tests/unit/test_git_utils.py +++ b/tests/unit/test_git_utils.py @@ -14,6 +14,7 @@ import pytest import os +from pathlib import Path import subprocess from mock import patch, ANY @@ -34,10 +35,11 @@ @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) @patch("os.path.isdir", return_value=True) @patch("os.path.exists", return_value=True) -def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call): +def test_git_clone_repo_succeed(exists, isdir, isfile, tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" @@ -88,7 +90,8 @@ def test_git_clone_repo_git_argument_wrong_format(): ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_clone_fail(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_clone_fail(tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" @@ -103,7 +106,8 @@ def test_git_clone_repo_clone_fail(mkdtemp, check_call): side_effect=[True, subprocess.CalledProcessError(returncode=1, cmd="git checkout banana")], ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_branch_not_exist(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_branch_not_exist(tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" @@ -122,7 +126,8 @@ def test_git_clone_repo_branch_not_exist(mkdtemp, check_call): ], ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_commit_not_exist(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_commit_not_exist(tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" @@ -134,10 +139,11 @@ def test_git_clone_repo_commit_not_exist(mkdtemp, check_call): @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=False) @patch("os.path.isdir", return_value=True) @patch("os.path.exists", return_value=True) -def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, heck_call): +def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, tempdir, mkdtemp, heck_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point_that_does_not_exist" source_dir = "source_dir" @@ -149,10 +155,11 @@ def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, he @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) @patch("os.path.isdir", return_value=False) @patch("os.path.exists", return_value=True) -def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, check_call): +def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir_that_does_not_exist" @@ -164,10 +171,11 @@ def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, che @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) @patch("os.path.isdir", return_value=True) @patch("os.path.exists", side_effect=[True, False]) -def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, check_call): +def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, tempdir, mkdtemp, check_call): git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" @@ -179,8 +187,9 @@ def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, c @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_with_username_password_no_2fa(isfile, mkdtemp, check_call): +def test_git_clone_repo_with_username_password_no_2fa(isfile, tempdir, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -210,8 +219,9 @@ def test_git_clone_repo_with_username_password_no_2fa(isfile, mkdtemp, check_cal @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_with_token_no_2fa(isfile, mkdtemp, check_call): +def test_git_clone_repo_with_token_no_2fa(isfile, tempdir, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -236,8 +246,9 @@ def test_git_clone_repo_with_token_no_2fa(isfile, mkdtemp, check_call): @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_with_token_2fa(isfile, mkdtemp, check_call): +def test_git_clone_repo_with_token_2fa(isfile, tempdirm, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -264,8 +275,10 @@ def test_git_clone_repo_with_token_2fa(isfile, mkdtemp, check_call): @patch("subprocess.check_call") @patch("os.chmod") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_ssh(isfile, mkdtemp, chmod, check_call): +def test_git_clone_repo_ssh(isfile, tempdir, mkdtemp, chmod, check_call): + Path(REPO_DIR).mkdir(parents=True, exist_ok=True) git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} entry_point = "entry_point" ret = git_utils.git_clone_repo(git_config, entry_point) @@ -277,8 +290,11 @@ def test_git_clone_repo_ssh(isfile, mkdtemp, chmod, check_call): @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call): +def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided( + isfile, tempdir, mkdtemp, check_call +): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -309,8 +325,11 @@ def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(isfile, mkd @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call): +def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided( + isfile, tempdir, mkdtemp, check_call +): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -346,7 +365,8 @@ def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(isfile, mkdtem ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_with_username_and_password_wrong_creds(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_with_username_and_password_wrong_creds(tempdir, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -370,7 +390,8 @@ def test_git_clone_repo_with_username_and_password_wrong_creds(mkdtemp, check_ca ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_with_token_wrong_creds(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_with_token_wrong_creds(tempdir, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -393,7 +414,8 @@ def test_git_clone_repo_with_token_wrong_creds(mkdtemp, check_call): ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_with_and_token_2fa_wrong_creds(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_with_and_token_2fa_wrong_creds(tempdir, mkdtemp, check_call): git_config = { "repo": PRIVATE_GIT_REPO, "branch": PRIVATE_BRANCH, @@ -411,8 +433,11 @@ def test_git_clone_repo_with_and_token_2fa_wrong_creds(mkdtemp, check_call): @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) -def test_git_clone_repo_codecommit_https_with_username_and_password(isfile, mkdtemp, check_call): +def test_git_clone_repo_codecommit_https_with_username_and_password( + isfile, tempdir, mkdtemp, check_call +): git_config = { "repo": CODECOMMIT_REPO, "branch": CODECOMMIT_BRANCH, @@ -445,7 +470,9 @@ def test_git_clone_repo_codecommit_https_with_username_and_password(isfile, mkdt ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_codecommit_ssh_passphrase_required(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_codecommit_ssh_passphrase_required(tempdir, mkdtemp, check_call): + Path(REPO_DIR).mkdir(parents=True, exist_ok=True) git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH} entry_point = "entry_point" with pytest.raises(subprocess.CalledProcessError) as error: @@ -460,7 +487,8 @@ def test_git_clone_repo_codecommit_ssh_passphrase_required(mkdtemp, check_call): ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -def test_git_clone_repo_codecommit_https_creds_not_stored_locally(mkdtemp, check_call): +@patch("tempfile.TemporaryDirectory.__enter__", return_value=REPO_DIR) +def test_git_clone_repo_codecommit_https_creds_not_stored_locally(tempdir, mkdtemp, check_call): git_config = {"repo": CODECOMMIT_REPO, "branch": CODECOMMIT_BRANCH} entry_point = "entry_point" with pytest.raises(subprocess.CalledProcessError) as error: