@@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
4343 with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
4444 with HasSolver with Logging {
4545
46+ import GeneralizedLinearRegression ._
47+
4648 /**
4749 * Param for the name of family which is a description of the error distribution
4850 * to be used in the model.
@@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
5456 @ Since (" 2.0.0" )
5557 final val family : Param [String ] = new Param (this , " family" ,
5658 " The name of family which is a description of the error distribution to be used in the " +
57- " model. Supported options: gaussian(default), binomial, poisson and gamma ." ,
58- ParamValidators .inArray[String ](GeneralizedLinearRegression . supportedFamilyNames.toArray))
59+ s " model. Supported options: ${supportedFamilyNames.mkString( " , " )} . " ,
60+ ParamValidators .inArray[String ](supportedFamilyNames.toArray))
5961
6062 /** @group getParam */
6163 @ Since (" 2.0.0" )
@@ -71,29 +73,32 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
7173 @ Since (" 2.0.0" )
7274 final val link : Param [String ] = new Param (this , " link" , " The name of link function " +
7375 " which provides the relationship between the linear predictor and the mean of the " +
74- " distribution function. Supported options: identity, log, inverse, logit, probit, " +
75- " cloglog and sqrt." ,
76- ParamValidators .inArray[String ](GeneralizedLinearRegression .supportedLinkNames.toArray))
76+ s " distribution function. Supported options: ${supportedLinkNames.mkString(" , " )}" ,
77+ ParamValidators .inArray[String ](supportedLinkNames.toArray))
7778
7879 /** @group getParam */
7980 @ Since (" 2.0.0" )
8081 def getLink : String = $(link)
8182
8283 /**
8384 * Param for link prediction (linear predictor) column name.
84- * Default is empty , which means we do not output link prediction.
85+ * Default is not set , which means we do not output link prediction.
8586 *
8687 * @group param
8788 */
8889 @ Since (" 2.0.0" )
8990 final val linkPredictionCol : Param [String ] = new Param [String ](this , " linkPredictionCol" ,
9091 " link prediction (linear predictor) column name" )
91- setDefault(linkPredictionCol, " " )
9292
9393 /** @group getParam */
9494 @ Since (" 2.0.0" )
9595 def getLinkPredictionCol : String = $(linkPredictionCol)
9696
97+ /** Checks whether we should output link prediction. */
98+ private [regression] def hasLinkPredictionCol : Boolean = {
99+ isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
100+ }
101+
97102 import GeneralizedLinearRegression ._
98103
99104 @ Since (" 2.0.0" )
@@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
107112 s " with ${$(family)} family does not support ${$(link)} link function. " )
108113 }
109114 val newSchema = super .validateAndTransformSchema(schema, fitting, featuresDataType)
110- if ($(linkPredictionCol).nonEmpty ) {
115+ if (hasLinkPredictionCol ) {
111116 SchemaUtils .appendColumn(newSchema, $(linkPredictionCol), DoubleType )
112117 } else {
113118 newSchema
@@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
205210 /**
206211 * Sets the value of param [[weightCol ]].
207212 * If this is not set or empty, we treat all instance weights as 1.0.
208- * Default is empty , so all instances have weight one.
213+ * Default is not set , so all instances have weight one.
209214 *
210215 * @group setParam
211216 */
@@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
214219
215220 /**
216221 * Sets the solver algorithm used for optimization.
217- * Currently only support "irls" which is also the default solver.
222+ * Currently only supports "irls" which is also the default solver.
218223 *
219224 * @group setParam
220225 */
@@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
239244 }
240245 val familyAndLink = new FamilyAndLink (familyObj, linkObj)
241246
242- val numFeatures = dataset.select(col($(featuresCol))).limit(1 ).rdd
243- .map { case Row (features : Vector ) =>
244- features.size
245- }.first()
247+ val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector ](0 ).size
246248 if (numFeatures > WeightedLeastSquares .MAX_NUM_FEATURES ) {
247249 val msg = " Currently, GeneralizedLinearRegression only supports number of features" +
248250 s " <= ${WeightedLeastSquares .MAX_NUM_FEATURES }. Found $numFeatures in the input dataset. "
@@ -294,25 +296,25 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
294296 override def load (path : String ): GeneralizedLinearRegression = super .load(path)
295297
296298 /** Set of family and link pairs that GeneralizedLinearRegression supports. */
297- private [ml ] lazy val supportedFamilyAndLinkPairs = Set (
299+ private [regression ] lazy val supportedFamilyAndLinkPairs = Set (
298300 Gaussian -> Identity , Gaussian -> Log , Gaussian -> Inverse ,
299301 Binomial -> Logit , Binomial -> Probit , Binomial -> CLogLog ,
300302 Poisson -> Log , Poisson -> Identity , Poisson -> Sqrt ,
301303 Gamma -> Inverse , Gamma -> Identity , Gamma -> Log
302304 )
303305
304306 /** Set of family names that GeneralizedLinearRegression supports. */
305- private [ml ] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
307+ private [regression ] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
306308
307309 /** Set of link names that GeneralizedLinearRegression supports. */
308- private [ml ] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
310+ private [regression ] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
309311
310- private [ml ] val epsilon : Double = 1E-16
312+ private [regression ] val epsilon : Double = 1E-16
311313
312314 /**
313315 * Wrapper of family and link combination used in the model.
314316 */
315- private [ml ] class FamilyAndLink (val family : Family , val link : Link ) extends Serializable {
317+ private [regression ] class FamilyAndLink (val family : Family , val link : Link ) extends Serializable {
316318
317319 /** Linear predictor based on given mu. */
318320 def predict (mu : Double ): Double = link.link(family.project(mu))
@@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
359361 *
360362 * @param name the name of the family.
361363 */
362- private [ml ] abstract class Family (val name : String ) extends Serializable {
364+ private [regression ] abstract class Family (val name : String ) extends Serializable {
363365
364366 /** The default link instance of this family. */
365367 val defaultLink : Link
@@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
391393 def project (mu : Double ): Double = mu
392394 }
393395
394- private [ml ] object Family {
396+ private [regression ] object Family {
395397
396398 /**
397399 * Gets the [[Family ]] object from its name.
@@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
412414 * Gaussian exponential family distribution.
413415 * The default link for the Gaussian family is the identity link.
414416 */
415- private [ml ] object Gaussian extends Family (" gaussian" ) {
417+ private [regression ] object Gaussian extends Family (" gaussian" ) {
416418
417419 val defaultLink : Link = Identity
418420
@@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
448450 * Binomial exponential family distribution.
449451 * The default link for the Binomial family is the logit link.
450452 */
451- private [ml ] object Binomial extends Family (" binomial" ) {
453+ private [regression ] object Binomial extends Family (" binomial" ) {
452454
453455 val defaultLink : Link = Logit
454456
@@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
492494 * Poisson exponential family distribution.
493495 * The default link for the Poisson family is the log link.
494496 */
495- private [ml ] object Poisson extends Family (" poisson" ) {
497+ private [regression ] object Poisson extends Family (" poisson" ) {
496498
497499 val defaultLink : Link = Log
498500
@@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
533535 * Gamma exponential family distribution.
534536 * The default link for the Gamma family is the inverse link.
535537 */
536- private [ml ] object Gamma extends Family (" gamma" ) {
538+ private [regression ] object Gamma extends Family (" gamma" ) {
537539
538540 val defaultLink : Link = Inverse
539541
@@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
578580 *
579581 * @param name the name of link function.
580582 */
581- private [ml ] abstract class Link (val name : String ) extends Serializable {
583+ private [regression ] abstract class Link (val name : String ) extends Serializable {
582584
583585 /** The link function. */
584586 def link (mu : Double ): Double
@@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
590592 def unlink (eta : Double ): Double
591593 }
592594
593- private [ml ] object Link {
595+ private [regression ] object Link {
594596
595597 /**
596598 * Gets the [[Link ]] object from its name.
@@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
611613 }
612614 }
613615
614- private [ml ] object Identity extends Link (" identity" ) {
616+ private [regression ] object Identity extends Link (" identity" ) {
615617
616618 override def link (mu : Double ): Double = mu
617619
@@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
620622 override def unlink (eta : Double ): Double = eta
621623 }
622624
623- private [ml ] object Logit extends Link (" logit" ) {
625+ private [regression ] object Logit extends Link (" logit" ) {
624626
625627 override def link (mu : Double ): Double = math.log(mu / (1.0 - mu))
626628
@@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
629631 override def unlink (eta : Double ): Double = 1.0 / (1.0 + math.exp(- 1.0 * eta))
630632 }
631633
632- private [ml ] object Log extends Link (" log" ) {
634+ private [regression ] object Log extends Link (" log" ) {
633635
634636 override def link (mu : Double ): Double = math.log(mu)
635637
@@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
638640 override def unlink (eta : Double ): Double = math.exp(eta)
639641 }
640642
641- private [ml ] object Inverse extends Link (" inverse" ) {
643+ private [regression ] object Inverse extends Link (" inverse" ) {
642644
643645 override def link (mu : Double ): Double = 1.0 / mu
644646
@@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
647649 override def unlink (eta : Double ): Double = 1.0 / eta
648650 }
649651
650- private [ml ] object Probit extends Link (" probit" ) {
652+ private [regression ] object Probit extends Link (" probit" ) {
651653
652654 override def link (mu : Double ): Double = dist.Gaussian (0.0 , 1.0 ).icdf(mu)
653655
@@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
658660 override def unlink (eta : Double ): Double = dist.Gaussian (0.0 , 1.0 ).cdf(eta)
659661 }
660662
661- private [ml ] object CLogLog extends Link (" cloglog" ) {
663+ private [regression ] object CLogLog extends Link (" cloglog" ) {
662664
663665 override def link (mu : Double ): Double = math.log(- 1.0 * math.log(1 - mu))
664666
@@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
667669 override def unlink (eta : Double ): Double = 1.0 - math.exp(- 1.0 * math.exp(eta))
668670 }
669671
670- private [ml ] object Sqrt extends Link (" sqrt" ) {
672+ private [regression ] object Sqrt extends Link (" sqrt" ) {
671673
672674 override def link (mu : Double ): Double = math.sqrt(mu)
673675
@@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] (
732734 if ($(predictionCol).nonEmpty) {
733735 output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
734736 }
735- if ($(linkPredictionCol).nonEmpty ) {
737+ if (hasLinkPredictionCol ) {
736738 output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
737739 }
738740 output.toDF()
@@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
860862 */
861863 @ Since (" 2.0.0" )
862864 val predictionCol : String = {
863- if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != " " ) {
865+ if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty ) {
864866 origModel.getPredictionCol
865867 } else {
866868 " prediction_" + java.util.UUID .randomUUID.toString
0 commit comments