Skip to content

Commit d1aa55a

Browse files
author
Ilya Ganelin
committed
Resolved merge conflicts due to earlier patch.
2 parents cdce9d3 + 6bba750 commit d1aa55a

File tree

455 files changed

+18179
-8525
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

455 files changed

+18179
-8525
lines changed

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export("print.jobj")
1212

1313
# MLlib integration
1414
exportMethods("glm",
15-
"predict")
15+
"predict",
16+
"summary")
1617

1718
# Job group lifecycle management methods
1819
export("setJobGroup",

R/pkg/R/backend.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
110110

111111
# TODO: check the status code to output error information
112112
returnStatus <- readInt(conn)
113-
stopifnot(returnStatus == 0)
113+
if (returnStatus != 0) {
114+
stop(readString(conn))
115+
}
114116
readObject(conn)
115117
}

R/pkg/R/client.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack
4848
jars <- paste("--jars", jars)
4949
}
5050

51-
if (packages != "") {
51+
if (!identical(packages, "")) {
5252
packages <- paste("--packages", packages)
5353
}
5454

R/pkg/R/deserialize.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ readList <- function(con) {
102102

103103
readRaw <- function(con) {
104104
dataLen <- readInt(con)
105-
data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
105+
readBin(con, raw(), as.integer(dataLen), endian = "big")
106106
}
107107

108108
readRawLen <- function(con, dataLen) {
109-
data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
109+
readBin(con, raw(), as.integer(dataLen), endian = "big")
110110
}
111111

112112
readDeserialize <- function(con) {

R/pkg/R/generics.R

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues")
254254

255255
# @rdname intersection
256256
# @export
257-
setGeneric("intersection", function(x, other, numPartitions = 1) {
258-
standardGeneric("intersection") })
257+
setGeneric("intersection",
258+
function(x, other, numPartitions = 1) {
259+
standardGeneric("intersection")
260+
})
259261

260262
# @rdname keys
261263
# @export
@@ -489,9 +491,7 @@ setGeneric("sample",
489491
#' @rdname sample
490492
#' @export
491493
setGeneric("sample_frac",
492-
function(x, withReplacement, fraction, seed) {
493-
standardGeneric("sample_frac")
494-
})
494+
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
495495

496496
#' @rdname saveAsParquetFile
497497
#' @export
@@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn
553553

554554
#' @rdname withColumnRenamed
555555
#' @export
556-
setGeneric("withColumnRenamed", function(x, existingCol, newCol) {
557-
standardGeneric("withColumnRenamed") })
556+
setGeneric("withColumnRenamed",
557+
function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") })
558558

559559

560560
###################### Column Methods ##########################

R/pkg/R/mllib.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
2828
#'
2929
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30-
#' operators are supported, including '~' and '+'.
30+
#' operators are supported, including '~', '+', '-', and '.'.
3131
#' @param data DataFrame for training
3232
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
3333
#' @param lambda Regularization parameter
@@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
7171
function(object, newData) {
7272
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
7373
})
74+
75+
#' Get the summary of a model
76+
#'
77+
#' Returns the summary of a model produced by glm(), similarly to R's summary().
78+
#'
79+
#' @param model A fitted MLlib model
80+
#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
81+
#' summary.glm for more information.
82+
#' @rdname glm
83+
#' @export
84+
#' @examples
85+
#'\dontrun{
86+
#' model <- glm(y ~ x, trainingData)
87+
#' summary(model)
88+
#'}
89+
setMethod("summary", signature(object = "PipelineModel"),
90+
function(object) {
91+
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
92+
"getModelFeatures", object@model)
93+
weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
94+
"getModelWeights", object@model)
95+
coefficients <- as.matrix(unlist(weights))
96+
colnames(coefficients) <- c("Estimate")
97+
rownames(coefficients) <- unlist(features)
98+
return(list(coefficients = coefficients))
99+
})

R/pkg/R/pairRDD.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ setMethod("partitionBy",
202202

203203
packageNamesArr <- serialize(.sparkREnv$.packages,
204204
connection = NULL)
205-
broadcastArr <- lapply(ls(.broadcastNames), function(name) {
206-
get(name, .broadcastNames) })
205+
broadcastArr <- lapply(ls(.broadcastNames),
206+
function(name) { get(name, .broadcastNames) })
207207
jrdd <- getJRDD(x)
208208

209209
# We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])],

R/pkg/R/sparkR.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
connExists <- function(env) {
2323
tryCatch({
2424
exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]])
25-
}, error = function(err) {
25+
},
26+
error = function(err) {
2627
return(FALSE)
2728
})
2829
}
@@ -104,16 +105,13 @@ sparkR.init <- function(
104105
return(get(".sparkRjsc", envir = .sparkREnv))
105106
}
106107

107-
sparkMem <- Sys.getenv("SPARK_MEM", "1024m")
108108
jars <- suppressWarnings(normalizePath(as.character(sparkJars)))
109109

110110
# Classpath separator is ";" on Windows
111111
# URI needs four /// as from http://stackoverflow.com/a/18522792
112112
if (.Platform$OS.type == "unix") {
113-
collapseChar <- ":"
114113
uriSep <- "//"
115114
} else {
116-
collapseChar <- ";"
117115
uriSep <- "////"
118116
}
119117

@@ -156,7 +154,8 @@ sparkR.init <- function(
156154
.sparkREnv$backendPort <- backendPort
157155
tryCatch({
158156
connectBackend("localhost", backendPort)
159-
}, error = function(err) {
157+
},
158+
error = function(err) {
160159
stop("Failed to connect JVM\n")
161160
})
162161

@@ -267,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) {
267266
ssc <- callJMethod(sc, "sc")
268267
hiveCtx <- tryCatch({
269268
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
270-
}, error = function(err) {
269+
},
270+
error = function(err) {
271271
stop("Spark SQL is not built with Hive support")
272272
})
273273

R/pkg/inst/tests/test_client.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", {
3030
expect_equal(gsub("[[:space:]]", "", args),
3131
"")
3232
})
33+
34+
test_that("multiple packages don't produce a warning", {
35+
expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning()))
36+
})

R/pkg/inst/tests/test_mllib.R

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,27 @@ test_that("glm and predict", {
3535

3636
test_that("predictions match with native glm", {
3737
training <- createDataFrame(sqlContext, iris)
38-
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
38+
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
3939
vals <- collect(select(predict(model, training), "prediction"))
40-
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
41-
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
40+
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
41+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
42+
})
43+
44+
test_that("dot minus and intercept vs native glm", {
45+
training <- createDataFrame(sqlContext, iris)
46+
model <- glm(Sepal_Width ~ . - Species + 0, data = training)
47+
vals <- collect(select(predict(model, training), "prediction"))
48+
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
49+
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
50+
})
51+
52+
test_that("summary coefficients match with native glm", {
53+
training <- createDataFrame(sqlContext, iris)
54+
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
55+
coefs <- as.vector(stats$coefficients)
56+
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
57+
expect_true(all(abs(rCoefs - coefs) < 1e-6))
58+
expect_true(all(
59+
as.character(stats$features) ==
60+
c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
4261
})

0 commit comments

Comments
 (0)