Skip to content

Commit 1d6d7c8

Browse files
committed
Merge branch 'master' into receiver-scheduling
2 parents 8f93c8d + a721ee5 commit 1d6d7c8

File tree

300 files changed

+10690
-2707
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

300 files changed

+10690
-2707
lines changed

R/pkg/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Collate:
2929
'client.R'
3030
'context.R'
3131
'deserialize.R'
32+
'mllib.R'
3233
'serialize.R'
3334
'sparkR.R'
3435
'utils.R'

R/pkg/NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ export("sparkR.init")
1010
export("sparkR.stop")
1111
export("print.jobj")
1212

13+
# MLlib integration
14+
exportMethods("glm",
15+
"predict")
16+
1317
# Job group lifecycle management methods
1418
export("setJobGroup",
1519
"clearJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
661661
#' @rdname column
662662
#' @export
663663
setGeneric("upper", function(x) { standardGeneric("upper") })
664+
665+
#' @rdname glm
666+
#' @export
667+
setGeneric("glm")

R/pkg/R/mllib.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# mllib.R: Provides methods for MLlib integration
19+
20+
#' @title S4 class that represents a PipelineModel
21+
#' @param model A Java object reference to the backing Scala PipelineModel
22+
#' @export
23+
setClass("PipelineModel", representation(model = "jobj"))
24+
25+
#' Fits a generalized linear model
26+
#'
27+
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
28+
#'
29+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30+
#' operators are supported, including '~' and '+'.
31+
#' @param data DataFrame for training
32+
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
33+
#' @param lambda Regularization parameter
34+
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
35+
#' @return a fitted MLlib model
36+
#' @rdname glm
37+
#' @export
38+
#' @examples
39+
#'\dontrun{
40+
#' sc <- sparkR.init()
41+
#' sqlContext <- sparkRSQL.init(sc)
42+
#' data(iris)
43+
#' df <- createDataFrame(sqlContext, iris)
44+
#' model <- glm(Sepal_Length ~ Sepal_Width, df)
45+
#'}
46+
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
47+
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
48+
family <- match.arg(family)
49+
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
50+
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
51+
alpha)
52+
return(new("PipelineModel", model = model))
53+
})
54+
55+
#' Make predictions from a model
56+
#'
57+
#' Makes predictions from a model produced by glm(), similarly to R's predict().
58+
#'
59+
#' @param model A fitted MLlib model
60+
#' @param newData DataFrame for testing
61+
#' @return DataFrame containing predicted values
62+
#' @rdname glm
63+
#' @export
64+
#' @examples
65+
#'\dontrun{
66+
#' model <- glm(y ~ x, trainingData)
67+
#' predicted <- predict(model, testData)
68+
#' showDF(predicted)
69+
#'}
70+
setMethod("predict", signature(object = "PipelineModel"),
71+
function(object, newData) {
72+
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
73+
})

R/pkg/R/schema.R

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
6969
#' @param ... further arguments passed to or from other methods
7070
print.structType <- function(x, ...) {
7171
cat("StructType\n",
72-
sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
73-
"\", type = \"", field$dataType.toString(),
74-
"\", nullable = ", field$nullable(), "\n",
75-
sep = "") })
76-
, sep = "")
72+
sapply(x$fields(),
73+
function(field) {
74+
paste("|-", "name = \"", field$name(),
75+
"\", type = \"", field$dataType.toString(),
76+
"\", nullable = ", field$nullable(), "\n",
77+
sep = "")
78+
}),
79+
sep = "")
7780
}
7881

7982
#' structField

R/pkg/R/utils.R

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
390390
for (i in 1:nodeLen) {
391391
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
392392
}
393-
} else { # if node[[1]] is length of 1, check for some R special functions.
393+
} else {
394+
# if node[[1]] is length of 1, check for some R special functions.
394395
nodeChar <- as.character(node[[1]])
395-
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
396+
if (nodeChar == "{" || nodeChar == "(") {
397+
# Skip start symbol.
396398
for (i in 2:nodeLen) {
397399
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
398400
}
399401
} else if (nodeChar == "<-" || nodeChar == "=" ||
400-
nodeChar == "<<-") { # Assignment Ops.
402+
nodeChar == "<<-") {
403+
# Assignment Ops.
401404
defVar <- node[[2]]
402405
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
403406
# Add the defined variable name into defVars.
@@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
408411
for (i in 3:nodeLen) {
409412
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
410413
}
411-
} else if (nodeChar == "function") { # Function definition.
414+
} else if (nodeChar == "function") {
415+
# Function definition.
412416
# Add parameter names.
413417
newArgs <- names(node[[2]])
414418
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
415419
for (i in 3:nodeLen) {
416420
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
417421
}
418-
} else if (nodeChar == "$") { # Skip the field.
422+
} else if (nodeChar == "$") {
423+
# Skip the field.
419424
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
420425
} else if (nodeChar == "::" || nodeChar == ":::") {
421426
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
@@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
429434
(typeof(node) == "symbol" || typeof(node) == "language")) {
430435
# Base case: current AST node is a leaf node and a symbol or a function call.
431436
nodeChar <- as.character(node)
432-
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
437+
if (!nodeChar %in% defVars$data) {
438+
# Not a function parameter or local variable.
433439
func.env <- oldEnv
434440
topEnv <- parent.env(.GlobalEnv)
435441
# Search in function environment, and function's enclosing environments
@@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
439445
while (!identical(func.env, topEnv)) {
440446
# Namespaces other than "SparkR" will not be searched.
441447
if (!isNamespace(func.env) ||
442-
(getNamespaceName(func.env) == "SparkR" &&
443-
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
448+
(getNamespaceName(func.env) == "SparkR" &&
449+
!(nodeChar %in% getNamespaceExports("SparkR")))) {
450+
# Only include SparkR internals.
451+
444452
# Set parameter 'inherits' to FALSE since we do not need to search in
445453
# attached package environments.
446454
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
447455
error = function(e) { FALSE })) {
448456
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
449-
if (is.function(obj)) { # If the node is a function call.
457+
if (is.function(obj)) {
458+
# If the node is a function call.
450459
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
451460
ifnotfound = list(list(NULL)))[[1]]
452461
found <- sapply(funcList, function(func) {
453462
ifelse(identical(func, obj), TRUE, FALSE)
454463
})
455-
if (sum(found) > 0) { # If function has been examined, ignore.
464+
if (sum(found) > 0) {
465+
# If function has been examined, ignore.
456466
break
457467
}
458468
# Function has not been examined, record it and recursively clean its closure.
@@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
495505
# environment. First, function's arguments are added to defVars.
496506
defVars <- initAccumulator()
497507
argNames <- names(as.list(args(func)))
498-
for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
508+
for (i in 1:(length(argNames) - 1)) {
509+
# Remove the ending NULL in pairlist.
499510
addItemToAccumulator(defVars, argNames[i])
500511
}
501512
# Recursively examine variables in the function body.

R/pkg/inst/tests/test_mllib.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
library(testthat)
19+
20+
context("MLlib functions")
21+
22+
# Tests for MLlib functions in SparkR
23+
24+
sc <- sparkR.init()
25+
26+
sqlContext <- sparkRSQL.init(sc)
27+
28+
test_that("glm and predict", {
29+
training <- createDataFrame(sqlContext, iris)
30+
test <- select(training, "Sepal_Length")
31+
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
32+
prediction <- predict(model, test)
33+
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
34+
})
35+
36+
test_that("predictions match with native glm", {
37+
training <- createDataFrame(sqlContext, iris)
38+
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
39+
vals <- collect(select(predict(model, training), "prediction"))
40+
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
41+
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
42+
})

bin/spark-shell

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ function main() {
4747
# (see https://github.com/sbt/sbt/issues/562).
4848
stty -icanon min 1 -echo > /dev/null 2>&1
4949
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
50-
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
50+
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
5151
stty icanon echo > /dev/null 2>&1
5252
else
5353
export SPARK_SUBMIT_OPTS
54-
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
54+
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
5555
fi
5656
}
5757

bin/spark-shell2.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
3232
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"
3333

3434
:run_shell
35-
%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
35+
%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %*

build/sbt-launch-lib.bash

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,13 @@ acquire_sbt_jar () {
5151
printf "Attempting to fetch sbt\n"
5252
JAR_DL="${JAR}.part"
5353
if [ $(command -v curl) ]; then
54-
(curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
54+
(curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
55+
(rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
56+
mv "${JAR_DL}" "${JAR}"
5557
elif [ $(command -v wget) ]; then
56-
(wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
58+
(wget --quiet ${URL1} -O "${JAR_DL}" ||\
59+
(rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
60+
mv "${JAR_DL}" "${JAR}"
5761
else
5862
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
5963
exit -1

0 commit comments

Comments
 (0)