Skip to content

Commit 6e5bb4b

Browse files
committed
merge fix
2 parents e5c9202 + a267c3e commit 6e5bb4b

File tree

6 files changed

+76
-23
lines changed

6 files changed

+76
-23
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def read_requirements(filename):
4848
# Declare minimal set for installation
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
51-
"boto3>=1.26.28,<2.0",
51+
"boto3>=1.26.131,<2.0",
5252
"cloudpickle==2.2.1",
5353
"google-pasta",
5454
"numpy>=1.9.0,<2.0",

src/sagemaker/remote_function/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def map(self, func, *iterables):
745745
futures = map(self.submit, itertools.repeat(func), *iterables)
746746
return [future.result() for future in futures]
747747

748-
def shutdown(self, wait=True):
748+
def shutdown(self):
749749
"""Prevent more function executions to be submitted to this executor."""
750750
with self._state_condition:
751751
self._shutdown = True
@@ -756,15 +756,15 @@ def shutdown(self, wait=True):
756756
self._state_condition.notify_all()
757757

758758
if self._workers is not None:
759-
self._workers.shutdown(wait)
759+
self._workers.shutdown(wait=True)
760760

761761
def __enter__(self):
762762
"""Create an executor instance and return it"""
763763
return self
764764

765765
def __exit__(self, exc_type, exc_val, exc_tb):
766766
"""Make sure the executor instance is shutdown."""
767-
self.shutdown(wait=False)
767+
self.shutdown()
768768
return False
769769

770770
@staticmethod

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code related to the ServerlessInferenceConfig class.
1414
15-
Codes are used for configuring async inference endpoint. Use it when deploying
15+
Codes are used for configuring serverless inference endpoint. Use it when deploying
1616
the model to the endpoints.
1717
"""
1818
from __future__ import print_function, absolute_import
19+
from typing import Optional
1920

2021

2122
class ServerlessInferenceConfig(object):
@@ -29,6 +30,7 @@ def __init__(
2930
self,
3031
memory_size_in_mb: int = 2048,
3132
max_concurrency: int = 5,
33+
provisioned_concurrency: Optional[int] = None,
3234
):
3335
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
3436
@@ -40,9 +42,13 @@ def __init__(
4042
max_concurrency (int): Optional. The maximum number of concurrent invocations
4143
your serverless endpoint can process. If no value is provided, Amazon
4244
SageMaker will choose the default value for you. (Default: 5)
45+
provisioned_concurrency (int): Optional. The provisioned concurrency of your
46+
serverless endpoint. If no value is provided, Amazon SageMaker will not
47+
apply provisioned concucrrency to your Serverless endpoint. (Default: None)
4348
"""
4449
self.memory_size_in_mb = memory_size_in_mb
4550
self.max_concurrency = max_concurrency
51+
self.provisioned_concurrency = provisioned_concurrency
4652

4753
def _to_request_dict(self):
4854
"""Generates a request dictionary using the parameters provided to the class."""
@@ -51,4 +57,7 @@ def _to_request_dict(self):
5157
"MaxConcurrency": self.max_concurrency,
5258
}
5359

60+
if self.provisioned_concurrency is not None:
61+
request_dict["ProvisionedConcurrency"] = self.provisioned_concurrency
62+
5463
return request_dict

tests/integ/test_serverless_inference.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,30 @@ def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_s
4444
pca.extra_components = 5
4545
pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name)
4646

47-
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
47+
serverless_name = unique_name_from_base("pca-serverless")
48+
with timeout_and_delete_endpoint_by_name(serverless_name, sagemaker_session):
4849

4950
predictor_serverless = pca.deploy(
50-
endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig()
51+
endpoint_name=serverless_name, serverless_inference_config=ServerlessInferenceConfig()
5152
)
5253

5354
result = predictor_serverless.predict(training_set[0][:5])
5455

5556
assert len(result) == 5
5657
for record in result:
5758
assert record.label["projection"] is not None
59+
60+
# Test out Serverless Provisioned Concurrency endpoint happy case
61+
serverless_pc_name = unique_name_from_base("pca-serverless-pc")
62+
with timeout_and_delete_endpoint_by_name(serverless_pc_name, sagemaker_session):
63+
64+
predictor_serverless_pc = pca.deploy(
65+
endpoint_name=serverless_pc_name,
66+
serverless_inference_config=ServerlessInferenceConfig(provisioned_concurrency=1),
67+
)
68+
69+
result = predictor_serverless_pc.predict(training_set[0][:5])
70+
71+
assert len(result) == 5
72+
for record in result:
73+
assert record.label["projection"] is not None

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
518518
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
519519
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
520520

521-
future_1.wait()
522-
future_2.wait()
523-
future_3.wait()
524-
future_4.wait()
525-
526521
mock_start.assert_has_calls(
527522
[
528523
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None),
@@ -531,6 +526,10 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
531526
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None),
532527
]
533528
)
529+
mock_job_1.describe.assert_called()
530+
mock_job_2.describe.assert_called()
531+
mock_job_3.describe.assert_called()
532+
mock_job_4.describe.assert_called()
534533

535534
assert future_1.done()
536535
assert future_2.done()
@@ -555,15 +554,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
555554
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
556555
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
557556

558-
future_1.wait()
559-
future_2.wait()
560-
561557
mock_start.assert_has_calls(
562558
[
563559
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info),
564560
call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info),
565561
]
566562
)
563+
mock_job_1.describe.assert_called()
564+
mock_job_2.describe.assert_called()
567565

568566
assert future_1.done()
569567
assert future_2.done()
@@ -573,15 +571,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
573571
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
574572
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
575573

576-
future_3.wait()
577-
future_4.wait()
578-
579574
mock_start.assert_has_calls(
580575
[
581576
call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info),
582577
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info),
583578
]
584579
)
580+
mock_job_3.describe.assert_called()
581+
mock_job_4.describe.assert_called()
585582

586583
assert future_3.done()
587584
assert future_4.done()
@@ -633,7 +630,7 @@ def test_executor_fails_to_start_job(mock_start, *args):
633630

634631
with pytest.raises(TypeError):
635632
future_1.result()
636-
future_2.wait()
633+
print(future_2._state)
637634
assert future_2.done()
638635

639636

@@ -698,8 +695,6 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args):
698695
# submit second job
699696
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
700697

701-
future_1.wait()
702-
future_2.wait()
703698
assert future_1.done()
704699
assert future_2.done()
705700

@@ -719,9 +714,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
719714
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
720715

721716
with pytest.raises(RuntimeError):
722-
future_1.result()
717+
future_1.done()
723718
with pytest.raises(RuntimeError):
724-
future_2.result()
719+
future_2.done()
725720

726721

727722
@pytest.mark.parametrize(

tests/unit/sagemaker/serverless/test_serverless_inference_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,54 @@
1616

1717
DEFAULT_MEMORY_SIZE_IN_MB = 2048
1818
DEFAULT_MAX_CONCURRENCY = 5
19+
DEFAULT_PROVISIONED_CONCURRENCY = 5
1920

2021
DEFAULT_REQUEST_DICT = {
2122
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
2223
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
2324
}
2425

26+
PROVISIONED_CONCURRENCY_REQUEST_DICT = {
27+
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
28+
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
29+
"ProvisionedConcurrency": DEFAULT_PROVISIONED_CONCURRENCY,
30+
}
31+
2532

2633
def test_init():
2734
serverless_inference_config = ServerlessInferenceConfig()
2835

2936
assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB
3037
assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY
3138

39+
serverless_provisioned_concurrency_inference_config = ServerlessInferenceConfig(
40+
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
41+
)
42+
43+
assert (
44+
serverless_provisioned_concurrency_inference_config.memory_size_in_mb
45+
== DEFAULT_MEMORY_SIZE_IN_MB
46+
)
47+
assert (
48+
serverless_provisioned_concurrency_inference_config.max_concurrency
49+
== DEFAULT_MAX_CONCURRENCY
50+
)
51+
assert (
52+
serverless_provisioned_concurrency_inference_config.provisioned_concurrency
53+
== DEFAULT_PROVISIONED_CONCURRENCY
54+
)
55+
3256

3357
def test_to_request_dict():
3458
serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict()
3559

3660
assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT
61+
62+
serverless_provisioned_concurrency_inference_config_dict = ServerlessInferenceConfig(
63+
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
64+
)._to_request_dict()
65+
66+
assert (
67+
serverless_provisioned_concurrency_inference_config_dict
68+
== PROVISIONED_CONCURRENCY_REQUEST_DICT
69+
)

0 commit comments

Comments
 (0)