Skip to content

Commit c917ad8

Browse files
authored
Merge branch 'master-jumpstart' into feat/hyperparameter-validation
2 parents 8f4aecf + d9d8c68 commit c917ad8

File tree

17 files changed

+893
-243
lines changed

17 files changed

+893
-243
lines changed

src/sagemaker/algorithm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self.validate_train_spec()
175175
self.hyperparameter_definitions = self._parse_hyperparameters()
176176

177-
self.hyperparam_dict = {}
177+
self._hyperparameters = {}
178178
if hyperparameters:
179179
self.set_hyperparameters(**hyperparameters)
180180

@@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
215215
"""Placeholder docstring"""
216216
for k, v in kwargs.items():
217217
value = self._validate_and_cast_hyperparameter(k, v)
218-
self.hyperparam_dict[k] = value
218+
self._hyperparameters[k] = value
219219

220220
self._validate_and_set_default_hyperparameters()
221221

@@ -225,7 +225,7 @@ def hyperparameters(self):
225225
The fit() method, that does the model training, calls this method to
226226
find the hyperparameters you specified.
227227
"""
228-
return self.hyperparam_dict
228+
return self._hyperparameters
229229

230230
def training_image_uri(self):
231231
"""Returns the docker image to use for training.
@@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
464464
# Check if all the required hyperparameters are set. If there is a default value
465465
# for one, set it.
466466
for name, definition in self.hyperparameter_definitions.items():
467-
if name not in self.hyperparam_dict:
467+
if name not in self._hyperparameters:
468468
spec = definition["spec"]
469469
if "DefaultValue" in spec:
470-
self.hyperparam_dict[name] = spec["DefaultValue"]
470+
self._hyperparameters[name] = spec["DefaultValue"]
471471
elif "IsRequired" in spec and spec["IsRequired"]:
472472
raise ValueError("Required hyperparameter: %s is not set" % name)
473473

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717

18-
from sagemaker.estimator import Framework
18+
from sagemaker.estimator import Framework, EstimatorBase
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
@@ -158,7 +158,9 @@ def hyperparameters(self):
158158

159159
# remove unset keys.
160160
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
161-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
161+
hyperparameters.update(
162+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
163+
)
162164
return hyperparameters
163165

164166
def create_model(

src/sagemaker/chainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
168168
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
169169
self._upload_code(deploy_key_prefix)
170170
deploy_env = dict(self.env)
171-
deploy_env.update(self._framework_env_vars())
171+
deploy_env.update(self._script_mode_env_vars())
172172

173173
if self.model_server_workers:
174174
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/estimator.py

Lines changed: 409 additions & 102 deletions
Large diffs are not rendered by default.

src/sagemaker/huggingface/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818

1919
from sagemaker.deprecations import renamed_kwargs
20-
from sagemaker.estimator import Framework
20+
from sagemaker.estimator import Framework, EstimatorBase
2121
from sagemaker.fw_utils import (
2222
framework_name_from_image,
2323
warn_if_parameter_server_with_multi_gpu,
@@ -246,13 +246,13 @@ def hyperparameters(self):
246246
distribution=self.distribution
247247
)
248248
hyperparameters.update(
249-
Framework._json_encode_hyperparameters(distributed_training_hyperparameters)
249+
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
250250
)
251251

252252
if self.compiler_config:
253253
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
254254
hyperparameters.update(
255-
Framework._json_encode_hyperparameters(training_compiler_hyperparameters)
255+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
256256
)
257257

258258
return hyperparameters

src/sagemaker/huggingface/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
273273
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
274274
self._upload_code(deploy_key_prefix, repack=True)
275275
deploy_env = dict(self.env)
276-
deploy_env.update(self._framework_env_vars())
276+
deploy_env.update(self._script_mode_env_vars())
277277

278278
if self.model_server_workers:
279279
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

0 commit comments

Comments
 (0)