diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index adba12d39e..36c27d6d63 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -49,6 +49,13 @@ TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +# SELinux Enabled +SELINUX_ENABLED = os.environ.get("SAGEMAKER_LOCAL_SELINUX_ENABLED", "False").lower() in [ + "1", + "true", + "yes", +] + logger = logging.getLogger(__name__) @@ -349,6 +356,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): # Gather the artifacts from all nodes into artifacts/model and artifacts/output for host in self.hosts: volumes = compose_data["services"][str(host)]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] for volume in volumes: if re.search(r"^[A-Za-z]:", volume): unit, host_dir, container_dir = volume.split(":") @@ -887,10 +895,14 @@ def __init__(self, host_dir, container_dir=None, channel=None): self.container_dir = container_dir if container_dir else "/opt/ml/input/data/" + channel self.host_dir = host_dir + map_format = "{}:{}" + if platform.system() == "Linux" and SELINUX_ENABLED: + # Support mounting shared volumes in SELinux enabled hosts + map_format += ":z" if platform.system() == "Darwin" and host_dir.startswith("/var"): self.host_dir = os.path.join("/private", host_dir) - self.map = "{}:{}".format(self.host_dir, self.container_dir) + self.map = map_format.format(self.host_dir, self.container_dir) def _stream_output(process): diff --git a/tests/unit/sagemaker/local/test_local_image.py b/tests/unit/sagemaker/local/test_local_image.py index f7632a748d..2de10445c7 100644 --- a/tests/unit/sagemaker/local/test_local_image.py +++ b/tests/unit/sagemaker/local/test_local_image.py @@ -30,7 +30,7 @@ from mock import patch, Mock, MagicMock import sagemaker -from sagemaker.local.image import _SageMakerContainer, _aws_credentials +from sagemaker.local.image import _SageMakerContainer, _Volume, _aws_credentials REGION = "us-west-2" BUCKET_NAME = "mybucket" @@ -513,6 +513,7 @@ def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session): assert config["services"][h]["image"] == image assert config["services"][h]["command"] == "train" volumes = config["services"][h]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] assert "%s:/opt/ml/code" % "/tmp/code" in volumes assert "%s:/opt/ml/shared" % shared_folder_path in volumes @@ -564,9 +565,26 @@ def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagem assert config["services"][h]["image"] == image assert config["services"][h]["command"] == "train" volumes = config["services"][h]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] assert "%s:/opt/ml/output/intermediate" % intermediate_folder_path in volumes +@patch("platform.system", Mock(return_value="Linux")) +@patch("sagemaker.local.image.SELINUX_ENABLED", Mock(return_value=True)) +def test_container_selinux_has_label(tmpdir): + volume = _Volume(str(tmpdir), "/opt/ml/model") + + assert volume.map.endswith(":z") + + +@patch("platform.system", Mock(return_value="Darwin")) +@patch("sagemaker.local.image.SELINUX_ENABLED", Mock(return_value=True)) +def test_container_has_selinux_no_label(tmpdir): + volume = _Volume(str(tmpdir), "/opt/ml/model") + + assert not volume.map.endswith(":z") + + def test_container_has_gpu_support(tmpdir, sagemaker_session): instance_count = 1 image = "my-image" @@ -650,6 +668,7 @@ def test_serve_local_code(tmpdir, sagemaker_session): assert config["services"][h]["command"] == "serve" volumes = config["services"][h]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] assert "%s:/opt/ml/code" % "/tmp/code" in volumes assert ( "SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/code"