@@ -295,19 +295,25 @@ private[ml] abstract class Family(val link: Link) extends Serializable {
295295
296296 /** Weights for IRLS steps. */
297297 def weights (mu : Double ): Double = {
298- 1.0 / (math.pow(this .link.deriv(mu), 2.0 ) * this .variance(mu))
298+ val x = clean(mu)
299+ 1.0 / (math.pow(this .link.deriv(x), 2.0 ) * this .variance(x))
299300 }
300301
301302 /** The adjusted response variable. */
302303 def adjusted (y : Double , mu : Double , eta : Double ): Double = {
303- eta + (y - mu) * link.deriv(mu)
304+ val x = clean(mu)
305+ eta + (y - x) * link.deriv(x)
304306 }
305307
306308 /** Linear predictors based on given mu. */
307- def predict (mu : Double ): Double = this .link.link(mu )
309+ def predict (mu : Double ): Double = this .link.link(clean(mu) )
308310
309311 /** Fitted values based on linear predictors eta. */
310- def fitted (eta : Double ): Double = this .link.unlink(eta)
312+ def fitted (eta : Double ): Double = clean(this .link.unlink(eta))
313+
314+ def clean (mu : Double ): Double = mu
315+
316+ val epsilon : Double = 1E-16
311317}
312318
313319/**
@@ -317,7 +323,13 @@ private[ml] abstract class Family(val link: Link) extends Serializable {
317323 */
318324private [ml] class Gaussian (link : Link = Identity ) extends Family (link) {
319325
320- override def initialize (y : Double , weight : Double ): Double = y
326+ override def initialize (y : Double , weight : Double ): Double = {
327+ if (link == Log ) {
328+ require(y > 0.0 , " The response variable of Gaussian family with Log link " +
329+ s " should be positive, but got $y" )
330+ }
331+ y
332+ }
321333
322334 def variance (mu : Double ): Double = 1.0
323335}
@@ -341,10 +353,23 @@ private[ml] object Gaussian {
341353private [ml] class Binomial (link : Link = Logit ) extends Family (link) {
342354
343355 override def initialize (y : Double , weight : Double ): Double = {
344- (weight * y + 0.5 ) / (weight + 1.0 )
356+ val mu = (weight * y + 0.5 ) / (weight + 1.0 )
357+ require(mu > 0.0 && mu < 1.0 , " The response variable of Binomial family" +
358+ s " should be in range (0, 1), but got $mu" )
359+ mu
345360 }
346361
347- override def variance (mu : Double ): Double = mu * (1 - mu)
362+ override def variance (mu : Double ): Double = mu * (1.0 - mu)
363+
364+ override def clean (mu : Double ): Double = {
365+ if (mu < epsilon) {
366+ epsilon
367+ } else if (mu > 1.0 - epsilon) {
368+ 1.0 - epsilon
369+ } else {
370+ mu
371+ }
372+ }
348373}
349374
350375private [ml] object Binomial {
@@ -365,7 +390,11 @@ private[ml] object Binomial {
365390 */
366391private [ml] class Poisson (link : Link = Log ) extends Family (link) {
367392
368- override def initialize (y : Double , weight : Double ): Double = y + 0.1
393+ override def initialize (y : Double , weight : Double ): Double = {
394+ require(y > 0.0 , " The response variable of Poisson family " +
395+ s " should be positive, but got $y" )
396+ y
397+ }
369398
370399 override def variance (mu : Double ): Double = mu
371400}
@@ -388,7 +417,11 @@ private[ml] object Poisson {
388417 */
389418private [ml] class Gamma (link : Link = Inverse ) extends Family (link) {
390419
391- override def initialize (y : Double , weight : Double ): Double = y
420+ override def initialize (y : Double , weight : Double ): Double = {
421+ require(y > 0.0 , " The response variable of Gamma family " +
422+ s " should be positive, but got $y" )
423+ y
424+ }
392425
393426 override def variance (mu : Double ): Double = math.pow(mu, 2.0 )
394427}
0 commit comments