@@ -29,7 +29,8 @@ setClass("PipelineModel", representation(model = "jobj"))
2929# ' @param formula A symbolic description of the model to be fitted. Currently only a few formula
3030# ' operators are supported, including '~', '.', ':', '+', and '-'.
3131# ' @param data DataFrame for training
32- # ' @param family a description of the error distribution and link function to be used in the model..
32+ # ' @param family a description of the error distribution and link function to be used in the model,
33+ # ' as in [[https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html]]
3334# ' @param lambda Regularization parameter
3435# ' @param solver Currently only support "irls" which is also the default solver.
3536# ' @return a fitted MLlib model
@@ -45,12 +46,12 @@ setClass("PipelineModel", representation(model = "jobj"))
4546# ' summary(model)
4647# '}
4748setMethod ("glm ", signature(formula = "formula", family = "ANY", data = "DataFrame"),
48- function (formula , family = c( " gaussian" , " binomial " , " poisson " , " gamma " ), data ,
49- lambda = 0 , solver = " irls " ) {
50- family <- match.arg( family )
49+ function (formula , family = gaussian(), data , lambda = 0 , solver = " auto " ) {
50+ familyName <- family $ family
51+ linkName <- family $ link
5152 formula <- paste(deparse(formula ), collapse = " " )
5253 model <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
53- " fitGLM" , formula , data @ sdf , family , lambda , solver )
54+ " fitGLM" , formula , data @ sdf , familyName , linkName , lambda , solver )
5455 return (new(" PipelineModel" , model = model ))
5556 })
5657
@@ -117,11 +118,6 @@ setMethod("summary", signature(object = "PipelineModel"),
117118 colnames(coefficients ) <- c(" Estimate" )
118119 rownames(coefficients ) <- unlist(features )
119120 return (list (coefficients = coefficients ))
120- } else if (modelName == " GeneralizedLinearRegressionModel" ) {
121- coefficients <- as.matrix(unlist(coefficients ))
122- colnames(coefficients ) <- c(" Estimate" )
123- rownames(coefficients ) <- unlist(features )
124- return (list (coefficients = coefficients ))
125121 } else if (modelName == " KMeansModel" ) {
126122 modelSize <- callJStatic(" org.apache.spark.ml.api.r.SparkRWrappers" ,
127123 " getKMeansModelSize" , object @ model )
0 commit comments