Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,12 @@ def getFamily(self):
return self.getOrDefault(self.family)


class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
class LogisticRegressionModel(JavaModel, JavaClassificationModel, HasFeaturesCol,
HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization,
HasThresholds, HasWeightCol, HasAggregationDepth,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.

Expand Down Expand Up @@ -669,8 +674,11 @@ def _create_model(self, java_model):


@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, HasFeaturesCol,
HasLabelCol, HasPredictionCol, HasProbabilityCol,
HasRawPredictionCol, DecisionTreeParams,
TreeClassifierParams, HasCheckpointInterval, HasSeed,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.

Expand Down Expand Up @@ -798,8 +806,9 @@ def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)


class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, HasFeaturesCol,
HasLabelCol, HasPredictionCol, HasRawPredictionCol,
HasProbabilityCol, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestClassifier.

Expand Down Expand Up @@ -950,7 +959,8 @@ def getLossType(self):
return self.getOrDefault(self.lossType)


class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, HasFeaturesCol,
HasLabelCol, HasPredictionCol, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by GBTClassifier.
Expand Down Expand Up @@ -1105,7 +1115,9 @@ def getModelType(self):
return self.getOrDefault(self.modelType)


class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
class NaiveBayesModel(JavaModel, JavaClassificationModel, HasFeaturesCol, HasLabelCol,
HasPredictionCol, HasProbabilityCol, HasRawPredictionCol,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.

Expand Down Expand Up @@ -1304,8 +1316,9 @@ def getInitialWeights(self):
return self.getOrDefault(self.initialWeights)


class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel,
HasFeaturesCol, HasLabelCol, HasPredictionCol,
JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']


class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
class GaussianMixtureModel(JavaModel, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
HasProbabilityCol, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -181,7 +182,8 @@ def getK(self):
return self.getOrDefault(self.k)


class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasFeaturesCol,
HasPredictionCol, HasMaxIter, HasTol, HasSeed):
"""
Model fitted by KMeans.

Expand Down Expand Up @@ -324,7 +326,8 @@ def getInitSteps(self):
return self.getOrDefault(self.initSteps)


class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
class BisectingKMeansModel(JavaModel, HasFeaturesCol, HasPredictionCol, HasMaxIter,
HasSeed, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -461,7 +464,7 @@ def _create_model(self, java_model):


@inherit_doc
class LDAModel(JavaModel):
class LDAModel(JavaModel, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval):
"""
.. note:: Experimental

Expand Down
62 changes: 51 additions & 11 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ def _create_model(self, java_model):
return CountVectorizerModel(java_model)


class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol,
JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`CountVectorizer`.

Expand Down Expand Up @@ -635,7 +636,7 @@ def _create_model(self, java_model):
return IDFModel(java_model)


class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable):
class IDFModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`IDF`.

Expand Down Expand Up @@ -713,7 +714,7 @@ def _create_model(self, java_model):
return MaxAbsScalerModel(java_model)


class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class MaxAbsScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -837,7 +838,7 @@ def _create_model(self, java_model):
return MinMaxScalerModel(java_model)


class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class MinMaxScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`MinMaxScaler`.

Expand Down Expand Up @@ -1538,7 +1539,7 @@ def _create_model(self, java_model):
return StandardScalerModel(java_model)


class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class StandardScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`StandardScaler`.

Expand Down Expand Up @@ -1626,7 +1627,8 @@ def _create_model(self, java_model):
return StringIndexerModel(java_model)


class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class StringIndexerModel(JavaModel, HasInputCol, HasOutputCol, HasHandleInvalid,
JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`StringIndexer`.

Expand Down Expand Up @@ -1996,7 +1998,7 @@ def _create_model(self, java_model):
return VectorIndexerModel(java_model)


class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
class VectorIndexerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`VectorIndexer`.

Expand Down Expand Up @@ -2134,6 +2136,15 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
>>> model = word2Vec.fit(doc)
>>> estimator_paramMap = word2Vec.extractParamMap()
>>> model_paramMap = model.extractParamMap()
>>> all([estimator_paramMap[getattr(word2Vec, param.name)] == value
... for param, value in model_paramMap.items()])
True
>>> all([param.parent == model.uid for param in model_paramMap])
True
>>> [param.name for param in model.params]
['inputCol', 'maxIter', 'outputCol', 'seed', 'stepSize']
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
Expand Down Expand Up @@ -2292,7 +2303,8 @@ def _create_model(self, java_model):
return Word2VecModel(java_model)


class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
class Word2VecModel(JavaModel, HasStepSize, HasMaxIter, HasSeed, HasInputCol,
HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`Word2Vec`.

Expand Down Expand Up @@ -2333,6 +2345,15 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
>>> df = spark.createDataFrame(data,["features"])
>>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
>>> model = pca.fit(df)
>>> estimator_paramMap = pca.extractParamMap()
>>> model_paramMap = model.extractParamMap()
>>> all([estimator_paramMap[getattr(pca, param.name)] == value
... for param, value in model_paramMap.items()])
True
>>> all([param.parent == model.uid for param in model_paramMap])
True
>>> [param.name for param in model.params]
['inputCol', 'outputCol']
>>> model.transform(df).collect()[0].pca_features
DenseVector([1.648..., -4.013...])
>>> model.explainedVariance
Expand Down Expand Up @@ -2394,7 +2415,7 @@ def _create_model(self, java_model):
return PCAModel(java_model)


class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable):
class PCAModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`PCA`. Transforms vectors to a lower dimensional space.

Expand Down Expand Up @@ -2437,6 +2458,15 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
... ], ["y", "x", "s"])
>>> rf = RFormula(formula="y ~ x + s")
>>> model = rf.fit(df)
>>> estimator_paramMap = rf.extractParamMap()
>>> model_paramMap = model.extractParamMap()
>>> all([estimator_paramMap[getattr(rf, param.name)] == value
... for param, value in model_paramMap.items()])
True
>>> all([param.parent == model.uid for param in model_paramMap])
True
>>> [param.name for param in model.params]
['featuresCol', 'labelCol']
>>> model.transform(df).show()
+---+---+---+---------+-----+
| y| x| s| features|label|
Expand Down Expand Up @@ -2554,7 +2584,7 @@ def __str__(self):
return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)


class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
class RFormulaModel(JavaModel, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -2586,6 +2616,15 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
... ["features", "label"])
>>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
>>> model = selector.fit(df)
>>> estimator_paramMap = selector.extractParamMap()
>>> model_paramMap = model.extractParamMap()
>>> all([estimator_paramMap[getattr(selector, param.name)] == value
... for param, value in model_paramMap.items()])
True
>>> all([param.parent == model.uid for param in model_paramMap])
True
>>> [param.name for param in model.params]
['featuresCol', 'labelCol', 'outputCol']
>>> model.transform(df).head().selectedFeatures
DenseVector([18.0])
>>> model.selectedFeatures
Expand Down Expand Up @@ -2710,7 +2749,8 @@ def _create_model(self, java_model):
return ChiSqSelectorModel(java_model)


class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
class ChiSqSelectorModel(JavaModel, HasFeaturesCol, HasOutputCol, HasLabelCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ def hasParam(self, paramName):
return isinstance(p, Param)
else:
raise TypeError("hasParam(): paramName must be a string")
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this code is reachable, is this necessary?

param = self._resolveParam(paramName)
return param in self.params
except:
return False

@since("1.4.0")
def getOrDefault(self, param):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@


@inherit_doc
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
JavaMLWritable, JavaMLReadable):
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol,
HasRegParam, HasSeed, JavaMLWritable, JavaMLReadable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just a cleanup change, you might want to revert. It can be fixed another time and there's already a lot of changes here.

"""
Alternating Least Squares (ALS) matrix factorization.

Expand Down Expand Up @@ -333,7 +333,7 @@ def getFinalStorageLevel(self):
return self.getOrDefault(self.finalStorageLevel)


class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
class ALSModel(JavaModel, HasPredictionCol, JavaMLWritable, JavaMLReadable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure ALS has more params, why only adding predictionCol?

"""
Model fitted by ALS.

Expand Down
Loading