Skip to content

Commit 05c11f4

Browse files
author
Evan Chen
committed
Copied parameters over from Estimator to Transformer Estimator UID is being copied correctly to the Transformer model objects and params now, working on Doctests Changed the way parameters are copied from the Estimator to Transformer Checkpoint, switching back to inheritance method Working on DocTests Implemented Doctests for Recommendation, Clustering, Classification (except RandomForestClassifier), Evaluation, Tuning, Regression (except RandomRegression) Ready for Code Review Code Review changeset #1
1 parent 8bfc3b7 commit 05c11f4

File tree

7 files changed

+116
-44
lines changed

7 files changed

+116
-44
lines changed

python/pyspark/ml/classification.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,12 @@ def getFamily(self):
264264
return self.getOrDefault(self.family)
265265

266266

267-
class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
267+
class LogisticRegressionModel(JavaModel, JavaClassificationModel, HasFeaturesCol,
268+
HasLabelCol, HasPredictionCol, HasMaxIter,
269+
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
270+
HasElasticNetParam, HasFitIntercept, HasStandardization,
271+
HasThresholds, HasWeightCol, HasAggregationDepth,
272+
JavaMLWritable, JavaMLReadable):
268273
"""
269274
Model fitted by LogisticRegression.
270275
@@ -669,8 +674,11 @@ def _create_model(self, java_model):
669674

670675

671676
@inherit_doc
672-
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
673-
JavaMLReadable):
677+
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, HasFeaturesCol,
678+
HasLabelCol, HasPredictionCol, HasProbabilityCol,
679+
HasRawPredictionCol, DecisionTreeParams,
680+
TreeClassifierParams, HasCheckpointInterval, HasSeed,
681+
JavaMLWritable, JavaMLReadable):
674682
"""
675683
Model fitted by DecisionTreeClassifier.
676684
@@ -798,8 +806,9 @@ def _create_model(self, java_model):
798806
return RandomForestClassificationModel(java_model)
799807

800808

801-
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
802-
JavaMLReadable):
809+
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, HasFeaturesCol,
810+
HasLabelCol, HasPredictionCol, HasRawPredictionCol,
811+
HasProbabilityCol, JavaMLWritable, JavaMLReadable):
803812
"""
804813
Model fitted by RandomForestClassifier.
805814
@@ -950,7 +959,8 @@ def getLossType(self):
950959
return self.getOrDefault(self.lossType)
951960

952961

953-
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
962+
class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, HasFeaturesCol,
963+
HasLabelCol, HasPredictionCol, JavaMLWritable,
954964
JavaMLReadable):
955965
"""
956966
Model fitted by GBTClassifier.
@@ -1105,7 +1115,9 @@ def getModelType(self):
11051115
return self.getOrDefault(self.modelType)
11061116

11071117

1108-
class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
1118+
class NaiveBayesModel(JavaModel, JavaClassificationModel, HasFeaturesCol, HasLabelCol,
1119+
HasPredictionCol, HasProbabilityCol, HasRawPredictionCol,
1120+
JavaMLWritable, JavaMLReadable):
11091121
"""
11101122
Model fitted by NaiveBayes.
11111123
@@ -1304,8 +1316,9 @@ def getInitialWeights(self):
13041316
return self.getOrDefault(self.initialWeights)
13051317

13061318

1307-
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, JavaMLWritable,
1308-
JavaMLReadable):
1319+
class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel,
1320+
HasFeaturesCol, HasLabelCol, HasPredictionCol,
1321+
JavaMLWritable, JavaMLReadable):
13091322
"""
13101323
.. note:: Experimental
13111324

python/pyspark/ml/clustering.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']
2828

2929

30-
class GaussianMixtureModel(JavaModel, JavaMLWritable, JavaMLReadable):
30+
class GaussianMixtureModel(JavaModel, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
31+
HasProbabilityCol, JavaMLWritable, JavaMLReadable):
3132
"""
3233
.. note:: Experimental
3334
@@ -181,7 +182,8 @@ def getK(self):
181182
return self.getOrDefault(self.k)
182183

183184

184-
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
185+
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable, HasFeaturesCol,
186+
HasPredictionCol, HasMaxIter, HasTol, HasSeed):
185187
"""
186188
Model fitted by KMeans.
187189
@@ -324,7 +326,8 @@ def getInitSteps(self):
324326
return self.getOrDefault(self.initSteps)
325327

326328

327-
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
329+
class BisectingKMeansModel(JavaModel, HasFeaturesCol, HasPredictionCol, HasMaxIter,
330+
HasSeed, JavaMLWritable, JavaMLReadable):
328331
"""
329332
.. note:: Experimental
330333
@@ -461,7 +464,7 @@ def _create_model(self, java_model):
461464

462465

463466
@inherit_doc
464-
class LDAModel(JavaModel):
467+
class LDAModel(JavaModel, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInterval):
465468
"""
466469
.. note:: Experimental
467470

python/pyspark/ml/feature.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def _create_model(self, java_model):
340340
return CountVectorizerModel(java_model)
341341

342342

343-
class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
343+
class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol,
344+
JavaMLReadable, JavaMLWritable):
344345
"""
345346
Model fitted by :py:class:`CountVectorizer`.
346347
@@ -635,7 +636,7 @@ def _create_model(self, java_model):
635636
return IDFModel(java_model)
636637

637638

638-
class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable):
639+
class IDFModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
639640
"""
640641
Model fitted by :py:class:`IDF`.
641642
@@ -713,7 +714,7 @@ def _create_model(self, java_model):
713714
return MaxAbsScalerModel(java_model)
714715

715716

716-
class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
717+
class MaxAbsScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
717718
"""
718719
.. note:: Experimental
719720
@@ -837,7 +838,7 @@ def _create_model(self, java_model):
837838
return MinMaxScalerModel(java_model)
838839

839840

840-
class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
841+
class MinMaxScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
841842
"""
842843
Model fitted by :py:class:`MinMaxScaler`.
843844
@@ -1538,7 +1539,7 @@ def _create_model(self, java_model):
15381539
return StandardScalerModel(java_model)
15391540

15401541

1541-
class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
1542+
class StandardScalerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
15421543
"""
15431544
Model fitted by :py:class:`StandardScaler`.
15441545
@@ -1626,7 +1627,8 @@ def _create_model(self, java_model):
16261627
return StringIndexerModel(java_model)
16271628

16281629

1629-
class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
1630+
class StringIndexerModel(JavaModel, HasInputCol, HasOutputCol, HasHandleInvalid,
1631+
JavaMLReadable, JavaMLWritable):
16301632
"""
16311633
Model fitted by :py:class:`StringIndexer`.
16321634
@@ -1996,7 +1998,7 @@ def _create_model(self, java_model):
19961998
return VectorIndexerModel(java_model)
19971999

19982000

1999-
class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
2001+
class VectorIndexerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
20002002
"""
20012003
Model fitted by :py:class:`VectorIndexer`.
20022004
@@ -2134,6 +2136,15 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
21342136
>>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
21352137
>>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
21362138
>>> model = word2Vec.fit(doc)
2139+
>>> estimator_paramMap = word2Vec.extractParamMap()
2140+
>>> model_paramMap = model.extractParamMap()
2141+
>>> all([estimator_paramMap[getattr(word2Vec, param.name)] == value
2142+
... for param, value in model_paramMap.items()])
2143+
True
2144+
>>> all([param.parent == model.uid for param in model_paramMap])
2145+
True
2146+
>>> [param.name for param in model.params]
2147+
['inputCol', 'maxIter', 'outputCol', 'seed', 'stepSize']
21372148
>>> model.getVectors().show()
21382149
+----+--------------------+
21392150
|word| vector|
@@ -2292,7 +2303,8 @@ def _create_model(self, java_model):
22922303
return Word2VecModel(java_model)
22932304

22942305

2295-
class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
2306+
class Word2VecModel(JavaModel, HasStepSize, HasMaxIter, HasSeed, HasInputCol,
2307+
HasOutputCol, JavaMLReadable, JavaMLWritable):
22962308
"""
22972309
Model fitted by :py:class:`Word2Vec`.
22982310
@@ -2333,6 +2345,15 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
23332345
>>> df = spark.createDataFrame(data,["features"])
23342346
>>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
23352347
>>> model = pca.fit(df)
2348+
>>> estimator_paramMap = pca.extractParamMap()
2349+
>>> model_paramMap = model.extractParamMap()
2350+
>>> all([estimator_paramMap[getattr(pca, param.name)] == value
2351+
... for param, value in model_paramMap.items()])
2352+
True
2353+
>>> all([param.parent == model.uid for param in model_paramMap])
2354+
True
2355+
>>> [param.name for param in model.params]
2356+
['inputCol', 'outputCol']
23362357
>>> model.transform(df).collect()[0].pca_features
23372358
DenseVector([1.648..., -4.013...])
23382359
>>> model.explainedVariance
@@ -2394,7 +2415,7 @@ def _create_model(self, java_model):
23942415
return PCAModel(java_model)
23952416

23962417

2397-
class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable):
2418+
class PCAModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
23982419
"""
23992420
Model fitted by :py:class:`PCA`. Transforms vectors to a lower dimensional space.
24002421
@@ -2437,6 +2458,15 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
24372458
... ], ["y", "x", "s"])
24382459
>>> rf = RFormula(formula="y ~ x + s")
24392460
>>> model = rf.fit(df)
2461+
>>> estimator_paramMap = rf.extractParamMap()
2462+
>>> model_paramMap = model.extractParamMap()
2463+
>>> all([estimator_paramMap[getattr(rf, param.name)] == value
2464+
... for param, value in model_paramMap.items()])
2465+
True
2466+
>>> all([param.parent == model.uid for param in model_paramMap])
2467+
True
2468+
>>> [param.name for param in model.params]
2469+
['featuresCol', 'labelCol']
24402470
>>> model.transform(df).show()
24412471
+---+---+---+---------+-----+
24422472
| y| x| s| features|label|
@@ -2554,7 +2584,7 @@ def __str__(self):
25542584
return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid)
25552585

25562586

2557-
class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
2587+
class RFormulaModel(JavaModel, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
25582588
"""
25592589
.. note:: Experimental
25602590
@@ -2586,6 +2616,15 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
25862616
... ["features", "label"])
25872617
>>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
25882618
>>> model = selector.fit(df)
2619+
>>> estimator_paramMap = selector.extractParamMap()
2620+
>>> model_paramMap = model.extractParamMap()
2621+
>>> all([estimator_paramMap[getattr(selector, param.name)] == value
2622+
... for param, value in model_paramMap.items()])
2623+
True
2624+
>>> all([param.parent == model.uid for param in model_paramMap])
2625+
True
2626+
>>> [param.name for param in model.params]
2627+
['featuresCol', 'labelCol', 'outputCol']
25892628
>>> model.transform(df).head().selectedFeatures
25902629
DenseVector([18.0])
25912630
>>> model.selectedFeatures
@@ -2710,7 +2749,8 @@ def _create_model(self, java_model):
27102749
return ChiSqSelectorModel(java_model)
27112750

27122751

2713-
class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
2752+
class ChiSqSelectorModel(JavaModel, HasFeaturesCol, HasOutputCol, HasLabelCol,
2753+
JavaMLReadable, JavaMLWritable):
27142754
"""
27152755
.. note:: Experimental
27162756

python/pyspark/ml/param/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ def hasParam(self, paramName):
336336
return isinstance(p, Param)
337337
else:
338338
raise TypeError("hasParam(): paramName must be a string")
339+
try:
340+
param = self._resolveParam(paramName)
341+
return param in self.params
342+
except:
343+
return False
339344

340345
@since("1.4.0")
341346
def getOrDefault(self, param):

python/pyspark/ml/recommendation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727

2828
@inherit_doc
29-
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
30-
JavaMLWritable, JavaMLReadable):
29+
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol,
30+
HasRegParam, HasSeed, JavaMLWritable, JavaMLReadable):
3131
"""
3232
Alternating Least Squares (ALS) matrix factorization.
3333
@@ -333,7 +333,7 @@ def getFinalStorageLevel(self):
333333
return self.getOrDefault(self.finalStorageLevel)
334334

335335

336-
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
336+
class ALSModel(JavaModel, HasPredictionCol, JavaMLWritable, JavaMLReadable):
337337
"""
338338
Model fitted by ALS.
339339

0 commit comments

Comments
 (0)