-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Script mode support for Estimator class #2834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Script mode support for Estimator class #2834
Conversation
Codecov Report
@@ Coverage Diff @@
## master-jumpstart #2834 +/- ##
=================================================
Coverage 89.16% 89.17%
=================================================
Files 185 185
Lines 16047 16069 +22
=================================================
+ Hits 14308 14329 +21
- Misses 1739 1740 +1
Continue to review full report at Codecov.
|
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
JGuinegagne
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you be able to add unit tests to verify that:
git_configis also handled correctly and similarly b7DummyFrameworkandEstimatorhyperparametersare handled correctly (if not covered already)
src/sagemaker/estimator.py
Outdated
| return name_from_base(self.base_job_name) | ||
|
|
||
| @staticmethod | ||
| def _json_encode_hyperparameters(hyperparameters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you are just refactoring, but is it possible to add args & return typings?
src/sagemaker/estimator.py
Outdated
| self._prepare_debugger_for_training() | ||
| self._prepare_profiler_for_training() | ||
|
|
||
| def _script_mode_hyperparam_update(self, code_dir, script): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
|
|
||
| self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) | ||
|
|
||
| def _stage_user_code_in_s3(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
| self.uploaded_code = self._stage_user_code_in_s3() | ||
| code_dir = self.uploaded_code.s3_prefix | ||
| script = self.uploaded_code.script_name | ||
| def _script_mode_hyperparam_update(self, code_dir, script): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
| code_dir (str): The directory hosting the training scripts. | ||
| script (str): The relative filepath of the training entry-point script. | ||
| """ | ||
| hyperparams = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing please:
hyperparams: Dict[str, str] = {}
tests/unit/test_estimator.py
Outdated
| @patch("sagemaker.estimator.Estimator._stage_user_code_in_s3") | ||
| def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): | ||
| patched_stage_user_code.return_value = UploadedCode( | ||
| s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: are you trying to be consistent with the rest of the module when you use the ""%(*args) pattern?
If not, could you please use f-string?
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
1441e9f to
7295190
Compare
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
JGuinegagne
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few non-blocking comments/suggestions.
src/sagemaker/estimator.py
Outdated
| try to use either CodeCommit credential helper or local | ||
| credential storage for authentication. | ||
| hyperparameters (dict): Dictionary containing the hyperparameters to | ||
| initialize this estimator with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Default: None)
src/sagemaker/estimator.py
Outdated
| If not specified, the default ``code location`` is s3://output_bucket/job-name/. | ||
| entry_point (str): Path (absolute or relative) to the local Python | ||
| source file which should be executed as the entry point to | ||
| training. If ``source_dir`` is specified, then ``entry_point`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Default: None)
src/sagemaker/estimator.py
Outdated
| return name_from_base(self.base_job_name) | ||
|
|
||
| @staticmethod | ||
| def _json_encode_hyperparameters(hyperparameters: dict) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you use Dict[str, Any] instead of dict for both the argument and the return type.
src/sagemaker/estimator.py
Outdated
| self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) | ||
|
|
||
| def _stage_user_code_in_s3(self) -> str: | ||
| """Upload the user training script to s3 and return the location. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ...and return the S3 URI.
| code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) | ||
|
|
||
| output_bucket, _ = parse_s3_url(self.output_path) | ||
| kms_key = self.output_kms_key if code_bucket == output_bucket else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-blocking question: i know you are just refactoring this code, but is this a concern? i.e. are we conforming to customer expectation if we do not use the "output" encryption key when the script gets uploaded to a different bucket than the output bucket?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea, this section was directly lifted from somewhere else in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to dive into this.
src/sagemaker/estimator.py
Outdated
| return None | ||
|
|
||
| def set_hyperparameters(self, **kwargs): | ||
| """Sets hyperparameters.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid repeating function name in docstring. How about:
"""Escape the dict argument as JSON, update the private hyperparameter attribute."""| repack=self.source_dir | ||
| and self.entry_point | ||
| and not (self.key_prefix or self.git_config), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit for readability, i would vote for:
is_repack = self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
self._upload_code(deploy_key_prefix, repack=is_repack)
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Issue #, if available:
Description of changes:
This PR adds script mode support to the
EstimatorandEstimatorBaseclasses.This was basically done by adding the parameters
source_dir, git_config, hyperparameters, container_log_level, code_location, entry_point, dependenciesto theEstimatorandEstimatorBaseclasses.Testing done:
The changes to the
Estimatorclass do not break any existing unit tests. In addition, new unit tests were added to simulate the script mode use case for theEstimatorclass, and confirm that the calls tosagemaker.create_training_job()ands3are the same for theEstimatorclass andFrameworkclass when both use script mode. A test was also added to ensure that git support works with theEstimatorclass. Integration tests will be introduced in a subsequent PR.Merge Checklist
Put an
xin the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_baseto create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.