Skip to content

Commit 03cd5ad

Browse files
author
Dewen Qi
committed
change: Implement test mechanism for Pipeline variables
annotations for processors + estimators / test mechanism change annotations for processors + estimators / test mechanism change remove debug print reformatting update TM and resolve all AIs and all untested subclasses Add ppl var annotation to all composite object for training Adjust tm for recent model changes
1 parent 3560d70 commit 03cd5ad

29 files changed

+3459
-45
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import json
1919
import logging
2020
import tempfile
21-
from typing import Union
2221

2322
from six.moves.urllib.parse import urlparse
2423

@@ -32,7 +31,6 @@
3231
from sagemaker.utils import sagemaker_timestamp
3332
from sagemaker.workflow.entities import PipelineVariable
3433
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
35-
from sagemaker.workflow.entities import PipelineVariable
3634
from sagemaker.workflow.parameters import ParameterBoolean
3735
from sagemaker.workflow import is_pipeline_variable
3836

src/sagemaker/amazon/hyperparameter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import json
1717
from sagemaker.workflow import is_pipeline_variable
1818

19-
from sagemaker.workflow import is_pipeline_variable
20-
2119

2220
class Hyperparameter(object):
2321
"""An algorithm hyperparameter with optional validation.

src/sagemaker/amazon/kmeans.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828
from sagemaker.workflow.entities import PipelineVariable
2929

30-
from sagemaker.workflow.entities import PipelineVariable
31-
3230

3331
class KMeans(AmazonAlgorithmEstimatorBase):
3432
"""An unsupervised learning algorithm that attempts to find discrete groupings within data.
@@ -76,7 +74,7 @@ def __init__(
7674
half_life_time_size: Optional[int] = None,
7775
epochs: Optional[int] = None,
7876
center_factor: Optional[int] = None,
79-
eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None,
77+
eval_metrics: Optional[List[str]] = None,
8078
**kwargs
8179
):
8280
"""A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`.

src/sagemaker/amazon/linear_learner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from typing import Union, Optional
1718

1819
from sagemaker import image_uris
@@ -28,6 +29,8 @@
2829
from sagemaker.workflow.entities import PipelineVariable
2930
from sagemaker.workflow import is_pipeline_variable
3031

32+
logger = logging.getLogger(__name__)
33+
3134

3235
class LinearLearner(AmazonAlgorithmEstimatorBase):
3336
"""A supervised learning algorithms used for solving classification or regression problems.
@@ -437,9 +440,14 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
437440
# mini_batch_size can't be greater than number of records or training job fails
438441
if not mini_batch_size:
439442
if is_pipeline_variable(self.instance_count):
440-
raise ValueError(
441-
"instance_count can not be a pipeline variable when mini_batch_size is not given."
443+
logger.warning(
444+
"mini_batch_size is not given in .fit() and instance_count is a "
445+
"pipeline variable (%s) which is only parsed in execution time. "
446+
"Thus setting mini_batch_size to 1, as it can't be greater than "
447+
"number of records per instance_count, otherwise the training job fails.",
448+
type(self.instance_count),
442449
)
450+
mini_batch_size = 1
443451
else:
444452
mini_batch_size = min(
445453
self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count))

src/sagemaker/chainer/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Optional, Union, Dict
1717

1818
import logging
19-
from typing import Union, Optional
2019

2120
from sagemaker.estimator import Framework, EstimatorBase
2221
from sagemaker.fw_utils import (

src/sagemaker/clarify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
logger = logging.getLogger(__name__)
3838

3939

40-
class DataConfig: #TODO: add PipelineVariable to rest of fields
40+
class DataConfig: # TODO: add PipelineVariable to rest of fields
4141
"""Config object related to configurations of the input and output dataset."""
4242

4343
def __init__(
@@ -271,7 +271,7 @@ def get_config(self):
271271
return copy.deepcopy(self.analysis_config)
272272

273273

274-
class ModelConfig: # TODO add pipeline annotation
274+
class ModelConfig: # TODO add pipeline annotation
275275
"""Config object related to a model and its endpoint to be created."""
276276

277277
def __init__(

src/sagemaker/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
get_jumpstart_base_name_if_jumpstart_model,
5757
update_inference_tags_with_jumpstart_training_tags,
5858
)
59-
from sagemaker.debugger import RuleBase
6059
from sagemaker.local import LocalSession
6160
from sagemaker.model import (
6261
CONTAINER_LOG_LEVEL_PARAM_NAME,

src/sagemaker/huggingface/estimator.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import logging
1919
import re
20-
from typing import Optional, Union, Dict
2120

2221
from sagemaker.deprecations import renamed_kwargs
2322
from sagemaker.estimator import Framework, EstimatorBase
@@ -204,14 +203,8 @@ def __init__(
204203
f"Instead got {type(compiler_config)}"
205204
)
206205
raise ValueError(error_string)
207-
208-
compiler_config.validate(
209-
image_uri=image_uri,
210-
instance_type=instance_type,
211-
distribution=distribution,
212-
)
213-
214-
self.distribution = distribution or {}
206+
if compiler_config:
207+
compiler_config.validate(self)
215208
self.compiler_config = compiler_config
216209

217210
def _validate_args(self, image_uri):

src/sagemaker/inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import, print_function
1515

1616
from typing import Union, Optional, List
17+
1718
import attr
1819

1920
from sagemaker.workflow.entities import PipelineVariable

src/sagemaker/mxnet/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Optional, Dict, Union
1717

1818
import logging
19-
from typing import Union, Optional, Dict
2019

2120
from packaging.version import Version
2221

0 commit comments

Comments
 (0)