Skip to content

Commit 4d4a546

Browse files
authored
Merge branch 'master' into mwfongAWS-SM-sdk
2 parents 4978670 + 239fc9d commit 4d4a546

File tree

4 files changed

+162
-30
lines changed

4 files changed

+162
-30
lines changed

doc/remote_function/sagemaker.remote_function.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ Remote function classes and methods specification
88
.. automethod:: sagemaker.remote_function.client.remote
99

1010

11-
RemoteExcutor
12-
-------------
11+
RemoteExecutor
12+
--------------
1313

1414
.. autoclass:: sagemaker.remote_function.RemoteExecutor
1515
:members:

src/sagemaker/lambda_helper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(
3636
timeout: int = 120,
3737
memory_size: int = 128,
3838
runtime: str = "python3.8",
39+
vpc_config: dict = None,
40+
architectures: list = None,
41+
environment: dict = None,
42+
layers: list = None,
3943
):
4044
"""Constructs a Lambda instance.
4145
@@ -66,6 +70,11 @@ def __init__(
6670
timeout (int): Timeout of the Lambda function in seconds. Default is 120 seconds.
6771
memory_size (int): Memory of the Lambda function in megabytes. Default is 128 MB.
6872
runtime (str): Runtime of the Lambda function. Default is set to python3.8.
73+
vpc_config (dict): VPC to deploy the Lambda function to. Default is None.
74+
architectures (list): Which architecture to deploy to. Valid Values are
75+
'x86_64' and 'arm64', default is None.
76+
environment (dict): Environment Variables for the Lambda function. Default is None.
77+
layers (list): List of Lambda layers for the Lambda function. Default is None.
6978
"""
7079
self.function_arn = function_arn
7180
self.function_name = function_name
@@ -78,6 +87,10 @@ def __init__(
7887
self.timeout = timeout
7988
self.memory_size = memory_size
8089
self.runtime = runtime
90+
self.vpc_config = vpc_config
91+
self.environment = environment
92+
self.architectures = architectures
93+
self.layers = layers
8194

8295
if function_arn is None and function_name is None:
8396
raise ValueError("Either function_arn or function_name must be provided.")
@@ -127,6 +140,10 @@ def create(self):
127140
Code=code,
128141
Timeout=self.timeout,
129142
MemorySize=self.memory_size,
143+
VpcConfig=self.vpc_config,
144+
Environment=self.environment,
145+
Architectures=self.architectures,
146+
Layers=self.layers,
130147
)
131148
return response
132149
except ClientError as e:
@@ -146,6 +163,7 @@ def update(self):
146163
response = lambda_client.update_function_code(
147164
FunctionName=self.function_name or self.function_arn,
148165
ZipFile=_zip_lambda_code(self.script),
166+
Architectures=self.architectures,
149167
)
150168
else:
151169
bucket = self.s3_bucket or self.session.default_bucket()
@@ -168,6 +186,7 @@ def update(self):
168186
zipped_code_dir=self.zipped_code_dir,
169187
s3_bucket=bucket,
170188
),
189+
Architectures=self.architectures,
171190
)
172191
return response
173192
except ClientError as e:

src/sagemaker/remote_function/client.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def remote(
8787
This decorator wraps the annotated code and runs it as a new SageMaker job synchronously
8888
with the provided runtime settings.
8989
90-
Unless mentioned otherwise, the decorator first looks up the value from the SageMaker
90+
If a parameter value is not set, the decorator first looks up the value from the SageMaker
9191
configuration file. If no value is specified in the configuration file or no configuration file
9292
is found, the decorator selects the default as specified below. For more information, see
9393
`Configuring and using defaults with the SageMaker Python SDK <https://sagemaker.readthedocs.io/
@@ -131,7 +131,7 @@ def remote(
131131
annotated with the remote decorator is invoked using the Python runtime available
132132
in the system path.
133133
134-
* The parameter dependencies is set to auto_capture. SageMaker will automatically
134+
* The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
135135
generate an env_snapshot.yml corresponding to the current active conda environment’s
136136
snapshot. You do not need to provide a dependencies file. The following conditions
137137
apply:
@@ -204,7 +204,7 @@ def remote(
204204
downloaded in the previous runs.
205205
206206
max_retry_attempts (int): The max number of times the job is retried on
207-
```InternalServerFailure``` Error from SageMaker service. Defaults to 1.
207+
``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
208208
209209
max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
210210
this specified amount of time, SageMaker terminates the job regardless of its current
@@ -475,10 +475,10 @@ def __init__(
475475
):
476476
"""Constructor for RemoteExecutor
477477
478-
Unless mentioned otherwise, the constructor first looks up the value from the SageMaker
479-
configuration file. If no value is specified in the configuration file or no configuration
480-
file is found, the constructor selects the default as specified below. For more
481-
information, see `Configuring and using defaults with the SageMaker Python SDK
478+
If a parameter value is not set, the constructor first looks up the value from the
479+
SageMaker configuration file. If no value is specified in the configuration file or
480+
no configuration file is found, the constructor selects the default as specified below.
481+
For more information, see `Configuring and using defaults with the SageMaker Python SDK
482482
<https://sagemaker.readthedocs.io/en/stable/overview.html
483483
#configuring-and-using-defaults-with-the-sagemaker-python-sdk>`_.
484484
@@ -520,7 +520,7 @@ def __init__(
520520
annotated with the remote decorator is invoked using the Python runtime available
521521
in the system path.
522522
523-
* The parameter dependencies is set to auto_capture. SageMaker will automatically
523+
* The parameter dependencies is set to ``auto_capture``. SageMaker will automatically
524524
generate an env_snapshot.yml corresponding to the current active conda environment’s
525525
snapshot. You do not need to provide a dependencies file. The following conditions
526526
apply:
@@ -595,7 +595,7 @@ def __init__(
595595
max_parallel_jobs (int): Maximum number of jobs that run in parallel. Defaults to 1.
596596
597597
max_retry_attempts (int): The max number of times the job is retried on
598-
```InternalServerFailure``` Error from SageMaker service. Defaults to 1.
598+
``InternalServerFailure`` Error from SageMaker service. Defaults to 1.
599599
600600
max_runtime_in_seconds (int): The upper limit in seconds to be used for training. After
601601
this specified amount of time, SageMaker terminates the job regardless of its current
@@ -1012,7 +1012,8 @@ def wait(
10121012
timeout (int): Timeout in seconds to wait until the job is completed before it is
10131013
stopped. Defaults to ``None``.
10141014
1015-
Returns: None
1015+
Returns:
1016+
None
10161017
"""
10171018

10181019
with self._condition:
@@ -1022,14 +1023,15 @@ def wait(
10221023
if self._state == _RUNNING:
10231024
self._job.wait(timeout=timeout)
10241025

1025-
def cancel(self):
1026+
def cancel(self) -> bool:
10261027
"""Cancel the function execution.
10271028
10281029
This method prevents the SageMaker job being created or stops the underlying SageMaker job
10291030
early if it is already in progress.
10301031
1031-
Returns: ``True`` if the underlying SageMaker job created as a result of the remote function
1032-
run is cancelled.
1032+
Returns:
1033+
``True`` if the underlying SageMaker job created as a result of the remote function
1034+
run is cancelled.
10331035
"""
10341036
with self._condition:
10351037
if self._state == _FINISHED:
@@ -1042,18 +1044,30 @@ def cancel(self):
10421044
self._state = _CANCELLED
10431045
return True
10441046

1045-
def running(self):
1046-
"""Returns ``True`` if the underlying sagemaker job is still running."""
1047+
def running(self) -> bool:
1048+
"""Check if the underlying SageMaker job is running.
1049+
1050+
Returns:
1051+
``True`` if the underlying SageMaker job is still running. ``False``, otherwise.
1052+
"""
10471053
with self._condition:
10481054
return self._state == _RUNNING
10491055

1050-
def cancelled(self):
1051-
"""Returns ``True`` if the underlying sagemaker job was cancelled. ``False``, otherwise."""
1056+
def cancelled(self) -> bool:
1057+
"""Check if the underlying SageMaker job was cancelled.
1058+
1059+
Returns:
1060+
``True`` if the underlying SageMaker job was cancelled. ``False``, otherwise.
1061+
"""
10521062
with self._condition:
10531063
return self._state == _CANCELLED
10541064

1055-
def done(self):
1056-
"""Returns ``True`` if the underlying sagemaker job finished running."""
1065+
def done(self) -> bool:
1066+
"""Check if the underlying SageMaker job is finished.
1067+
1068+
Returns:
1069+
``True`` if the underlying SageMaker job finished running. ``False``, otherwise.
1070+
"""
10571071
with self._condition:
10581072
if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [
10591073
"Completed",
@@ -1068,7 +1082,7 @@ def done(self):
10681082
return False
10691083

10701084

1071-
def get_future(job_name, sagemaker_session=None):
1085+
def get_future(job_name, sagemaker_session=None) -> Future:
10721086
"""Get a future object with information about a job with the given job_name.
10731087
10741088
Args:

0 commit comments

Comments
 (0)