@@ -340,7 +340,8 @@ def _create_model(self, java_model):
340340 return CountVectorizerModel (java_model )
341341
342342
343- class CountVectorizerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
343+ class CountVectorizerModel (JavaModel , HasInputCol , HasOutputCol ,
344+ JavaMLReadable , JavaMLWritable ):
344345 """
345346 Model fitted by :py:class:`CountVectorizer`.
346347
@@ -635,7 +636,7 @@ def _create_model(self, java_model):
635636 return IDFModel (java_model )
636637
637638
638- class IDFModel (JavaModel , JavaMLReadable , JavaMLWritable ):
639+ class IDFModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
639640 """
640641 Model fitted by :py:class:`IDF`.
641642
@@ -713,7 +714,7 @@ def _create_model(self, java_model):
713714 return MaxAbsScalerModel (java_model )
714715
715716
716- class MaxAbsScalerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
717+ class MaxAbsScalerModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
717718 """
718719 .. note:: Experimental
719720
@@ -837,7 +838,7 @@ def _create_model(self, java_model):
837838 return MinMaxScalerModel (java_model )
838839
839840
840- class MinMaxScalerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
841+ class MinMaxScalerModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
841842 """
842843 Model fitted by :py:class:`MinMaxScaler`.
843844
@@ -1538,7 +1539,7 @@ def _create_model(self, java_model):
15381539 return StandardScalerModel (java_model )
15391540
15401541
1541- class StandardScalerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
1542+ class StandardScalerModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
15421543 """
15431544 Model fitted by :py:class:`StandardScaler`.
15441545
@@ -1626,7 +1627,8 @@ def _create_model(self, java_model):
16261627 return StringIndexerModel (java_model )
16271628
16281629
1629- class StringIndexerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
1630+ class StringIndexerModel (JavaModel , HasInputCol , HasOutputCol , HasHandleInvalid ,
1631+ JavaMLReadable , JavaMLWritable ):
16301632 """
16311633 Model fitted by :py:class:`StringIndexer`.
16321634
@@ -1996,7 +1998,7 @@ def _create_model(self, java_model):
19961998 return VectorIndexerModel (java_model )
19971999
19982000
1999- class VectorIndexerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2001+ class VectorIndexerModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
20002002 """
20012003 Model fitted by :py:class:`VectorIndexer`.
20022004
@@ -2134,6 +2136,15 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
21342136 >>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"])
21352137 >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
21362138 >>> model = word2Vec.fit(doc)
2139+ >>> estimator_paramMap = word2Vec.extractParamMap()
2140+ >>> model_paramMap = model.extractParamMap()
2141+ >>> all([estimator_paramMap[getattr(word2Vec, param.name)] == value
2142+ ... for param, value in model_paramMap.items()])
2143+ True
2144+ >>> all([param.parent == model.uid for param in model_paramMap])
2145+ True
2146+ >>> [param.name for param in model.params]
2147+ ['inputCol', 'maxIter', 'outputCol', 'seed', 'stepSize']
21372148 >>> model.getVectors().show()
21382149 +----+--------------------+
21392150 |word| vector|
@@ -2292,7 +2303,8 @@ def _create_model(self, java_model):
22922303 return Word2VecModel (java_model )
22932304
22942305
2295- class Word2VecModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2306+ class Word2VecModel (JavaModel , HasStepSize , HasMaxIter , HasSeed , HasInputCol ,
2307+ HasOutputCol , JavaMLReadable , JavaMLWritable ):
22962308 """
22972309 Model fitted by :py:class:`Word2Vec`.
22982310
@@ -2333,6 +2345,15 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
23332345 >>> df = spark.createDataFrame(data,["features"])
23342346 >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features")
23352347 >>> model = pca.fit(df)
2348+ >>> estimator_paramMap = pca.extractParamMap()
2349+ >>> model_paramMap = model.extractParamMap()
2350+ >>> all([estimator_paramMap[getattr(pca, param.name)] == value
2351+ ... for param, value in model_paramMap.items()])
2352+ True
2353+ >>> all([param.parent == model.uid for param in model_paramMap])
2354+ True
2355+ >>> [param.name for param in model.params]
2356+ ['inputCol', 'outputCol']
23362357 >>> model.transform(df).collect()[0].pca_features
23372358 DenseVector([1.648..., -4.013...])
23382359 >>> model.explainedVariance
@@ -2394,7 +2415,7 @@ def _create_model(self, java_model):
23942415 return PCAModel (java_model )
23952416
23962417
2397- class PCAModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2418+ class PCAModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
23982419 """
23992420 Model fitted by :py:class:`PCA`. Transforms vectors to a lower dimensional space.
24002421
@@ -2437,6 +2458,15 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
24372458 ... ], ["y", "x", "s"])
24382459 >>> rf = RFormula(formula="y ~ x + s")
24392460 >>> model = rf.fit(df)
2461+ >>> estimator_paramMap = rf.extractParamMap()
2462+ >>> model_paramMap = model.extractParamMap()
2463+ >>> all([estimator_paramMap[getattr(rf, param.name)] == value
2464+ ... for param, value in model_paramMap.items()])
2465+ True
2466+ >>> all([param.parent == model.uid for param in model_paramMap])
2467+ True
2468+ >>> [param.name for param in model.params]
2469+ ['featuresCol', 'labelCol']
24402470 >>> model.transform(df).show()
24412471 +---+---+---+---------+-----+
24422472 | y| x| s| features|label|
@@ -2554,7 +2584,7 @@ def __str__(self):
25542584 return "RFormula(%s) (uid=%s)" % (formulaStr , self .uid )
25552585
25562586
2557- class RFormulaModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2587+ class RFormulaModel (JavaModel , HasFeaturesCol , HasLabelCol , JavaMLReadable , JavaMLWritable ):
25582588 """
25592589 .. note:: Experimental
25602590
@@ -2586,6 +2616,15 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
25862616 ... ["features", "label"])
25872617 >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
25882618 >>> model = selector.fit(df)
2619+ >>> estimator_paramMap = selector.extractParamMap()
2620+ >>> model_paramMap = model.extractParamMap()
2621+ >>> all([estimator_paramMap[getattr(selector, param.name)] == value
2622+ ... for param, value in model_paramMap.items()])
2623+ True
2624+ >>> all([param.parent == model.uid for param in model_paramMap])
2625+ True
2626+ >>> [param.name for param in model.params]
2627+ ['featuresCol', 'labelCol', 'outputCol']
25892628 >>> model.transform(df).head().selectedFeatures
25902629 DenseVector([18.0])
25912630 >>> model.selectedFeatures
@@ -2710,7 +2749,8 @@ def _create_model(self, java_model):
27102749 return ChiSqSelectorModel (java_model )
27112750
27122751
2713- class ChiSqSelectorModel (JavaModel , JavaMLReadable , JavaMLWritable ):
2752+ class ChiSqSelectorModel (JavaModel , HasFeaturesCol , HasOutputCol , HasLabelCol ,
2753+ JavaMLReadable , JavaMLWritable ):
27142754 """
27152755 .. note:: Experimental
27162756
0 commit comments