From 3228a2e20895473c3f13475eec9720d54bcd0a98 Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Thu, 17 Jun 2021 12:52:31 -0700 Subject: [PATCH 1/3] Correct type annotation for training step inputs --- src/sagemaker/workflow/steps.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 9e526e2bee..5ea5f1bf4b 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -16,7 +16,7 @@ import abc from enum import Enum -from typing import Dict, List +from typing import Dict, List, Union import attr @@ -145,7 +145,7 @@ def __init__( self, name: str, estimator: EstimatorBase, - inputs: TrainingInput = None, + inputs: Union[TrainingInput, dict, str] = None, cache_config: CacheConfig = None, depends_on: List[str] = None, ): @@ -157,7 +157,22 @@ def __init__( Args: name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. - inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. + inputs (str or dict or sagemaker.inputs.TrainingInput): Information + about the training data. This can be one of three types: + + * (str) the S3 location where training data is saved, or a file:// path in + local mode. + * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple + channels for training data, you can specify a dict mapping channel names to + strings or :func:`~sagemaker.inputs.TrainingInput` objects. + * (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources + that can provide additional information as well as the path to the training + dataset. + See :func:`sagemaker.inputs.TrainingInput` for full details. + * (sagemaker.session.FileSystemInput) - channel configuration for + a file system data source that can provide additional information as well as + the path to the training dataset. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` depends on From dd6e734383305f4b232c82010b55127aad7d131d Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Thu, 17 Jun 2021 13:58:29 -0700 Subject: [PATCH 2/3] Add missing type hint --- src/sagemaker/workflow/steps.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 5ea5f1bf4b..37efe9f34d 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -25,6 +25,7 @@ CreateModelInput, TrainingInput, TransformInput, + FileSystemInput ) from sagemaker.model import Model from sagemaker.processing import ( @@ -44,7 +45,6 @@ Properties, ) - class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): """Enum of step types.""" @@ -145,7 +145,7 @@ def __init__( self, name: str, estimator: EstimatorBase, - inputs: Union[TrainingInput, dict, str] = None, + inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, cache_config: CacheConfig = None, depends_on: List[str] = None, ): @@ -157,7 +157,8 @@ def __init__( Args: name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. - inputs (str or dict or sagemaker.inputs.TrainingInput): Information + inputs (str or dict or sagemaker.inputs.TrainingInput + or sagemaker.inputs.FileSystemInput): Information about the training data. This can be one of three types: * (str) the S3 location where training data is saved, or a file:// path in @@ -169,7 +170,7 @@ def __init__( that can provide additional information as well as the path to the training dataset. See :func:`sagemaker.inputs.TrainingInput` for full details. - * (sagemaker.session.FileSystemInput) - channel configuration for + * (sagemaker.inputs.FileSystemInput) - channel configuration for a file system data source that can provide additional information as well as the path to the training dataset. From 8ae60a1731fab16a18b4d769b129421b28540436 Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Thu, 17 Jun 2021 14:27:07 -0700 Subject: [PATCH 3/3] black-format --- src/sagemaker/workflow/steps.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 37efe9f34d..5e36392b70 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -21,12 +21,7 @@ import attr from sagemaker.estimator import EstimatorBase, _TrainingJob -from sagemaker.inputs import ( - CreateModelInput, - TrainingInput, - TransformInput, - FileSystemInput -) +from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput from sagemaker.model import Model from sagemaker.processing import ( ProcessingInput, @@ -45,6 +40,7 @@ Properties, ) + class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): """Enum of step types."""