@@ -117,6 +117,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
117117 self .metric_definitions = metric_definitions
118118 self .model_uri = model_uri
119119 self .model_channel_name = model_channel_name
120+ self .code_uri = None
121+ self .code_channel_name = 'code'
120122
121123 if self .train_instance_type in ('local' , 'local_gpu' ):
122124 if self .train_instance_type == 'local_gpu' and self .train_instance_count > 1 :
@@ -773,9 +775,11 @@ class Framework(EstimatorBase):
773775 LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
774776 MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
775777 MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
778+ CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'
776779
777780 def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
778- container_log_level = logging .INFO , code_location = None , image_name = None , dependencies = None , ** kwargs ):
781+ container_log_level = logging .INFO , code_location = None , image_name = None , dependencies = None ,
782+ enable_network_isolation = False , ** kwargs ):
779783 """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
780784
781785 Args:
@@ -784,6 +788,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
784788 source_dir (str): Path (absolute or relative) to a directory with any other training
785789 source code dependencies aside from the entry point file (default: None). Structure within this
786790 directory are preserved when training on Amazon SageMaker.
791+ hyperparameters (dict): Hyperparameters that will be used for training (default: None).
792+ The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
793+ For convenience, this accepts other types for keys and values, but ``str()`` will be called
794+ to convert them before training.
795+ enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
796+ training jobs. This will be ignored for now and removed in a further release.
797+ container_log_level (int): Log level to use within the container (default: logging.INFO).
798+ Valid values are defined in the Python logging module.
799+ code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
800+ The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
801+ If not specified, the default code location is s3://default_bucket/job-name/. And code file
802+ uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
803+ image_name (str): An alternate image name to use instead of the official Sagemaker image
804+ for the framework. This is useful to run one of the Sagemaker supported frameworks
805+ with an image containing custom dependencies.
787806 dependencies (list[str]): A list of paths to directories (absolute or relative) with
788807 any additional libraries that will be exported to the container (default: []).
789808 The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
@@ -800,21 +819,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
800819 >>> |------ common
801820 >>> |------ virtual-env
802821
803- hyperparameters (dict): Hyperparameters that will be used for training (default: None).
804- The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
805- For convenience, this accepts other types for keys and values, but ``str()`` will be called
806- to convert them before training.
807- enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
808- training jobs. This will be ignored for now and removed in a further release.
809- container_log_level (int): Log level to use within the container (default: logging.INFO).
810- Valid values are defined in the Python logging module.
811- code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
812- The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
813- If not specified, the default code location is s3://default_bucket/job-name/. And code file
814- uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
815- image_name (str): An alternate image name to use instead of the official Sagemaker image
816- for the framework. This is useful to run one of the Sagemaker supported frameworks
817- with an image containing custom dependencies.
822+ enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
823+ isolation mode restricts the container access to outside networks (such as the internet). The container
824+ does not make any inbound or outbound network calls. If True, a channel named "code" will be created
825+ for any user entry script for training. The user entry script, files in source_dir (if specified), and
826+ dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
818827 **kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
819828 """
820829 super (Framework , self ).__init__ (** kwargs )
@@ -830,9 +839,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
830839 self .container_log_level = container_log_level
831840 self .code_location = code_location
832841 self .image_name = image_name
842+ self ._enable_network_isolation = enable_network_isolation
833843
834844 self ._hyperparameters = hyperparameters or {}
835845
846+ def enable_network_isolation (self ):
847+ """Return True if this Estimator can use network isolation to run.
848+
849+ Returns:
850+ bool: Whether this Estimator can use network isolation or not.
851+ """
852+ return self ._enable_network_isolation
853+
836854 def _prepare_for_training (self , job_name = None ):
837855 """Set hyperparameters needed for training. This method will also validate ``source_dir``.
838856
@@ -858,6 +876,11 @@ def _prepare_for_training(self, job_name=None):
858876
859877 code_dir = 'file://' + self .source_dir
860878 script = self .entry_point
879+ elif self .enable_network_isolation () and self .entry_point :
880+ self .uploaded_code = self ._stage_user_code_in_s3 ()
881+ code_dir = self .CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
882+ script = self .uploaded_code .script_name
883+ self .code_uri = self .uploaded_code .s3_prefix
861884 else :
862885 self .uploaded_code = self ._stage_user_code_in_s3 ()
863886 code_dir = self .uploaded_code .s3_prefix
@@ -881,12 +904,12 @@ def _stage_user_code_in_s3(self):
881904
882905 if self .code_location is None and local_mode :
883906 code_bucket = self .sagemaker_session .default_bucket ()
884- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
907+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
885908 kms_key = None
886909
887910 elif self .code_location is None :
888911 code_bucket , _ = parse_s3_url (self .output_path )
889- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
912+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
890913 kms_key = self .output_kms_key
891914 else :
892915 code_bucket , key_prefix = parse_s3_url (self .code_location )
0 commit comments