Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
def getLink: String = $(link)

/**
* Param for link prediction (linear predictor) column name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention the default value in the doc.

* Default is empty, which means we do not output link prediction.
* @group param
*/
@Since("2.0.0")
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
"link prediction (linear predictor) column name")
setDefault(linkPredictionCol, "")

/** @group getParam */
@Since("2.0.0")
def getLinkPredictionCol: String = $(linkPredictionCol)

import GeneralizedLinearRegression._

@Since("2.0.0")
Expand All @@ -93,7 +107,12 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
s"with ${$(family)} family does not support ${$(link)} link function.")
}
super.validateAndTransformSchema(schema, fitting, featuresDataType)
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes predictionCol is always set. We can either keep this assumption or fix it.

Copy link
Contributor Author

@yanboliang yanboliang Apr 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not assume predictionCol is always set. If we set predictionCol = "", we will not output predictionCol in this PR. Because super.validateAndTransformSchema checked empty or not for predictionCol inside the function SchemaUtils.appendColumn.
It looks like a convention in ML code base. We output predictionCol default, users can disable it by setting predictionCol = "".

if ($(linkPredictionCol).nonEmpty) {
Copy link
Contributor Author

@yanboliang yanboliang Apr 20, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check $(linkPredictionCol).nonEmpty can be omitted, because we already have empty check inside the next line function SchemaUtils.appendColumn. But these kinds of code can help developers to understand the logic clearly and they exist at lots of place in the code base. If we would like to omit them, I can do the clean up in a separate PR.

SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
} else {
newSchema
}
}
}

Expand Down Expand Up @@ -196,6 +215,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")

/**
* Sets the link prediction (linear predictor) column name.
* @group setParam
*/
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)

override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
Expand Down Expand Up @@ -664,6 +690,13 @@ class GeneralizedLinearRegressionModel private[ml] (
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
with GeneralizedLinearRegressionBase with MLWritable {

/**
* Sets the link prediction (linear predictor) column name.
* @group setParam
*/
@Since("2.0.0")
def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)

import GeneralizedLinearRegression._

lazy val familyObj = Family.fromName($(family))
Expand All @@ -675,10 +708,35 @@ class GeneralizedLinearRegressionModel private[ml] (
lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)

override protected def predict(features: Vector): Double = {
val eta = BLAS.dot(features, coefficients) + intercept
val eta = predictLink(features)
familyAndLink.fitted(eta)
}

/**
* Calculate the link prediction (linear predictor) of the given instance.
*/
private def predictLink(features: Vector): Double = {
BLAS.dot(features, coefficients) + intercept
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
transformImpl(dataset)
}

override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictLinkUDF = udf { (features: Vector) => predictLink(features) }
var output = dataset
if ($(predictionCol).nonEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my previous comment about the assumption on predictionCol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we checked $(predictionCol).isEmpty when transform schema, we also need to check it when actually transform the dataset.

output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
if ($(linkPredictionCol).nonEmpty) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
}
output.toDF
}

private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,24 @@ class GeneralizedLinearRegressionSuite
("inverse", datasetGaussianInverse))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gaussian family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gaussian family, $link link and fitIntercept = $fitIntercept.")
}

idx += 1
}
Expand Down Expand Up @@ -358,21 +362,25 @@ class GeneralizedLinearRegressionSuite
("cloglog", datasetBinomial))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
model.coefficients(2), model.coefficients(3))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " +
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"binomial family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"binomial family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with binomial family, $link link and fitIntercept = $fitIntercept.")
}

idx += 1
}
Expand Down Expand Up @@ -427,20 +435,24 @@ class GeneralizedLinearRegressionSuite
("sqrt", datasetPoissonSqrt))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"poisson family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"poisson family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with poisson family, $link link and fitIntercept = $fitIntercept.")
}

idx += 1
}
Expand Down Expand Up @@ -495,20 +507,24 @@ class GeneralizedLinearRegressionSuite
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
.setFitIntercept(fitIntercept)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
s"$link link and fitIntercept = $fitIntercept.")

val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gamma family, $link link and fitIntercept = $fitIntercept.")
}
model.transform(dataset).select("features", "prediction", "linkPrediction").collect()
.foreach {
case Row(features: DenseVector, prediction1: Double, linkPrediction1: Double) =>
val eta = BLAS.dot(features, model.coefficients) + model.intercept
val prediction2 = familyLink.fitted(eta)
val linkPrediction2 = eta
assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
s"gamma family, $link link and fitIntercept = $fitIntercept.")
assert(linkPrediction1 ~= linkPrediction2 relTol 1E-5, "Link Prediction mismatch: " +
s"GLM with gamma family, $link link and fitIntercept = $fitIntercept.")
}

idx += 1
}
Expand Down