From 636f74be7aa87f7f7cad76142007028e01071636 Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Tue, 7 Dec 2021 16:15:48 -0800 Subject: [PATCH 1/5] fix: Set ProcessingStep upload locations deterministically to avoid cache misses on pipeline upsert. Add a warning to cache-enabled TrainingSteps with profiling enabled --- src/sagemaker/workflow/steps.py | 32 +++ src/sagemaker/workflow/utilities.py | 21 ++ tests/unit/sagemaker/workflow/test_steps.py | 221 +++++++++++++++++--- 3 files changed, 248 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 6975c6ca97..dd81553a02 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -14,8 +14,10 @@ from __future__ import absolute_import import abc +import warnings from enum import Enum from typing import Dict, List, Union +from urllib.parse import urlparse import attr @@ -270,6 +272,16 @@ def __init__( ) self.cache_config = cache_config + if self.cache_config is not None and not self.estimator.disable_profiler: + msg = ( + "Profiling is enabled on the provided estimator. " + "The default profiler rule includes a timestamp " + "which will change each time the pipeline is " + "upserted, causing cache misses. If profiling " + "is not needed, set disable_profiler to True on the estimator." + ) + warnings.warn(msg) + @property def arguments(self) -> RequestType: """The arguments dict that is used to call `create_training_job`. @@ -498,6 +510,7 @@ def __init__( self.job_arguments = job_arguments self.code = code self.property_files = property_files + self.job_name = None # Examine why run method in sagemaker.processing.Processor mutates the processor instance # by setting the instance's arguments attribute. Refactor Processor.run, if possible. @@ -508,6 +521,17 @@ def __init__( ) self.cache_config = cache_config + if code: + code_url = urlparse(code) + if code_url.scheme == "" or code_url.scheme == "file": + # By default, Processor will upload the local code to an S3 path + # containing a timestamp. This causes cache misses whenever a + # pipeline is updated, even if the underlying script hasn't changed. + # To avoid this, hash the contents of the script and include it + # in the job_name passed to the Processor, which will be used + # instead of the timestamped path. + self.job_name = self._generate_code_upload_path() + @property def arguments(self) -> RequestType: """The arguments dict that is used to call `create_processing_job`. @@ -516,6 +540,7 @@ def arguments(self) -> RequestType: ProcessingJobName and ExperimentConfig cannot be included in the arguments. """ normalized_inputs, normalized_outputs = self.processor._normalize_args( + job_name=self.job_name, arguments=self.job_arguments, inputs=self.inputs, outputs=self.outputs, @@ -546,6 +571,13 @@ def to_request(self) -> RequestType: ] return request_dict + def _generate_code_upload_path(self) -> str: + """Generate an upload path for local processing scripts based on its contents""" + from sagemaker.workflow.utilities import hash_file + + code_hash = hash_file(self.code) + return f"{self.name}-{code_hash}"[:1024] + class TuningStep(ConfigurableRetryStep): """Tuning step for workflow.""" diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 069894d761..3e77465ff6 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import List, Sequence, Union +import hashlib from sagemaker.workflow.entities import ( Entity, @@ -37,3 +38,23 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R elif isinstance(entity, StepCollection): request_dicts.extend(entity.request_dicts()) return request_dicts + + +def hash_file(path: str) -> str: + """Get the MD5 hash of a file. + + Args: + path (str): The local path for the file. + Returns: + str: The MD5 hash of the file. + """ + BUF_SIZE = 65536 # read in 64KiB chunks + md5 = hashlib.md5() + with open(path, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + + return md5.hexdigest() diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 42c3bed7b6..8886dbf84f 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,6 +16,7 @@ import pytest import sagemaker import os +import warnings from mock import ( Mock, @@ -63,8 +64,7 @@ ) from tests.unit import DATA_DIR -SCRIPT_FILE = "dummy_script.py" -SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE) +DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") REGION = "us-west-2" BUCKET = "my-bucket" @@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client): ) +@pytest.fixture +def script_processor(sagemaker_session): + return ScriptProcessor( + role=ROLE, + image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + max_runtime_in_seconds=3600, + base_job_name="my_sklearn_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + subnets=["my_subnet_id"], + security_group_ids=["my_security_group_id"], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + + def test_custom_step(): step = CustomStep( name="MyStep", display_name="CustomStepDisplayName", description="CustomStepDescription" @@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session): training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5) training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500) estimator = TensorFlow( - entry_point=os.path.join(DATA_DIR, SCRIPT_FILE), + entry_point=DUMMY_SCRIPT_PATH, role=ROLE, model_dir=False, image_uri=IMAGE_URI, @@ -403,6 +428,101 @@ def test_training_step_tensorflow(sagemaker_session): assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} +def test_training_step_profiler_warning(sagemaker_session): + estimator = TensorFlow( + entry_point=DUMMY_SCRIPT_PATH, + role=ROLE, + model_dir=False, + image_uri=IMAGE_URI, + source_dir="s3://mybucket/source", + framework_version="2.4.1", + py_version="py37", + disable_profiler=False, + instance_count=1, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + debugger_hook_config=False, + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ) + + inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + with warnings.catch_warnings(record=True) as w: + TrainingStep( + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config + ) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Profiling is enabled on the provided estimator" in str(w[-1].message) + + +def test_training_step_no_profiler_warning(sagemaker_session): + estimator = TensorFlow( + entry_point=DUMMY_SCRIPT_PATH, + role=ROLE, + model_dir=False, + image_uri=IMAGE_URI, + source_dir="s3://mybucket/source", + framework_version="2.4.1", + py_version="py37", + disable_profiler=True, + instance_count=1, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + debugger_hook_config=False, + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ) + + inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + with warnings.catch_warnings(record=True) as w: + # profiler disabled, cache config not None + TrainingStep( + name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config + ) + assert len(w) == 0 + + with warnings.catch_warnings(record=True) as w: + # profiler enabled, cache config is None + estimator.disable_profiler = False + TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=None) + assert len(w) == 0 + + +def test_training_step_profiler_not_explicitly_enabled(sagemaker_session): + estimator = TensorFlow( + entry_point=DUMMY_SCRIPT_PATH, + role=ROLE, + model_dir=False, + image_uri=IMAGE_URI, + source_dir="s3://mybucket/source", + framework_version="2.4.1", + py_version="py37", + instance_count=1, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + debugger_hook_config=False, + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ) + + inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + step = TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs) + step_request = step.to_request() + assert step_request["Arguments"]["ProfilerRuleConfigurations"] is None + + def test_processing_step(sagemaker_session): processing_input_data_uri_parameter = ParameterString( name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" @@ -473,28 +593,42 @@ def test_processing_step(sagemaker_session): @patch("sagemaker.processing.ScriptProcessor._normalize_args") -def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session): - processor = ScriptProcessor( - role=ROLE, - image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", - command=["python3"], - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=100, - volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", - output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", - max_runtime_in_seconds=3600, - base_job_name="my_sklearn_processor", - env={"my_env_variable": "my_env_variable_value"}, - tags=[{"Key": "my-tag", "Value": "my-tag-value"}], - network_config=NetworkConfig( - subnets=["my_subnet_id"], - security_group_ids=["my_security_group_id"], - enable_network_isolation=True, - encrypt_inter_container_traffic=True, - ), - sagemaker_session=sagemaker_session, +def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, script_processor): + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + inputs = [ + ProcessingInput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + outputs = [ + ProcessingOutput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + step = ProcessingStep( + name="MyProcessingStep", + processor=script_processor, + code=DUMMY_SCRIPT_PATH, + inputs=inputs, + outputs=outputs, + job_arguments=["arg1", "arg2"], + cache_config=cache_config, + ) + mock_normalize_args.return_value = [step.inputs, step.outputs] + step.to_request() + mock_normalize_args.assert_called_with( + job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db", + arguments=step.job_arguments, + inputs=step.inputs, + outputs=step.outputs, + code=step.code, ) + + +@patch("sagemaker.processing.ScriptProcessor._normalize_args") +def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, script_processor): cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") inputs = [ ProcessingInput( @@ -510,8 +644,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) ] step = ProcessingStep( name="MyProcessingStep", - processor=processor, - code="foo.py", + processor=script_processor, + code="s3://foo", inputs=inputs, outputs=outputs, job_arguments=["arg1", "arg2"], @@ -520,6 +654,7 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) mock_normalize_args.return_value = [step.inputs, step.outputs] step.to_request() mock_normalize_args.assert_called_with( + job_name=None, arguments=step.job_arguments, inputs=step.inputs, outputs=step.outputs, @@ -527,6 +662,40 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session) ) +@patch("sagemaker.processing.ScriptProcessor._normalize_args") +def test_processing_step_normalizes_args_with_no_code(mock_normalize_args, script_processor): + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + inputs = [ + ProcessingInput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + outputs = [ + ProcessingOutput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + step = ProcessingStep( + name="MyProcessingStep", + processor=script_processor, + inputs=inputs, + outputs=outputs, + job_arguments=["arg1", "arg2"], + cache_config=cache_config, + ) + mock_normalize_args.return_value = [step.inputs, step.outputs] + step.to_request() + mock_normalize_args.assert_called_with( + job_name=None, + arguments=step.job_arguments, + inputs=step.inputs, + outputs=step.outputs, + code=None, + ) + + def test_create_model_step(sagemaker_session): model = Model( image_uri=IMAGE_URI, From df2143ffae92396cb3e1d6dc9a6870012acb1c76 Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Tue, 7 Dec 2021 16:27:02 -0800 Subject: [PATCH 2/5] Remove old test --- tests/unit/sagemaker/workflow/test_steps.py | 26 --------------------- 1 file changed, 26 deletions(-) diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 8886dbf84f..3c2adc7bd9 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -497,32 +497,6 @@ def test_training_step_no_profiler_warning(sagemaker_session): assert len(w) == 0 -def test_training_step_profiler_not_explicitly_enabled(sagemaker_session): - estimator = TensorFlow( - entry_point=DUMMY_SCRIPT_PATH, - role=ROLE, - model_dir=False, - image_uri=IMAGE_URI, - source_dir="s3://mybucket/source", - framework_version="2.4.1", - py_version="py37", - instance_count=1, - instance_type="ml.p3.16xlarge", - sagemaker_session=sagemaker_session, - hyperparameters={ - "batch-size": 500, - "epochs": 5, - }, - debugger_hook_config=False, - distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, - ) - - inputs = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") - step = TrainingStep(name="MyTrainingStep", estimator=estimator, inputs=inputs) - step_request = step.to_request() - assert step_request["Arguments"]["ProfilerRuleConfigurations"] is None - - def test_processing_step(sagemaker_session): processing_input_data_uri_parameter = ParameterString( name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" From 81a453b04ddef301d096db43cf8d35dd01497ee4 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Thu, 4 Nov 2021 10:04:29 -0700 Subject: [PATCH 3/5] feature: allow conditional parellel builds (#2727) --- ci-scripts/queue_build.py | 150 ++++++++++++++++++++++++-------------- 1 file changed, 94 insertions(+), 56 deletions(-) diff --git a/ci-scripts/queue_build.py b/ci-scripts/queue_build.py index de781be7b1..fcff0b9a9b 100644 --- a/ci-scripts/queue_build.py +++ b/ci-scripts/queue_build.py @@ -23,34 +23,26 @@ ).get_caller_identity()["Account"] bucket_name = "sagemaker-us-west-2-%s" % account +MAX_IN_PROGRESS_BUILDS = 3 +INTERVAL_BETWEEN_CONCURRENT_RUNS = 15 # minutes +CLEAN_UP_TICKETS_OLDER_THAN = 8 # hours + def queue_build(): - build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID")) - source_version = re.sub( - "[_/]", - "-", - os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"), - ) ticket_number = int(1000 * time.time()) - filename = "%s_%s_%s" % (ticket_number, build_id, source_version) - - print("Created queue ticket %s" % ticket_number) - - _write_ticket(filename) files = _list_tickets() - _cleanup_tickets_older_than_8_hours(files) - _wait_for_other_builds(files, ticket_number) + _cleanup_tickets_older_than(files) + _wait_for_other_builds(ticket_number) def _build_info_from_file(file): - filename = file.key.split("/")[1] + filename = file.key.split("/")[2] ticket_number, build_id, source_version = filename.split("_") return int(ticket_number), build_id, source_version -def _wait_for_other_builds(files, ticket_number): - newfiles = list(filter(lambda file: not _file_older_than(file), files)) - sorted_files = list(sorted(newfiles, key=lambda y: y.key)) +def _wait_for_other_builds(ticket_number): + sorted_files = _list_tickets() print("build queue status:") print() @@ -58,33 +50,76 @@ def _wait_for_other_builds(files, ticket_number): for order, file in enumerate(sorted_files): file_ticket_number, build_id, source_version = _build_info_from_file(file) print( - "%s -> %s %s, ticket number: %s" % (order, build_id, source_version, file_ticket_number) + "%s -> %s %s, ticket number: %s status: %s" + % (order, build_id, source_version, file_ticket_number, file.key.split("/")[1]) ) + print() + build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID")) + source_version = re.sub( + "[_/]", + "-", + os.environ.get("CODEBUILD_SOURCE_VERSION", "CODEBUILD-SOURCE-VERSION"), + ) + filename = "%s_%s_%s" % (ticket_number, build_id, source_version) + s3_file_obj = _write_ticket(filename, status="waiting") + print("Build %s waiting to be scheduled" % filename) + + while True: + _cleanup_tickets_with_terminal_states() + waiting_tickets = _list_tickets("waiting") + if waiting_tickets: + first_waiting_ticket_number, _, _ = _build_info_from_file(_list_tickets("waiting")[0]) + else: + first_waiting_ticket_number = ticket_number + + if ( + len(_list_tickets(status="in-progress")) < 3 + and last_in_progress_elapsed_time_check() + and first_waiting_ticket_number == ticket_number + ): + # put the build in progress + print("Scheduling build %s for running.." % filename) + s3_file_obj.delete() + _write_ticket(filename, status="in-progress") + break + else: + # wait + time.sleep(30) - for file in sorted_files: - file_ticket_number, build_id, source_version = _build_info_from_file(file) - if file_ticket_number == ticket_number: +def last_in_progress_elapsed_time_check(): + in_progress_tickets = _list_tickets("in-progress") + if not in_progress_tickets: + return True + last_in_progress_ticket, _, _ = _build_info_from_file(_list_tickets("in-progress")[-1]) + _elapsed_time = int(1000 * time.time()) - last_in_progress_ticket + last_in_progress_elapsed_time = int(_elapsed_time / (1000 * 60)) # in minutes + return last_in_progress_elapsed_time > INTERVAL_BETWEEN_CONCURRENT_RUNS - break - else: - while True: - client = boto3.client("codebuild") - response = client.batch_get_builds(ids=[build_id]) - build_status = response["builds"][0]["buildStatus"] - - if build_status == "IN_PROGRESS": - print( - "waiting on build %s %s %s" % (build_id, source_version, file_ticket_number) - ) - time.sleep(30) - else: - print("build %s finished, deleting lock" % build_id) - file.delete() - break - - -def _cleanup_tickets_older_than_8_hours(files): + +def _cleanup_tickets_with_terminal_states(): + files = _list_tickets() + build_ids = [] + for file in files: + _, build_id, _ = _build_info_from_file(file) + build_ids.append(build_id) + + client = boto3.client("codebuild") + response = client.batch_get_builds(ids=build_ids) + + for file, build_details in zip(files, response["builds"]): + _, _build_id_from_file, _ = _build_info_from_file(file) + build_status = build_details["buildStatus"] + + if build_status != "IN_PROGRESS" and _build_id_from_file == build_details["id"]: + print( + "Build %s in terminal state: %s, deleting lock" + % (_build_id_from_file, build_status) + ) + file.delete() + + +def _cleanup_tickets_older_than(files): oldfiles = list(filter(_file_older_than, files)) for file in oldfiles: print("object %s older than 8 hours. Deleting" % file.key) @@ -92,31 +127,34 @@ def _cleanup_tickets_older_than_8_hours(files): return files -def _list_tickets(): +def _list_tickets(status=None): s3 = boto3.resource("s3") bucket = s3.Bucket(bucket_name) - objects = [file for file in bucket.objects.filter(Prefix="ci-lock/")] - files = list(filter(lambda x: x != "ci-lock/", objects)) - return files + prefix = "ci-integ-queue/{}/".format(status) if status else "ci-integ-queue/" + objects = [file for file in bucket.objects.filter(Prefix=prefix)] + files = list(filter(lambda x: x != prefix, objects)) + sorted_files = list(sorted(files, key=lambda y: y.key)) + return sorted_files def _file_older_than(file): - timelimit = 1000 * 60 * 60 * 8 - + timelimit = 1000 * 60 * 60 * CLEAN_UP_TICKETS_OLDER_THAN file_ticket_number, build_id, source_version = _build_info_from_file(file) + return int(1000 * time.time()) - file_ticket_number > timelimit - return int(time.time()) - file_ticket_number > timelimit - - -def _write_ticket(ticket_number): - if not os.path.exists("ci-lock"): - os.mkdir("ci-lock") +def _write_ticket(filename, status="waiting"): + file_path = "ci-integ-queue/{}".format(status) + if not os.path.exists(file_path): + os.makedirs(file_path) - filename = "ci-lock/" + ticket_number - with open(filename, "w") as file: - file.write(ticket_number) - boto3.Session().resource("s3").Object(bucket_name, filename).upload_file(filename) + file_full_path = file_path + "/" + filename + with open(file_full_path, "w") as file: + file.write(filename) + s3_file_obj = boto3.Session().resource("s3").Object(bucket_name, file_full_path) + s3_file_obj.upload_file(file_full_path) + print("Build %s is now in state %s" % (filename, status)) + return s3_file_obj if __name__ == "__main__": From 565e70e17b2fc49458ef2a02c559affca2766623 Mon Sep 17 00:00:00 2001 From: Basil Beirouti Date: Mon, 6 Dec 2021 14:56:50 -0800 Subject: [PATCH 4/5] fix endpoint bug (#2772) Co-authored-by: Basil Beirouti --- src/sagemaker/session.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 26eba556f5..828371c6dc 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3556,19 +3556,17 @@ def endpoint_from_production_variants( Returns: str: The name of the created ``Endpoint``. """ - if not _deployment_entity_exists( - lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name) - ): - config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} - tags = _append_project_tags(tags) - if tags: - config_options["Tags"] = tags - if kms_key: - config_options["KmsKeyId"] = kms_key - if data_capture_config_dict is not None: - config_options["DataCaptureConfig"] = data_capture_config_dict - - self.sagemaker_client.create_endpoint_config(**config_options) + config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} + tags = _append_project_tags(tags) + if tags: + config_options["Tags"] = tags + if kms_key: + config_options["KmsKeyId"] = kms_key + if data_capture_config_dict is not None: + config_options["DataCaptureConfig"] = data_capture_config_dict + + self.sagemaker_client.create_endpoint_config(**config_options) + return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait) def expand_role(self, role): From 34dd43a3d105867ee803ad86415b6e67dc2b64b6 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Tue, 7 Dec 2021 13:45:57 -0800 Subject: [PATCH 5/5] fix: local mode - support relative file structure (#2768) --- src/sagemaker/local/image.py | 9 ++++--- src/sagemaker/local/utils.py | 8 +++---- tests/unit/test_local_utils.py | 44 ++++++++++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index f0a3ed8579..7a10eeacc6 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -277,7 +277,8 @@ def serve(self, model_dir, environment): script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] parsed_uri = urlparse(script_dir) if parsed_uri.scheme == "file": - volumes.append(_Volume(parsed_uri.path, "/opt/ml/code")) + host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + volumes.append(_Volume(host_dir, "/opt/ml/code")) # Update path to mount location environment = environment.copy() environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code" @@ -495,7 +496,8 @@ def _prepare_training_volumes( training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME]) parsed_uri = urlparse(training_dir) if parsed_uri.scheme == "file": - volumes.append(_Volume(parsed_uri.path, "/opt/ml/code")) + host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + volumes.append(_Volume(host_dir, "/opt/ml/code")) # Also mount a directory that all the containers can access. volumes.append(_Volume(shared_dir, "/opt/ml/shared")) @@ -504,7 +506,8 @@ def _prepare_training_volumes( parsed_uri.scheme == "file" and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters ): - intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate") + dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + intermediate_dir = os.path.join(dir_path, "output", "intermediate") if not os.path.exists(intermediate_dir): os.makedirs(intermediate_dir) volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate")) diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 5a8ce03282..352b7ec387 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -64,7 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session): """ parsed_uri = urlparse(destination) if parsed_uri.scheme == "file": - recursive_copy(source, parsed_uri.path) + dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + recursive_copy(source, dir_path) final_uri = destination elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc @@ -116,9 +117,8 @@ def get_child_process_ids(pid): (List[int]): Child process ids """ cmd = f"pgrep -P {pid}".split() - output, err = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ).communicate() + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = process.communicate() if err: return [] pids = [int(pid) for pid in output.decode("utf-8").split()] diff --git a/tests/unit/test_local_utils.py b/tests/unit/test_local_utils.py index 6384515622..be54d00a19 100644 --- a/tests/unit/test_local_utils.py +++ b/tests/unit/test_local_utils.py @@ -12,18 +12,31 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import pytest from mock import patch, Mock import sagemaker.local.utils +@patch("sagemaker.local.utils.os.path") +@patch("sagemaker.local.utils.os") +def test_copy_directory_structure(m_os, m_os_path): + m_os_path.exists.return_value = False + sagemaker.local.utils.copy_directory_structure("/tmp/", "code/") + m_os.makedirs.assert_called_with("/tmp/", "code/") + + @patch("shutil.rmtree", Mock()) @patch("sagemaker.local.utils.recursive_copy") def test_move_to_destination_local(recursive_copy): # local files will just be recursively copied - sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None) - recursive_copy.assert_called_with("/tmp/data", "/target/dir/") + # given absolute path + sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir", "job", None) + recursive_copy.assert_called_with("/tmp/data", "/target/dir") + # given relative path + sagemaker.local.utils.move_to_destination("/tmp/data", "file://root/target/dir", "job", None) + recursive_copy.assert_called_with("/tmp/data", os.path.abspath("./root/target/dir")) @patch("shutil.rmtree", Mock()) @@ -52,3 +65,30 @@ def test_move_to_destination_s3(recursive_copy): def test_move_to_destination_illegal_destination(): with pytest.raises(ValueError): sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None) + + +@patch("sagemaker.local.utils.os.path") +@patch("sagemaker.local.utils.copy_tree") +def test_recursive_copy(copy_tree, m_os_path): + m_os_path.isdir.return_value = True + sagemaker.local.utils.recursive_copy("source", "destination") + copy_tree.assert_called_with("source", "destination") + + +@patch("sagemaker.local.utils.os") +@patch("sagemaker.local.utils.get_child_process_ids") +def test_kill_child_processes(m_get_child_process_ids, m_os): + m_get_child_process_ids.return_value = ["child_pids"] + sagemaker.local.utils.kill_child_processes("pid") + m_os.kill.assert_called_with("child_pids", 15) + + +@patch("sagemaker.local.utils.subprocess") +def test_get_child_process_ids(m_subprocess): + cmd = "pgrep -P pid".split() + process_mock = Mock() + attrs = {"communicate.return_value": (b"\n", False), "returncode": 0} + process_mock.configure_mock(**attrs) + m_subprocess.Popen.return_value = process_mock + sagemaker.local.utils.get_child_process_ids("pid") + m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)