diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 6d4dedf86a..5e9c2098b9 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class FactorizationMachines(AmazonAlgorithmEstimatorBase): @@ -319,7 +323,13 @@ class FactorizationMachinesModel(Model): returns :class:`FactorizationMachinesPredictor`. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for FactorizationMachinesModel class. Args: @@ -343,6 +353,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=FactorizationMachines.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, FactorizationMachinesPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(FactorizationMachinesModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 8bc9103876..097f6b45dc 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -22,7 +24,9 @@ from sagemaker.model import Model from sagemaker.serializers import CSVSerializer from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class IPInsights(AmazonAlgorithmEstimatorBase): @@ -222,7 +226,13 @@ class IPInsightsModel(Model): Predictor that calculates anomaly scores for data points. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Creates object to get insights on S3 model data. Args: @@ -246,6 +256,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=IPInsights.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, IPInsightsPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(IPInsightsModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 286fe0c026..581e93e02a 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class KMeans(AmazonAlgorithmEstimatorBase): @@ -246,7 +250,13 @@ class KMeansModel(Model): Predictor to performs k-means cluster assignment. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for KMeansModel class. Args: @@ -270,6 +280,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=KMeans.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, KMeansPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(KMeansModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 10fe640b68..14ba404ebf 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class KNN(AmazonAlgorithmEstimatorBase): @@ -238,7 +242,13 @@ class KNNModel(Model): and returns :class:`KNNPredictor`. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Function to initialize KNNModel. Args: @@ -262,6 +272,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=KNN.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, KNNPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(KNNModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 2d7c4aa58b..4158b6cc27 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class LDA(AmazonAlgorithmEstimatorBase): @@ -220,7 +224,13 @@ class LDAModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for LDAModel class. Args: @@ -244,6 +254,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=LDA.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, LDAPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(LDAModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index e0a93c0120..d02ed2875f 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class LinearLearner(AmazonAlgorithmEstimatorBase): @@ -481,7 +485,13 @@ class LinearLearnerModel(Model): :class:`LinearLearnerPredictor` """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for LinearLearnerModel. Args: @@ -505,6 +515,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=LinearLearner.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, LinearLearnerPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(LinearLearnerModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 12f3fc635c..83c2f97348 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class NTM(AmazonAlgorithmEstimatorBase): @@ -249,7 +253,13 @@ class NTMModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for NTMModel class. Args: @@ -273,6 +283,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=NTM.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, NTMPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(NTMModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index bd34eb7d19..1fbd846cbf 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -20,7 +22,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable def _list_check_subset(valid_super_list): @@ -344,7 +348,13 @@ class Object2VecModel(Model): Predictor that calculates anomaly scores for datapoints. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for Object2VecModel class. Args: @@ -368,6 +378,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=Object2Vec.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, Predictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(Object2VecModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 93f8e25caa..e3127fd7a1 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class PCA(AmazonAlgorithmEstimatorBase): @@ -237,7 +241,13 @@ class PCAModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for PCAModel. Args: @@ -261,6 +271,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=PCA.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, PCAPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(PCAModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index a1c3e7d171..c38d75e3e4 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class RandomCutForest(AmazonAlgorithmEstimatorBase): @@ -209,7 +213,13 @@ class RandomCutForestModel(Model): Predictor that calculates anomaly scores for datapoints. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for RandomCutForestModel class. Args: @@ -233,6 +243,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=RandomCutForest.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, RandomCutForestPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(RandomCutForestModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/sparkml/model.py b/src/sagemaker/sparkml/model.py index f0c32fede8..527cae0957 100644 --- a/src/sagemaker/sparkml/model.py +++ b/src/sagemaker/sparkml/model.py @@ -13,8 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import Model, Predictor, Session, image_uris from sagemaker.serializers import CSVSerializer +from sagemaker.utils import pop_out_unused_kwarg +from sagemaker.workflow.entities import PipelineVariable framework_name = "sparkml-serving" @@ -71,7 +75,12 @@ class SparkMLModel(Model): """ def __init__( - self, model_data, role=None, spark_version="2.4", sagemaker_session=None, **kwargs + self, + model_data: Union[str, PipelineVariable], + role: Optional[str] = None, + spark_version: str = "2.4", + sagemaker_session: Optional[Session] = None, + **kwargs, ): """Initialize a SparkMLModel. @@ -104,6 +113,8 @@ def __init__( # boto_region_name region_name = (sagemaker_session or Session()).boto_region_name image_uri = image_uris.retrieve(framework_name, region_name, version=spark_version) + pop_out_unused_kwarg("predictor_cls", kwargs, SparkMLPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(SparkMLModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1998525a98..4365d22f2d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -27,6 +27,7 @@ import abc import uuid from datetime import datetime +from typing import Optional import botocore from six.moves.urllib import parse @@ -827,3 +828,20 @@ def construct_container_object( ) return obj + + +def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None): + """Pop out the unused key-word argument and give a warning. + + Args: + arg_name (str): The name of the argument to be checked if it is unused. + kwargs (dict): The key-word argument dict. + override_val (str): The value used to override the unused argument (default: None). + """ + if arg_name not in kwargs: + return + warn_msg = "{} supplied in kwargs will be ignored".format(arg_name) + if override_val: + warn_msg += " and further overridden with {}.".format(override_val) + logging.warning(warn_msg) + kwargs.pop(arg_name) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4e6ba92730..5302e21fb8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -761,3 +761,15 @@ def test_partition_by_region(): assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov" assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso" assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b" + + +def test_pop_out_unused_kwarg(): + # The given arg_name is in kwargs + kwargs = dict(arg1=1, arg2=2) + sagemaker.utils.pop_out_unused_kwarg("arg1", kwargs) + assert "arg1" not in kwargs + + # The given arg_name is not in kwargs + kwargs = dict(arg1=1, arg2=2) + sagemaker.utils.pop_out_unused_kwarg("arg3", kwargs) + assert len(kwargs) == 2