diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 9e526e2bee..5e36392b70 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -16,16 +16,12 @@ import abc from enum import Enum -from typing import Dict, List +from typing import Dict, List, Union import attr from sagemaker.estimator import EstimatorBase, _TrainingJob -from sagemaker.inputs import ( - CreateModelInput, - TrainingInput, - TransformInput, -) +from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput from sagemaker.model import Model from sagemaker.processing import ( ProcessingInput, @@ -145,7 +141,7 @@ def __init__( self, name: str, estimator: EstimatorBase, - inputs: TrainingInput = None, + inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, cache_config: CacheConfig = None, depends_on: List[str] = None, ): @@ -157,7 +153,23 @@ def __init__( Args: name (str): The name of the training step. estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance. - inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`. + inputs (str or dict or sagemaker.inputs.TrainingInput + or sagemaker.inputs.FileSystemInput): Information + about the training data. This can be one of three types: + + * (str) the S3 location where training data is saved, or a file:// path in + local mode. + * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) If using multiple + channels for training data, you can specify a dict mapping channel names to + strings or :func:`~sagemaker.inputs.TrainingInput` objects. + * (sagemaker.inputs.TrainingInput) - channel configuration for S3 data sources + that can provide additional information as well as the path to the training + dataset. + See :func:`sagemaker.inputs.TrainingInput` for full details. + * (sagemaker.inputs.FileSystemInput) - channel configuration for + a file system data source that can provide additional information as well as + the path to the training dataset. + cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` depends on