Skip to content

Commit 0874ff3

Browse files
GayathriMuralimengxr
authored andcommitted
[SPARK-13949][ML][PYTHON] PySpark ml DecisionTreeClassifier, Regressor support export/import
## What changes were proposed in this pull request? Added MLReadable and MLWritable to Decision Tree Classifier and Regressor. Added doctests. ## How was this patch tested? Python Unit tests. Tests added to check persistence in DecisionTreeClassifier and DecisionTreeRegressor. Author: GayathriMurali <[email protected]> Closes apache#11892 from GayathriMurali/SPARK-13949.
1 parent 5850977 commit 0874ff3

File tree

3 files changed

+66
-6
lines changed

3 files changed

+66
-6
lines changed

python/pyspark/ml/classification.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams):
278278
@inherit_doc
279279
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
280280
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
281-
TreeClassifierParams, HasCheckpointInterval, HasSeed):
281+
TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
282+
JavaMLReadable):
282283
"""
283284
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
284285
learning algorithm for classification.
@@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
313314
>>> model.transform(test1).head().prediction
314315
1.0
315316
317+
>>> dtc_path = temp_path + "/dtc"
318+
>>> dt.save(dtc_path)
319+
>>> dt2 = DecisionTreeClassifier.load(dtc_path)
320+
>>> dt2.getMaxDepth()
321+
2
322+
>>> model_path = temp_path + "/dtc_model"
323+
>>> model.save(model_path)
324+
>>> model2 = DecisionTreeClassificationModel.load(model_path)
325+
>>> model.featureImportances == model2.featureImportances
326+
True
327+
316328
.. versionadded:: 1.4.0
317329
"""
318330

@@ -361,7 +373,7 @@ def _create_model(self, java_model):
361373

362374

363375
@inherit_doc
364-
class DecisionTreeClassificationModel(DecisionTreeModel):
376+
class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
365377
"""
366378
Model fitted by DecisionTreeClassifier.
367379

python/pyspark/ml/regression.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams):
389389
@inherit_doc
390390
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
391391
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
392-
HasSeed):
392+
HasSeed, JavaMLWritable, JavaMLReadable):
393393
"""
394394
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
395395
learning algorithm for regression.
@@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
413413
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
414414
>>> model.transform(test1).head().prediction
415415
1.0
416+
>>> dtr_path = temp_path + "/dtr"
417+
>>> dt.save(dtr_path)
418+
>>> dt2 = DecisionTreeRegressor.load(dtr_path)
419+
>>> dt2.getMaxDepth()
420+
2
421+
>>> model_path = temp_path + "/dtr_model"
422+
>>> model.save(model_path)
423+
>>> model2 = DecisionTreeRegressionModel.load(model_path)
424+
>>> model.numNodes == model2.numNodes
425+
True
426+
>>> model.depth == model2.depth
427+
True
416428
417429
.. versionadded:: 1.4.0
418430
"""
@@ -498,7 +510,7 @@ def __repr__(self):
498510

499511

500512
@inherit_doc
501-
class DecisionTreeRegressionModel(DecisionTreeModel):
513+
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
502514
"""
503515
Model fitted by DecisionTreeRegressor.
504516

python/pyspark/ml/tests.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
import numpy as np
4343

4444
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
45-
from pyspark.ml.classification import LogisticRegression
45+
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
4646
from pyspark.ml.clustering import KMeans
4747
from pyspark.ml.evaluation import RegressionEvaluator
4848
from pyspark.ml.feature import *
4949
from pyspark.ml.param import Param, Params, TypeConverters
5050
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
51-
from pyspark.ml.regression import LinearRegression
51+
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
5252
from pyspark.ml.tuning import *
5353
from pyspark.ml.util import keyword_only
5454
from pyspark.ml.wrapper import JavaWrapper
@@ -655,6 +655,42 @@ def test_nested_pipeline_persistence(self):
655655
except OSError:
656656
pass
657657

658+
def test_decisiontree_classifier(self):
659+
dt = DecisionTreeClassifier(maxDepth=1)
660+
path = tempfile.mkdtemp()
661+
dtc_path = path + "/dtc"
662+
dt.save(dtc_path)
663+
dt2 = DecisionTreeClassifier.load(dtc_path)
664+
self.assertEqual(dt2.uid, dt2.maxDepth.parent,
665+
"Loaded DecisionTreeClassifier instance uid (%s) "
666+
"did not match Param's uid (%s)"
667+
% (dt2.uid, dt2.maxDepth.parent))
668+
self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
669+
"Loaded DecisionTreeClassifier instance default params did not match " +
670+
"original defaults")
671+
try:
672+
rmtree(path)
673+
except OSError:
674+
pass
675+
676+
def test_decisiontree_regressor(self):
677+
dt = DecisionTreeRegressor(maxDepth=1)
678+
path = tempfile.mkdtemp()
679+
dtr_path = path + "/dtr"
680+
dt.save(dtr_path)
681+
dt2 = DecisionTreeClassifier.load(dtr_path)
682+
self.assertEqual(dt2.uid, dt2.maxDepth.parent,
683+
"Loaded DecisionTreeRegressor instance uid (%s) "
684+
"did not match Param's uid (%s)"
685+
% (dt2.uid, dt2.maxDepth.parent))
686+
self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
687+
"Loaded DecisionTreeRegressor instance default params did not match " +
688+
"original defaults")
689+
try:
690+
rmtree(path)
691+
except OSError:
692+
pass
693+
658694

659695
class HasThrowableProperty(Params):
660696

0 commit comments

Comments
 (0)