@@ -40,6 +40,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
4040 HasRegParam , HasTol , HasElasticNetParam , HasFitIntercept ,
4141 HasStandardization , HasSolver , HasWeightCol , JavaMLWritable , JavaMLReadable ):
4242 """
43+ .. note:: Experimental
44+
4345 Linear regression.
4446
4547 The learning objective is to minimize the squared error, with regularization.
@@ -123,6 +125,8 @@ def _create_model(self, java_model):
123125
124126class LinearRegressionModel (JavaModel , JavaMLWritable , JavaMLReadable ):
125127 """
128+ .. note:: Experimental
129+
126130 Model fitted by LinearRegression.
127131
128132 .. versionadded:: 1.4.0
@@ -631,6 +635,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
631635 DecisionTreeParams , TreeRegressorParams , HasCheckpointInterval ,
632636 HasSeed , JavaMLWritable , JavaMLReadable , HasVarianceCol ):
633637 """
638+ .. note:: Experimental
639+
634640 `Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
635641 learning algorithm for regression.
636642 It supports both continuous and categorical features.
@@ -713,7 +719,10 @@ def _create_model(self, java_model):
713719
714720@inherit_doc
715721class DecisionTreeModel (JavaModel ):
716- """Abstraction for Decision Tree models.
722+ """
723+ .. note:: Experimental
724+
725+ Abstraction for Decision Tree models.
717726
718727 .. versionadded:: 1.5.0
719728 """
@@ -736,7 +745,10 @@ def __repr__(self):
736745
737746@inherit_doc
738747class TreeEnsembleModels (JavaModel ):
739- """Represents a tree ensemble model.
748+ """
749+ .. note:: Experimental
750+
751+ Represents a tree ensemble model.
740752
741753 .. versionadded:: 1.5.0
742754 """
@@ -754,6 +766,8 @@ def __repr__(self):
754766@inherit_doc
755767class DecisionTreeRegressionModel (DecisionTreeModel , JavaMLWritable , JavaMLReadable ):
756768 """
769+ .. note:: Experimental
770+
757771 Model fitted by DecisionTreeRegressor.
758772
759773 .. versionadded:: 1.4.0
@@ -786,6 +800,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
786800 RandomForestParams , TreeRegressorParams , HasCheckpointInterval ,
787801 JavaMLWritable , JavaMLReadable ):
788802 """
803+ .. note:: Experimental
804+
789805 `Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
790806 learning algorithm for regression.
791807 It supports both continuous and categorical features.
@@ -868,6 +884,8 @@ def _create_model(self, java_model):
868884
869885class RandomForestRegressionModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
870886 """
887+ .. note:: Experimental
888+
871889 Model fitted by RandomForestRegressor.
872890
873891 .. versionadded:: 1.4.0
@@ -892,8 +910,10 @@ def featureImportances(self):
892910@inherit_doc
893911class GBTRegressor (JavaEstimator , HasFeaturesCol , HasLabelCol , HasPredictionCol , HasMaxIter ,
894912 GBTParams , HasCheckpointInterval , HasStepSize , HasSeed , JavaMLWritable ,
895- JavaMLReadable ):
913+ JavaMLReadable , TreeRegressorParams ):
896914 """
915+ .. note:: Experimental
916+
897917 `Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
898918 learning algorithm for regression.
899919 It supports both continuous and categorical features.
@@ -904,6 +924,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
904924 ... (1.0, Vectors.dense(1.0)),
905925 ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
906926 >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
927+ >>> print(gbt.getImpurity())
928+ variance
907929 >>> model = gbt.fit(df)
908930 >>> model.featureImportances
909931 SparseVector(1, {0: 1.0})
@@ -940,19 +962,21 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
940962 def __init__ (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
941963 maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
942964 maxMemoryInMB = 256 , cacheNodeIds = False , subsamplingRate = 1.0 ,
943- checkpointInterval = 10 , lossType = "squared" , maxIter = 20 , stepSize = 0.1 , seed = None ):
965+ checkpointInterval = 10 , lossType = "squared" , maxIter = 20 , stepSize = 0.1 , seed = None ,
966+ impurity = "variance" ):
944967 """
945968 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
946969 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
947970 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
948- checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
971+ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
972+ impurity="variance")
949973 """
950974 super (GBTRegressor , self ).__init__ ()
951975 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.regression.GBTRegressor" , self .uid )
952976 self ._setDefault (maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
953977 maxMemoryInMB = 256 , cacheNodeIds = False , subsamplingRate = 1.0 ,
954978 checkpointInterval = 10 , lossType = "squared" , maxIter = 20 , stepSize = 0.1 ,
955- seed = None )
979+ seed = None , impurity = "variance" )
956980 kwargs = self .__init__ ._input_kwargs
957981 self .setParams (** kwargs )
958982
@@ -961,12 +985,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
961985 def setParams (self , featuresCol = "features" , labelCol = "label" , predictionCol = "prediction" ,
962986 maxDepth = 5 , maxBins = 32 , minInstancesPerNode = 1 , minInfoGain = 0.0 ,
963987 maxMemoryInMB = 256 , cacheNodeIds = False , subsamplingRate = 1.0 ,
964- checkpointInterval = 10 , lossType = "squared" , maxIter = 20 , stepSize = 0.1 , seed = None ):
988+ checkpointInterval = 10 , lossType = "squared" , maxIter = 20 , stepSize = 0.1 , seed = None ,
989+ impuriy = "variance" ):
965990 """
966991 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
967992 maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
968993 maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
969- checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None)
994+ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
995+ impurity="variance")
970996 Sets params for Gradient Boosted Tree Regression.
971997 """
972998 kwargs = self .setParams ._input_kwargs
@@ -992,6 +1018,8 @@ def getLossType(self):
9921018
9931019class GBTRegressionModel (TreeEnsembleModels , JavaMLWritable , JavaMLReadable ):
9941020 """
1021+ .. note:: Experimental
1022+
9951023 Model fitted by GBTRegressor.
9961024
9971025 .. versionadded:: 1.4.0
@@ -1017,6 +1045,8 @@ def featureImportances(self):
10171045class AFTSurvivalRegression (JavaEstimator , HasFeaturesCol , HasLabelCol , HasPredictionCol ,
10181046 HasFitIntercept , HasMaxIter , HasTol , JavaMLWritable , JavaMLReadable ):
10191047 """
1048+ .. note:: Experimental
1049+
10201050 Accelerated Failure Time (AFT) Model Survival Regression
10211051
10221052 Fit a parametric AFT survival regression model based on the Weibull distribution
@@ -1157,6 +1187,8 @@ def getQuantilesCol(self):
11571187
11581188class AFTSurvivalRegressionModel (JavaModel , JavaMLWritable , JavaMLReadable ):
11591189 """
1190+ .. note:: Experimental
1191+
11601192 Model fitted by AFTSurvivalRegression.
11611193
11621194 .. versionadded:: 1.6.0
@@ -1204,6 +1236,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12041236 HasFitIntercept , HasMaxIter , HasTol , HasRegParam , HasWeightCol ,
12051237 HasSolver , JavaMLWritable , JavaMLReadable ):
12061238 """
1239+ .. note:: Experimental
1240+
12071241 Generalized Linear Regression.
12081242
12091243 Fit a Generalized Linear Model specified by giving a symbolic description of the linear
@@ -1320,6 +1354,8 @@ def getLink(self):
13201354
13211355class GeneralizedLinearRegressionModel (JavaModel , JavaMLWritable , JavaMLReadable ):
13221356 """
1357+ .. note:: Experimental
1358+
13231359 Model fitted by GeneralizedLinearRegression.
13241360
13251361 .. versionadded:: 2.0.0
0 commit comments