Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def __init__(
command: List[str] = None,
instance_count: Union[int, PipelineVariable] = None,
instance_type: Union[str, PipelineVariable] = None,
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
volume_size_in_gb: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -537,6 +538,9 @@ def __init__(
a processing job with.
instance_type (str or PipelineVariable): The type of EC2 instance to use for
processing, for example, 'ml.c4.xlarge'.
entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the
processing job (default: None). This is in the form of a list of strings
that make a command.
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
to use for storing data during processing (default: 30).
volume_kms_key (str or PipelineVariable): A KMS key for the processing
Expand Down Expand Up @@ -572,6 +576,7 @@ def __init__(
image_uri=image_uri,
instance_count=instance_count,
instance_type=instance_type,
entrypoint=entrypoint,
volume_size_in_gb=volume_size_in_gb,
volume_kms_key=volume_kms_key,
output_kms_key=output_kms_key,
Expand Down Expand Up @@ -845,14 +850,16 @@ def _set_entrypoint(self, command, user_script_name):
Args:
user_script_name (str): A filename with an extension.
"""
user_script_location = str(
pathlib.PurePosixPath(
self._CODE_CONTAINER_BASE_PATH,
self._CODE_CONTAINER_INPUT_NAME,
user_script_name,
# Only set entrypoint if user hasn't provided one
if self.entrypoint is None:
user_script_location = str(
pathlib.PurePosixPath(
self._CODE_CONTAINER_BASE_PATH,
self._CODE_CONTAINER_INPUT_NAME,
user_script_name,
)
)
)
self.entrypoint = command + [user_script_location]
self.entrypoint = command + [user_script_location]


class ProcessingJob(_Job):
Expand Down Expand Up @@ -1434,6 +1441,7 @@ def __init__(
py_version: str = "py3",
image_uri: Optional[Union[str, PipelineVariable]] = None,
command: Optional[List[str]] = None,
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
volume_size_in_gb: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -1471,6 +1479,9 @@ def __init__(
command ([str]): The command to run, along with any command-line flags
to *precede* the ```code script```. Example: ["python3", "-v"]. If not
provided, ["python"] will be chosen (default: None).
entrypoint (list[str] or list[PipelineVariable]): The entrypoint for the
processing job (default: None). This is in the form of a list of strings
that make a command.
volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume
to use for storing data during processing (default: 30).
volume_kms_key (str or PipelineVariable): A KMS key for the processing volume
Expand Down Expand Up @@ -1523,6 +1534,7 @@ def __init__(
command=command,
instance_count=instance_count,
instance_type=instance_type,
entrypoint=entrypoint,
volume_size_in_gb=volume_size_in_gb,
volume_kms_key=volume_kms_key,
output_kms_key=output_kms_key,
Expand Down Expand Up @@ -2001,13 +2013,14 @@ def _set_entrypoint(self, command, user_script_name):
command ([str]): Ignored in favor of self.framework_entrypoint_command
user_script_name (str): A filename with an extension.
"""

user_script_location = str(
pathlib.PurePosixPath(
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
# Only set entrypoint if user hasn't provided one
if self.entrypoint is None:
user_script_location = str(
pathlib.PurePosixPath(
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
)
)
)
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
self.entrypoint = self.framework_entrypoint_command + [user_script_location]

def _create_and_upload_runproc(
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/pytorch/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
py_version: str = "py3", # New kwarg
image_uri: Optional[Union[str, PipelineVariable]] = None,
command: Optional[List[str]] = None,
entrypoint: Optional[List[Union[str, PipelineVariable]]] = None,
volume_size_in_gb: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
py_version,
image_uri,
command,
entrypoint,
volume_size_in_gb,
volume_kms_key,
output_kms_key,
Expand Down
Loading