Skip to content

Commit cc10147

Browse files
committed
add constriction to mu for family
1 parent 97c3f6a commit cc10147

File tree

1 file changed

+42
-9
lines changed

1 file changed

+42
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,19 +295,25 @@ private[ml] abstract class Family(val link: Link) extends Serializable {
295295

296296
/** Weights for IRLS steps. */
297297
def weights(mu: Double): Double = {
298-
1.0 / (math.pow(this.link.deriv(mu), 2.0) * this.variance(mu))
298+
val x = clean(mu)
299+
1.0 / (math.pow(this.link.deriv(x), 2.0) * this.variance(x))
299300
}
300301

301302
/** The adjusted response variable. */
302303
def adjusted(y: Double, mu: Double, eta: Double): Double = {
303-
eta + (y - mu) * link.deriv(mu)
304+
val x = clean(mu)
305+
eta + (y - x) * link.deriv(x)
304306
}
305307

306308
/** Linear predictors based on given mu. */
307-
def predict(mu: Double): Double = this.link.link(mu)
309+
def predict(mu: Double): Double = this.link.link(clean(mu))
308310

309311
/** Fitted values based on linear predictors eta. */
310-
def fitted(eta: Double): Double = this.link.unlink(eta)
312+
def fitted(eta: Double): Double = clean(this.link.unlink(eta))
313+
314+
def clean(mu: Double): Double = mu
315+
316+
val epsilon: Double = 1E-16
311317
}
312318

313319
/**
@@ -317,7 +323,13 @@ private[ml] abstract class Family(val link: Link) extends Serializable {
317323
*/
318324
private[ml] class Gaussian(link: Link = Identity) extends Family(link) {
319325

320-
override def initialize(y: Double, weight: Double): Double = y
326+
override def initialize(y: Double, weight: Double): Double = {
327+
if (link == Log) {
328+
require(y > 0.0, "The response variable of Gaussian family with Log link " +
329+
s"should be positive, but got $y")
330+
}
331+
y
332+
}
321333

322334
def variance(mu: Double): Double = 1.0
323335
}
@@ -341,10 +353,23 @@ private[ml] object Gaussian {
341353
private[ml] class Binomial(link: Link = Logit) extends Family(link) {
342354

343355
override def initialize(y: Double, weight: Double): Double = {
344-
(weight * y + 0.5) / (weight + 1.0)
356+
val mu = (weight * y + 0.5) / (weight + 1.0)
357+
require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" +
358+
s"should be in range (0, 1), but got $mu")
359+
mu
345360
}
346361

347-
override def variance(mu: Double): Double = mu * (1 - mu)
362+
override def variance(mu: Double): Double = mu * (1.0 - mu)
363+
364+
override def clean(mu: Double): Double = {
365+
if (mu < epsilon) {
366+
epsilon
367+
} else if (mu > 1.0 - epsilon) {
368+
1.0 - epsilon
369+
} else {
370+
mu
371+
}
372+
}
348373
}
349374

350375
private[ml] object Binomial {
@@ -365,7 +390,11 @@ private[ml] object Binomial {
365390
*/
366391
private[ml] class Poisson(link: Link = Log) extends Family(link) {
367392

368-
override def initialize(y: Double, weight: Double): Double = y + 0.1
393+
override def initialize(y: Double, weight: Double): Double = {
394+
require(y > 0.0, "The response variable of Poisson family " +
395+
s"should be positive, but got $y")
396+
y
397+
}
369398

370399
override def variance(mu: Double): Double = mu
371400
}
@@ -388,7 +417,11 @@ private[ml] object Poisson {
388417
*/
389418
private[ml] class Gamma(link: Link = Inverse) extends Family(link) {
390419

391-
override def initialize(y: Double, weight: Double): Double = y
420+
override def initialize(y: Double, weight: Double): Double = {
421+
require(y > 0.0, "The response variable of Gamma family " +
422+
s"should be positive, but got $y")
423+
y
424+
}
392425

393426
override def variance(mu: Double): Double = math.pow(mu, 2.0)
394427
}

0 commit comments

Comments
 (0)