|
42 | 42 | import numpy as np |
43 | 43 |
|
44 | 44 | from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer |
45 | | -from pyspark.ml.classification import LogisticRegression |
| 45 | +from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier |
46 | 46 | from pyspark.ml.clustering import KMeans |
47 | 47 | from pyspark.ml.evaluation import RegressionEvaluator |
48 | 48 | from pyspark.ml.feature import * |
49 | 49 | from pyspark.ml.param import Param, Params, TypeConverters |
50 | 50 | from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed |
51 | | -from pyspark.ml.regression import LinearRegression |
| 51 | +from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor |
52 | 52 | from pyspark.ml.tuning import * |
53 | 53 | from pyspark.ml.util import keyword_only |
54 | 54 | from pyspark.ml.wrapper import JavaWrapper |
@@ -655,6 +655,42 @@ def test_nested_pipeline_persistence(self): |
655 | 655 | except OSError: |
656 | 656 | pass |
657 | 657 |
|
| 658 | + def test_decisiontree_classifier(self): |
| 659 | + dt = DecisionTreeClassifier(maxDepth=1) |
| 660 | + path = tempfile.mkdtemp() |
| 661 | + dtc_path = path + "/dtc" |
| 662 | + dt.save(dtc_path) |
| 663 | + dt2 = DecisionTreeClassifier.load(dtc_path) |
| 664 | + self.assertEqual(dt2.uid, dt2.maxDepth.parent, |
| 665 | + "Loaded DecisionTreeClassifier instance uid (%s) " |
| 666 | + "did not match Param's uid (%s)" |
| 667 | + % (dt2.uid, dt2.maxDepth.parent)) |
| 668 | + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], |
| 669 | + "Loaded DecisionTreeClassifier instance default params did not match " + |
| 670 | + "original defaults") |
| 671 | + try: |
| 672 | + rmtree(path) |
| 673 | + except OSError: |
| 674 | + pass |
| 675 | + |
| 676 | + def test_decisiontree_regressor(self): |
| 677 | + dt = DecisionTreeRegressor(maxDepth=1) |
| 678 | + path = tempfile.mkdtemp() |
| 679 | + dtr_path = path + "/dtr" |
| 680 | + dt.save(dtr_path) |
| 681 | + dt2 = DecisionTreeClassifier.load(dtr_path) |
| 682 | + self.assertEqual(dt2.uid, dt2.maxDepth.parent, |
| 683 | + "Loaded DecisionTreeRegressor instance uid (%s) " |
| 684 | + "did not match Param's uid (%s)" |
| 685 | + % (dt2.uid, dt2.maxDepth.parent)) |
| 686 | + self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth], |
| 687 | + "Loaded DecisionTreeRegressor instance default params did not match " + |
| 688 | + "original defaults") |
| 689 | + try: |
| 690 | + rmtree(path) |
| 691 | + except OSError: |
| 692 | + pass |
| 693 | + |
658 | 694 |
|
659 | 695 | class HasThrowableProperty(Params): |
660 | 696 |
|
|
0 commit comments