@@ -59,6 +59,16 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
5959 ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
6060 >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
6161 >>> model = lr.fit(df)
62+ >>> emap = lr.extractParamMap()
63+ >>> mmap = model.extractParamMap()
64+ >>> all([emap[getattr(lr, param.name)] == value for (param, value) in mmap.items()])
65+ True
66+ >>> all([param.parent == model.uid for param in mmap])
67+ True
68+ >>> [param.name for param in model.params] # doctest: +NORMALIZE_WHITESPACE
69+ ['elasticNetParam', 'featuresCol', 'fitIntercept', 'labelCol', 'maxIter',
70+ 'predictionCol', 'probabilityCol', 'rawPredictionCol', 'regParam',
71+ 'standardization', 'thresholds', 'tol']
6272 >>> model.coefficients
6373 DenseVector([5.5...])
6474 >>> model.intercept
@@ -206,7 +216,10 @@ def _checkThresholdConsistency(self):
206216 " threshold (%g) and thresholds (equivalent to %g)" % (t2 , t ))
207217
208218
209- class LogisticRegressionModel (JavaModel , JavaMLWritable , JavaMLReadable ):
219+ class LogisticRegressionModel (JavaModel , HasFeaturesCol , HasLabelCol , HasPredictionCol , HasMaxIter ,
220+ HasRegParam , HasTol , HasProbabilityCol , HasRawPredictionCol ,
221+ HasElasticNetParam , HasFitIntercept , HasStandardization ,
222+ HasThresholds , JavaMLWritable , JavaMLReadable ):
210223 """
211224 Model fitted by LogisticRegression.
212225
@@ -504,6 +517,16 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
504517 >>> td = si_model.transform(df)
505518 >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
506519 >>> model = dt.fit(td)
520+ >>> emap = dt.extractParamMap()
521+ >>> mmap = model.extractParamMap()
522+ >>> all([emap[getattr(dt, param.name)] == value for (param, value) in mmap.items()])
523+ True
524+ >>> all([param.parent == model.uid for param in mmap])
525+ True
526+ >>> [param.name for param in model.params] # doctest: +NORMALIZE_WHITESPACE
527+ ['cacheNodeIds', 'checkpointInterval', 'featuresCol', 'impurity', 'labelCol',
528+ 'maxBins', 'maxDepth', 'maxMemoryInMB', 'minInfoGain', 'minInstancesPerNode',
529+ 'predictionCol', 'probabilityCol', 'rawPredictionCol', 'seed']
507530 >>> model.numNodes
508531 3
509532 >>> model.depth
@@ -581,7 +604,11 @@ def _create_model(self, java_model):
581604
582605
583606@inherit_doc
584- class DecisionTreeClassificationModel (DecisionTreeModel , JavaMLWritable , JavaMLReadable ):
607+ class DecisionTreeClassificationModel (DecisionTreeModel , HasFeaturesCol , HasLabelCol ,
608+ HasPredictionCol , HasProbabilityCol , HasRawPredictionCol ,
609+ DecisionTreeParams , TreeClassifierParams ,
610+ HasCheckpointInterval , HasSeed , JavaMLWritable ,
611+ JavaMLReadable ):
585612 """
586613 Model fitted by DecisionTreeClassifier.
587614
@@ -633,6 +660,14 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
633660 >>> td = si_model.transform(df)
634661 >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
635662 >>> model = rf.fit(td)
663+ >>> emap = rf.extractParamMap()
664+ >>> mmap = model.extractParamMap()
665+ >>> all([emap[getattr(rf, param.name)] == value for (param, value) in mmap.items()])
666+ True
667+ >>> all([param.parent == model.uid for param in mmap])
668+ True
669+ >>> [param.name for param in model.params]
670+ ['featuresCol', 'labelCol', 'predictionCol', 'probabilityCol', 'rawPredictionCol']
636671 >>> model.featureImportances
637672 SparseVector(1, {0: 1.0})
638673 >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
@@ -706,7 +741,9 @@ def _create_model(self, java_model):
706741 return RandomForestClassificationModel (java_model )
707742
708743
709- class RandomForestClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
744+ class RandomForestClassificationModel (TreeEnsembleModels , HasFeaturesCol , HasLabelCol ,
745+ HasPredictionCol , HasRawPredictionCol , HasProbabilityCol ,
746+ JavaMLWritable , JavaMLReadable ):
710747 """
711748 Model fitted by RandomForestClassifier.
712749
@@ -750,6 +787,14 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
750787 >>> td = si_model.transform(df)
751788 >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
752789 >>> model = gbt.fit(td)
790+ >>> emap = gbt.extractParamMap()
791+ >>> mmap = model.extractParamMap()
792+ >>> all([emap[getattr(gbt, param.name)] == value for (param, value) in mmap.items()])
793+ True
794+ >>> all([param.parent == model.uid for param in mmap])
795+ True
796+ >>> [param.name for param in model.params]
797+ ['featuresCol', 'labelCol', 'predictionCol']
753798 >>> model.featureImportances
754799 SparseVector(1, {0: 1.0})
755800 >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
@@ -835,7 +880,8 @@ def getLossType(self):
835880 return self .getOrDefault (self .lossType )
836881
837882
838- class GBTClassificationModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
883+ class GBTClassificationModel (TreeEnsembleModels , HasFeaturesCol , HasLabelCol , HasPredictionCol ,
884+ JavaMLWritable , JavaMLReadable ):
839885 """
840886 Model fitted by GBTClassifier.
841887
@@ -879,6 +925,14 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
879925 ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
880926 >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
881927 >>> model = nb.fit(df)
928+ >>> emap = nb.extractParamMap()
929+ >>> mmap = model.extractParamMap()
930+ >>> all([emap[getattr(nb, param.name)] == value for (param, value) in mmap.items()])
931+ True
932+ >>> all([param.parent == model.uid for param in mmap])
933+ True
934+ >>> [param.name for param in model.params]
935+ ['featuresCol', 'labelCol', 'predictionCol', 'probabilityCol', 'rawPredictionCol']
882936 >>> model.pi
883937 DenseVector([-0.51..., -0.91...])
884938 >>> model.theta
@@ -978,7 +1032,8 @@ def getModelType(self):
9781032 return self .getOrDefault (self .modelType )
9791033
9801034
981- class NaiveBayesModel (JavaModel , JavaMLWritable , JavaMLReadable ):
1035+ class NaiveBayesModel (JavaModel , HasFeaturesCol , HasLabelCol , HasPredictionCol , HasProbabilityCol ,
1036+ HasRawPredictionCol , JavaMLWritable , JavaMLReadable ):
9821037 """
9831038 Model fitted by NaiveBayes.
9841039
@@ -1019,6 +1074,14 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
10191074 ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
10201075 >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=123)
10211076 >>> model = mlp.fit(df)
1077+ >>> emap = mlp.extractParamMap()
1078+ >>> mmap = model.extractParamMap()
1079+ >>> all([emap[getattr(mlp, param.name)] == value for (param, value) in mmap.items()])
1080+ True
1081+ >>> all([param.parent == model.uid for param in mmap])
1082+ True
1083+ >>> [param.name for param in model.params]
1084+ ['featuresCol', 'labelCol', 'predictionCol']
10221085 >>> model.layers
10231086 [2, 5, 2]
10241087 >>> model.weights.size
@@ -1033,7 +1096,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
10331096 |[1.0,0.0]| 1.0|
10341097 |[0.0,0.0]| 0.0|
10351098 +---------+----------+
1036- ...
1099+
10371100 >>> mlp_path = temp_path + "/mlp"
10381101 >>> mlp.save(mlp_path)
10391102 >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
@@ -1118,7 +1181,8 @@ def getBlockSize(self):
11181181 return self .getOrDefault (self .blockSize )
11191182
11201183
1121- class MultilayerPerceptronClassificationModel (JavaModel , JavaMLWritable , JavaMLReadable ):
1184+ class MultilayerPerceptronClassificationModel (JavaModel , HasFeaturesCol , HasLabelCol ,
1185+ HasPredictionCol , JavaMLWritable , JavaMLReadable ):
11221186 """
11231187 Model fitted by MultilayerPerceptronClassifier.
11241188
@@ -1184,6 +1248,14 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
11841248 >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
11851249 >>> ovr = OneVsRest(classifier=lr)
11861250 >>> model = ovr.fit(df)
1251+ >>> emap = ovr.extractParamMap()
1252+ >>> mmap = model.extractParamMap()
1253+ >>> all([emap[getattr(ovr, param.name)] == value for (param, value) in mmap.items()])
1254+ True
1255+ >>> all([param.parent == model.uid for param in mmap])
1256+ True
1257+ >>> [param.name for param in model.params]
1258+ ['classifier', 'featuresCol', 'labelCol', 'predictionCol']
11871259 >>> [x.coefficients for x in model.models]
11881260 [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])]
11891261 >>> [x.intercept for x in model.models]
@@ -1463,7 +1535,9 @@ def _to_java(self):
14631535 temp_path = tempfile .mkdtemp ()
14641536 globs ['temp_path' ] = temp_path
14651537 try :
1466- (failure_count , test_count ) = doctest .testmod (globs = globs , optionflags = doctest .ELLIPSIS )
1538+ (failure_count , test_count ) = doctest .testmod (
1539+ globs = globs ,
1540+ optionflags = doctest .ELLIPSIS | doctest .NORMALIZE_WHITESPACE )
14671541 sc .stop ()
14681542 finally :
14691543 from shutil import rmtree
0 commit comments