Skip to content

Commit cdce9d3

Browse files
author
Ilya Ganelin
committed
Fixed merge conflict
2 parents 035f537 + 662d60d commit cdce9d3

File tree

351 files changed

+12240
-3706
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

351 files changed

+12240
-3706
lines changed

R/pkg/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Collate:
2929
'client.R'
3030
'context.R'
3131
'deserialize.R'
32+
'mllib.R'
3233
'serialize.R'
3334
'sparkR.R'
3435
'utils.R'

R/pkg/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ export("sparkR.init")
1010
export("sparkR.stop")
1111
export("print.jobj")
1212

13+
# MLlib integration
14+
exportMethods("glm",
15+
"predict")
16+
1317
# Job group lifecycle management methods
1418
export("setJobGroup",
1519
"clearJobGroup",
@@ -22,6 +26,7 @@ exportMethods("arrange",
2226
"collect",
2327
"columns",
2428
"count",
29+
"crosstab",
2530
"describe",
2631
"distinct",
2732
"dropna",

R/pkg/R/DataFrame.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,3 +1554,31 @@ setMethod("fillna",
15541554
}
15551555
dataFrame(sdf)
15561556
})
1557+
1558+
#' crosstab
1559+
#'
1560+
#' Computes a pair-wise frequency table of the given columns. Also known as a contingency
1561+
#' table. The number of distinct values for each column should be less than 1e4. At most 1e6
1562+
#' non-zero pair frequencies will be returned.
1563+
#'
1564+
#' @param col1 name of the first column. Distinct items will make the first item of each row.
1565+
#' @param col2 name of the second column. Distinct items will make the column names of the output.
1566+
#' @return a local R data.frame representing the contingency table. The first column of each row
1567+
#' will be the distinct values of `col1` and the column names will be the distinct values
1568+
#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no
1569+
#' occurrences will have zero as their counts.
1570+
#'
1571+
#' @rdname statfunctions
1572+
#' @export
1573+
#' @examples
1574+
#' \dontrun{
1575+
#' df <- jsonFile(sqlCtx, "/path/to/file.json")
1576+
#' ct = crosstab(df, "title", "gender")
1577+
#' }
1578+
setMethod("crosstab",
1579+
signature(x = "DataFrame", col1 = "character", col2 = "character"),
1580+
function(x, col1, col2) {
1581+
statFunctions <- callJMethod(x@sdf, "stat")
1582+
sct <- callJMethod(statFunctions, "crosstab", col1, col2)
1583+
collect(dataFrame(sct))
1584+
})

R/pkg/R/generics.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") })
5959
# @export
6060
setGeneric("countByValue", function(x) { standardGeneric("countByValue") })
6161

62+
# @rdname statfunctions
63+
# @export
64+
setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") })
65+
6266
# @rdname distinct
6367
# @export
6468
setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") })
@@ -661,3 +665,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
661665
#' @rdname column
662666
#' @export
663667
setGeneric("upper", function(x) { standardGeneric("upper") })
668+
669+
#' @rdname glm
670+
#' @export
671+
setGeneric("glm")

R/pkg/R/mllib.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# mllib.R: Provides methods for MLlib integration
19+
20+
#' @title S4 class that represents a PipelineModel
21+
#' @param model A Java object reference to the backing Scala PipelineModel
22+
#' @export
23+
setClass("PipelineModel", representation(model = "jobj"))
24+
25+
#' Fits a generalized linear model
26+
#'
27+
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
28+
#'
29+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
30+
#' operators are supported, including '~' and '+'.
31+
#' @param data DataFrame for training
32+
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
33+
#' @param lambda Regularization parameter
34+
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
35+
#' @return a fitted MLlib model
36+
#' @rdname glm
37+
#' @export
38+
#' @examples
39+
#'\dontrun{
40+
#' sc <- sparkR.init()
41+
#' sqlContext <- sparkRSQL.init(sc)
42+
#' data(iris)
43+
#' df <- createDataFrame(sqlContext, iris)
44+
#' model <- glm(Sepal_Length ~ Sepal_Width, df)
45+
#'}
46+
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
47+
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
48+
family <- match.arg(family)
49+
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
50+
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
51+
alpha)
52+
return(new("PipelineModel", model = model))
53+
})
54+
55+
#' Make predictions from a model
56+
#'
57+
#' Makes predictions from a model produced by glm(), similarly to R's predict().
58+
#'
59+
#' @param model A fitted MLlib model
60+
#' @param newData DataFrame for testing
61+
#' @return DataFrame containing predicted values
62+
#' @rdname glm
63+
#' @export
64+
#' @examples
65+
#'\dontrun{
66+
#' model <- glm(y ~ x, trainingData)
67+
#' predicted <- predict(model, testData)
68+
#' showDF(predicted)
69+
#'}
70+
setMethod("predict", signature(object = "PipelineModel"),
71+
function(object, newData) {
72+
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
73+
})

R/pkg/R/schema.R

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ structType.structField <- function(x, ...) {
6969
#' @param ... further arguments passed to or from other methods
7070
print.structType <- function(x, ...) {
7171
cat("StructType\n",
72-
sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
73-
"\", type = \"", field$dataType.toString(),
74-
"\", nullable = ", field$nullable(), "\n",
75-
sep = "") })
76-
, sep = "")
72+
sapply(x$fields(),
73+
function(field) {
74+
paste("|-", "name = \"", field$name(),
75+
"\", type = \"", field$dataType.toString(),
76+
"\", nullable = ", field$nullable(), "\n",
77+
sep = "")
78+
}),
79+
sep = "")
7780
}
7881

7982
#' structField

R/pkg/R/utils.R

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
390390
for (i in 1:nodeLen) {
391391
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
392392
}
393-
} else { # if node[[1]] is length of 1, check for some R special functions.
393+
} else {
394+
# if node[[1]] is length of 1, check for some R special functions.
394395
nodeChar <- as.character(node[[1]])
395-
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
396+
if (nodeChar == "{" || nodeChar == "(") {
397+
# Skip start symbol.
396398
for (i in 2:nodeLen) {
397399
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
398400
}
399401
} else if (nodeChar == "<-" || nodeChar == "=" ||
400-
nodeChar == "<<-") { # Assignment Ops.
402+
nodeChar == "<<-") {
403+
# Assignment Ops.
401404
defVar <- node[[2]]
402405
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
403406
# Add the defined variable name into defVars.
@@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
408411
for (i in 3:nodeLen) {
409412
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
410413
}
411-
} else if (nodeChar == "function") { # Function definition.
414+
} else if (nodeChar == "function") {
415+
# Function definition.
412416
# Add parameter names.
413417
newArgs <- names(node[[2]])
414418
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
415419
for (i in 3:nodeLen) {
416420
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
417421
}
418-
} else if (nodeChar == "$") { # Skip the field.
422+
} else if (nodeChar == "$") {
423+
# Skip the field.
419424
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
420425
} else if (nodeChar == "::" || nodeChar == ":::") {
421426
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
@@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
429434
(typeof(node) == "symbol" || typeof(node) == "language")) {
430435
# Base case: current AST node is a leaf node and a symbol or a function call.
431436
nodeChar <- as.character(node)
432-
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
437+
if (!nodeChar %in% defVars$data) {
438+
# Not a function parameter or local variable.
433439
func.env <- oldEnv
434440
topEnv <- parent.env(.GlobalEnv)
435441
# Search in function environment, and function's enclosing environments
@@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
439445
while (!identical(func.env, topEnv)) {
440446
# Namespaces other than "SparkR" will not be searched.
441447
if (!isNamespace(func.env) ||
442-
(getNamespaceName(func.env) == "SparkR" &&
443-
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
448+
(getNamespaceName(func.env) == "SparkR" &&
449+
!(nodeChar %in% getNamespaceExports("SparkR")))) {
450+
# Only include SparkR internals.
451+
444452
# Set parameter 'inherits' to FALSE since we do not need to search in
445453
# attached package environments.
446454
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
447455
error = function(e) { FALSE })) {
448456
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
449-
if (is.function(obj)) { # If the node is a function call.
457+
if (is.function(obj)) {
458+
# If the node is a function call.
450459
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
451460
ifnotfound = list(list(NULL)))[[1]]
452461
found <- sapply(funcList, function(func) {
453462
ifelse(identical(func, obj), TRUE, FALSE)
454463
})
455-
if (sum(found) > 0) { # If function has been examined, ignore.
464+
if (sum(found) > 0) {
465+
# If function has been examined, ignore.
456466
break
457467
}
458468
# Function has not been examined, record it and recursively clean its closure.
@@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
495505
# environment. First, function's arguments are added to defVars.
496506
defVars <- initAccumulator()
497507
argNames <- names(as.list(args(func)))
498-
for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
508+
for (i in 1:(length(argNames) - 1)) {
509+
# Remove the ending NULL in pairlist.
499510
addItemToAccumulator(defVars, argNames[i])
500511
}
501512
# Recursively examine variables in the function body.

R/pkg/inst/tests/test_mllib.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
library(testthat)
19+
20+
context("MLlib functions")
21+
22+
# Tests for MLlib functions in SparkR
23+
24+
sc <- sparkR.init()
25+
26+
sqlContext <- sparkRSQL.init(sc)
27+
28+
test_that("glm and predict", {
29+
training <- createDataFrame(sqlContext, iris)
30+
test <- select(training, "Sepal_Length")
31+
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
32+
prediction <- predict(model, test)
33+
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
34+
})
35+
36+
test_that("predictions match with native glm", {
37+
training <- createDataFrame(sqlContext, iris)
38+
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
39+
vals <- collect(select(predict(model, training), "prediction"))
40+
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
41+
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
42+
})

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,19 @@ test_that("fillna() on a DataFrame", {
987987
expect_identical(expected, actual)
988988
})
989989

990+
test_that("crosstab() on a DataFrame", {
991+
rdd <- lapply(parallelize(sc, 0:3), function(x) {
992+
list(paste0("a", x %% 3), paste0("b", x %% 2))
993+
})
994+
df <- toDF(rdd, list("a", "b"))
995+
ct <- crosstab(df, "a", "b")
996+
ordered <- ct[order(ct$a_b),]
997+
row.names(ordered) <- NULL
998+
expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0),
999+
stringsAsFactors = FALSE, row.names = NULL)
1000+
expect_identical(expected, ordered)
1001+
})
1002+
9901003
unlink(parquetPath)
9911004
unlink(jsonPath)
9921005
unlink(jsonPathNa)

bin/spark-shell

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ function main() {
4747
# (see https://github.com/sbt/sbt/issues/562).
4848
stty -icanon min 1 -echo > /dev/null 2>&1
4949
export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix"
50-
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
50+
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
5151
stty icanon echo > /dev/null 2>&1
5252
else
5353
export SPARK_SUBMIT_OPTS
54-
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@"
54+
"$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@"
5555
fi
5656
}
5757

0 commit comments

Comments
 (0)