Skip to content

Commit 8b3dd3e

Browse files
committed
extract link and family name in R
1 parent a5d203e commit 8b3dd3e

File tree

2 files changed

+12
-39
lines changed

2 files changed

+12
-39
lines changed

R/pkg/R/mllib.R

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ setClass("PipelineModel", representation(model = "jobj"))
2929
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
3030
#' operators are supported, including '~', '.', ':', '+', and '-'.
3131
#' @param data DataFrame for training
32-
#' @param family a description of the error distribution and link function to be used in the model..
32+
#' @param family a description of the error distribution and link function to be used in the model,
33+
#' as in [[https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html]]
3334
#' @param lambda Regularization parameter
3435
#' @param solver Currently only support "irls" which is also the default solver.
3536
#' @return a fitted MLlib model
@@ -45,12 +46,12 @@ setClass("PipelineModel", representation(model = "jobj"))
4546
#' summary(model)
4647
#'}
4748
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
48-
function(formula, family = c("gaussian", "binomial", "poisson", "gamma"), data,
49-
lambda = 0, solver = "irls") {
50-
family <- match.arg(family)
49+
function(formula, family = gaussian(), data, lambda = 0, solver = "auto") {
50+
familyName <- family$family
51+
linkName <- family$link
5152
formula <- paste(deparse(formula), collapse = "")
5253
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
53-
"fitGLM", formula, data@sdf, family, lambda, solver)
54+
"fitGLM", formula, data@sdf, familyName, linkName, lambda, solver)
5455
return(new("PipelineModel", model = model))
5556
})
5657

@@ -117,11 +118,6 @@ setMethod("summary", signature(object = "PipelineModel"),
117118
colnames(coefficients) <- c("Estimate")
118119
rownames(coefficients) <- unlist(features)
119120
return(list(coefficients = coefficients))
120-
} else if (modelName == "GeneralizedLinearRegressionModel") {
121-
coefficients <- as.matrix(unlist(coefficients))
122-
colnames(coefficients) <- c("Estimate")
123-
rownames(coefficients) <- unlist(features)
124-
return(list(coefficients = coefficients))
125121
} else if (modelName == "KMeansModel") {
126122
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
127123
"getKMeansModelSize", object@model)

mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,16 @@ private[r] object SparkRWrappers {
3131
value: String,
3232
df: DataFrame,
3333
family: String,
34+
link: String,
3435
lambda: Double,
3536
solver: String): PipelineModel = {
36-
if (solver.trim != "irls") throw new SparkException("Currently only support irls")
37-
3837
val formula = new RFormula().setFormula(value)
39-
val regex = "^\\s*(\\w+)\\s*(\\(\\s*link\\s*=\\s*\"(\\w+)\"\\s*\\))?\\s*$".r
40-
val estimator = family match {
41-
case regex(familyName, group2, linkName) =>
42-
val estimator = new GeneralizedLinearRegression()
43-
.setFamily(familyName)
44-
.setRegParam(lambda)
45-
.setFitIntercept(formula.hasIntercept)
46-
if (linkName != null) estimator.setLink(linkName)
47-
estimator
48-
case _ => throw new SparkException(s"Could not parse family: $family")
49-
}
38+
val estimator = new GeneralizedLinearRegression()
39+
.setFamily(family)
40+
.setRegParam(lambda)
41+
.setFitIntercept(formula.hasIntercept)
5042

43+
if (link != null) estimator.setLink(link)
5144
val pipeline = new Pipeline().setStages(Array(formula, estimator))
5245
pipeline.fit(df)
5346
}
@@ -117,12 +110,6 @@ private[r] object SparkRWrappers {
117110
}
118111
case m: KMeansModel =>
119112
m.clusterCenters.flatMap(_.toArray)
120-
case m: GeneralizedLinearRegressionModel =>
121-
if (m.getFitIntercept) {
122-
Array(m.intercept) ++ m.coefficients.toArray
123-
} else {
124-
m.coefficients.toArray
125-
}
126113
}
127114
}
128115

@@ -183,14 +170,6 @@ private[r] object SparkRWrappers {
183170
val attrs = AttributeGroup.fromStructField(
184171
m.summary.predictions.schema(m.summary.featuresCol))
185172
attrs.attributes.get.map(_.name.get)
186-
case m: GeneralizedLinearRegressionModel =>
187-
val attrs = AttributeGroup.fromStructField(
188-
m.summary.predictions.schema(m.summary.featuresCol))
189-
if (m.getFitIntercept) {
190-
Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
191-
} else {
192-
attrs.attributes.get.map(_.name.get)
193-
}
194173
}
195174
}
196175

@@ -202,8 +181,6 @@ private[r] object SparkRWrappers {
202181
"LogisticRegressionModel"
203182
case m: KMeansModel =>
204183
"KMeansModel"
205-
case m: GeneralizedLinearRegressionModel =>
206-
"GeneralizedLinearRegressionModel"
207184
}
208185
}
209186
}

0 commit comments

Comments
 (0)