diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 103be47caf..b651fa90eb 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -509,6 +509,7 @@ def __init__( command: List[str] = None, instance_count: Union[int, PipelineVariable] = None, instance_type: Union[str, PipelineVariable] = None, + entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, @@ -537,6 +538,9 @@ def __init__( a processing job with. instance_type (str or PipelineVariable): The type of EC2 instance to use for processing, for example, 'ml.c4.xlarge'. + entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the + processing job (default: None). This is in the form of a list of strings + that make a command. volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume to use for storing data during processing (default: 30). volume_kms_key (str or PipelineVariable): A KMS key for the processing @@ -572,6 +576,7 @@ def __init__( image_uri=image_uri, instance_count=instance_count, instance_type=instance_type, + entrypoint=entrypoint, volume_size_in_gb=volume_size_in_gb, volume_kms_key=volume_kms_key, output_kms_key=output_kms_key, @@ -845,14 +850,16 @@ def _set_entrypoint(self, command, user_script_name): Args: user_script_name (str): A filename with an extension. """ - user_script_location = str( - pathlib.PurePosixPath( - self._CODE_CONTAINER_BASE_PATH, - self._CODE_CONTAINER_INPUT_NAME, - user_script_name, + # Only set entrypoint if user hasn't provided one + if self.entrypoint is None: + user_script_location = str( + pathlib.PurePosixPath( + self._CODE_CONTAINER_BASE_PATH, + self._CODE_CONTAINER_INPUT_NAME, + user_script_name, + ) ) - ) - self.entrypoint = command + [user_script_location] + self.entrypoint = command + [user_script_location] class ProcessingJob(_Job): @@ -1434,6 +1441,7 @@ def __init__( py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, + entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, @@ -1471,6 +1479,9 @@ def __init__( command ([str]): The command to run, along with any command-line flags to *precede* the ```code script```. Example: ["python3", "-v"]. If not provided, ["python"] will be chosen (default: None). + entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the + processing job (default: None). This is in the form of a list of strings + that make a command. volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume to use for storing data during processing (default: 30). volume_kms_key (str or PipelineVariable): A KMS key for the processing volume @@ -1523,6 +1534,7 @@ def __init__( command=command, instance_count=instance_count, instance_type=instance_type, + entrypoint=entrypoint, volume_size_in_gb=volume_size_in_gb, volume_kms_key=volume_kms_key, output_kms_key=output_kms_key, @@ -2001,13 +2013,14 @@ def _set_entrypoint(self, command, user_script_name): command ([str]): Ignored in favor of self.framework_entrypoint_command user_script_name (str): A filename with an extension. """ - - user_script_location = str( - pathlib.PurePosixPath( - self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name + # Only set entrypoint if user hasn't provided one + if self.entrypoint is None: + user_script_location = str( + pathlib.PurePosixPath( + self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name + ) ) - ) - self.entrypoint = self.framework_entrypoint_command + [user_script_location] + self.entrypoint = self.framework_entrypoint_command + [user_script_location] def _create_and_upload_runproc( self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index e04e4ba65a..2048ad06c9 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -41,6 +41,7 @@ def __init__( py_version: str = "py3", # New kwarg image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, + entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, @@ -74,6 +75,7 @@ def __init__( py_version, image_uri, command, + entrypoint, volume_size_in_gb, volume_kms_key, output_kms_key,