-
Notifications
You must be signed in to change notification settings - Fork 89
feat: Support Placeholders with ModelStep #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
204381a
aa65d27
91a7b54
584cb94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -296,14 +296,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No | |
| model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here. | ||
| model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. | ||
| instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. | ||
| tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
| tags (list[dict] or Placeholders, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource. | ||
| parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None) | ||
| You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_. | ||
| """ | ||
| model_type = "FrameworkModel" if isinstance(model, FrameworkModel) else "Model" | ||
| if isinstance(model, FrameworkModel): | ||
ca-nguyen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri) | ||
| model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: didn't really have an objection to this parameter name since they do ultimately resolve to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was renamed to avoid confusion with the input |
||
| if model_name: | ||
| parameters['ModelName'] = model_name | ||
| model_parameters['ModelName'] = model_name | ||
| elif isinstance(model, Model): | ||
| parameters = { | ||
| model_parameters = { | ||
| 'ExecutionRoleArn': model.role, | ||
| 'ModelName': model_name or model.name, | ||
| 'PrimaryContainer': { | ||
|
|
@@ -315,13 +318,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No | |
| else: | ||
| raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__)) | ||
jormello marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if 'S3Operations' in parameters: | ||
| del parameters['S3Operations'] | ||
| if 'S3Operations' in model_parameters: | ||
| del model_parameters['S3Operations'] | ||
|
|
||
| if tags: | ||
| parameters['Tags'] = tags_dict_to_kv_list(tags) | ||
| model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) | ||
|
|
||
| kwargs[Field.Parameters.value] = parameters | ||
| if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): | ||
| # Update model parameters with input parameters | ||
| merge_dicts(model_parameters, kwargs[Field.Parameters.value]) | ||
|
|
||
| kwargs[Field.Parameters.value] = model_parameters | ||
|
|
||
| """ | ||
| Example resource arn: arn:aws:states:::sagemaker:createModel | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.