3131from sagemaker .pytorch import defaults
3232from sagemaker .pytorch .model import PyTorchModel
3333from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
34+ from sagemaker .workflow import is_pipeline_variable
3435from sagemaker .workflow .entities import PipelineVariable
3536
3637logger = logging .getLogger ("sagemaker" )
@@ -51,7 +52,7 @@ def __init__(
5152 source_dir : Optional [Union [str , PipelineVariable ]] = None ,
5253 hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
5354 image_uri : Optional [Union [str , PipelineVariable ]] = None ,
54- distribution : Dict = None ,
55+ distribution : Optional [ Dict ] = None ,
5556 ** kwargs
5657 ):
5758 """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -224,7 +225,7 @@ def __init__(
224225 if distribution is not None :
225226 instance_type = self ._get_instance_type ()
226227 # remove "ml." prefix
227- if instance_type [:3 ] == "ml." :
228+ if not is_pipeline_variable ( instance_type ) and instance_type [:3 ] == "ml." :
228229 instance_type = instance_type [3 :]
229230 validate_distribution_instance (self .sagemaker_session , distribution , instance_type )
230231
0 commit comments