Skip to content

Commit a42bf55

Browse files
BryanCutlermengxr
authored andcommitted
[SPARK-16079][PYSPARK][ML] Added missing import for DecisionTreeRegressionModel used in GBTClassificationModel
## What changes were proposed in this pull request? Fixed missing import for DecisionTreeRegressionModel used in GBTClassificationModel trees method. ## How was this patch tested? Local tests Author: Bryan Cutler <[email protected]> Closes #13787 from BryanCutler/pyspark-GBTClassificationModel-import-SPARK-16079.
1 parent 6daa8cf commit a42bf55

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

python/pyspark/ml/classification.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from pyspark import since, keyword_only
2222
from pyspark.ml import Estimator, Model
2323
from pyspark.ml.param.shared import *
24-
from pyspark.ml.regression import (
25-
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
24+
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
25+
RandomForestParams, TreeEnsembleModels, TreeEnsembleParams
2626
from pyspark.ml.util import *
2727
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
2828
from pyspark.ml.wrapper import JavaWrapper
@@ -798,6 +798,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
798798
True
799799
>>> model.treeWeights == model2.treeWeights
800800
True
801+
>>> model.trees
802+
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
801803
802804
.. versionadded:: 1.4.0
803805
"""

python/pyspark/ml/regression.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
994994
True
995995
>>> model.treeWeights == model2.treeWeights
996996
True
997+
>>> model.trees
998+
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
997999
9981000
.. versionadded:: 1.4.0
9991001
"""

0 commit comments

Comments
 (0)