Skip to content

Commit bad5c3e

Browse files
Fixed MlWritable and MlReable to JavaMLWritable and JavaMlReadable
2 parents 540fe8e + 5dfc019 commit bad5c3e

File tree

199 files changed

+6328
-2191
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

199 files changed

+6328
-2191
lines changed

R/pkg/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Depends:
1111
R (>= 3.0),
1212
methods,
1313
Suggests:
14-
testthat
14+
testthat,
15+
e1071
1516
Description: R frontend for Spark
1617
License: Apache License (== 2.0)
1718
Collate:

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ exportMethods("glm",
1515
"predict",
1616
"summary",
1717
"kmeans",
18-
"fitted")
18+
"fitted",
19+
"naiveBayes")
1920

2021
# Job group lifecycle management methods
2122
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,3 +1175,7 @@ setGeneric("kmeans")
11751175
#' @rdname fitted
11761176
#' @export
11771177
setGeneric("fitted")
1178+
1179+
#' @rdname naiveBayes
1180+
#' @export
1181+
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })

R/pkg/R/mllib.R

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
#' @export
2323
setClass("PipelineModel", representation(model = "jobj"))
2424

25+
#' @title S4 class that represents a NaiveBayesModel
26+
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
27+
#' @export
28+
setClass("NaiveBayesModel", representation(jobj = "jobj"))
29+
2530
#' Fits a generalized linear model
2631
#'
2732
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -42,7 +47,7 @@ setClass("PipelineModel", representation(model = "jobj"))
4247
#' @rdname glm
4348
#' @export
4449
#' @examples
45-
#'\dontrun{
50+
#' \dontrun{
4651
#' sc <- sparkR.init()
4752
#' sqlContext <- sparkRSQL.init(sc)
4853
#' data(iris)
@@ -71,7 +76,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
7176
#' @rdname predict
7277
#' @export
7378
#' @examples
74-
#'\dontrun{
79+
#' \dontrun{
7580
#' model <- glm(y ~ x, trainingData)
7681
#' predicted <- predict(model, testData)
7782
#' showDF(predicted)
@@ -81,6 +86,26 @@ setMethod("predict", signature(object = "PipelineModel"),
8186
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
8287
})
8388

89+
#' Make predictions from a naive Bayes model
90+
#'
91+
#' Makes predictions from a model produced by naiveBayes(), similarly to R package e1071's predict.
92+
#'
93+
#' @param object A fitted naive Bayes model
94+
#' @param newData DataFrame for testing
95+
#' @return DataFrame containing predicted labels in a column named "prediction"
96+
#' @rdname predict
97+
#' @export
98+
#' @examples
99+
#' \dontrun{
100+
#' model <- naiveBayes(y ~ x, trainingData)
101+
#' predicted <- predict(model, testData)
102+
#' showDF(predicted)
103+
#'}
104+
setMethod("predict", signature(object = "NaiveBayesModel"),
105+
function(object, newData) {
106+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
107+
})
108+
84109
#' Get the summary of a model
85110
#'
86111
#' Returns the summary of a model produced by glm(), similarly to R's summary().
@@ -97,7 +122,7 @@ setMethod("predict", signature(object = "PipelineModel"),
97122
#' @rdname summary
98123
#' @export
99124
#' @examples
100-
#'\dontrun{
125+
#' \dontrun{
101126
#' model <- glm(y ~ x, trainingData)
102127
#' summary(model)
103128
#'}
@@ -140,6 +165,35 @@ setMethod("summary", signature(object = "PipelineModel"),
140165
}
141166
})
142167

168+
#' Get the summary of a naive Bayes model
169+
#'
170+
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
171+
#'
172+
#' @param object A fitted MLlib model
173+
#' @return a list containing 'apriori', the label distribution, and 'tables', conditional
174+
# probabilities given the target label
175+
#' @rdname summary
176+
#' @export
177+
#' @examples
178+
#' \dontrun{
179+
#' model <- naiveBayes(y ~ x, trainingData)
180+
#' summary(model)
181+
#'}
182+
setMethod("summary", signature(object = "NaiveBayesModel"),
183+
function(object, ...) {
184+
jobj <- object@jobj
185+
features <- callJMethod(jobj, "features")
186+
labels <- callJMethod(jobj, "labels")
187+
apriori <- callJMethod(jobj, "apriori")
188+
apriori <- t(as.matrix(unlist(apriori)))
189+
colnames(apriori) <- unlist(labels)
190+
tables <- callJMethod(jobj, "tables")
191+
tables <- matrix(tables, nrow = length(labels))
192+
rownames(tables) <- unlist(labels)
193+
colnames(tables) <- unlist(features)
194+
return(list(apriori = apriori, tables = tables))
195+
})
196+
143197
#' Fit a k-means model
144198
#'
145199
#' Fit a k-means model, similarly to R's kmeans().
@@ -152,7 +206,7 @@ setMethod("summary", signature(object = "PipelineModel"),
152206
#' @rdname kmeans
153207
#' @export
154208
#' @examples
155-
#'\dontrun{
209+
#' \dontrun{
156210
#' model <- kmeans(x, centers = 2, algorithm="random")
157211
#'}
158212
setMethod("kmeans", signature(x = "DataFrame"),
@@ -173,7 +227,7 @@ setMethod("kmeans", signature(x = "DataFrame"),
173227
#' @rdname fitted
174228
#' @export
175229
#' @examples
176-
#'\dontrun{
230+
#' \dontrun{
177231
#' model <- kmeans(trainingData, 2)
178232
#' fitted.model <- fitted(model)
179233
#' showDF(fitted.model)
@@ -192,3 +246,30 @@ setMethod("fitted", signature(object = "PipelineModel"),
192246
stop(paste("Unsupported model", modelName, sep = " "))
193247
}
194248
})
249+
250+
#' Fit a Bernoulli naive Bayes model
251+
#'
252+
#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
253+
#' categorical features are supported. The input should be a DataFrame of observations instead of a
254+
#' contingency table.
255+
#'
256+
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
257+
#' operators are supported, including '~', '.', ':', '+', and '-'.
258+
#' @param data DataFrame for training
259+
#' @param laplace Smoothing parameter
260+
#' @return a fitted naive Bayes model
261+
#' @rdname naiveBayes
262+
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
263+
#' @export
264+
#' @examples
265+
#' \dontrun{
266+
#' df <- createDataFrame(sqlContext, infert)
267+
#' model <- naiveBayes(education ~ ., df, laplace = 0)
268+
#'}
269+
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
270+
function(formula, data, laplace = 0, ...) {
271+
formula <- paste(deparse(formula), collapse = "")
272+
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
273+
formula, data@sdf, laplace)
274+
return(new("NaiveBayesModel", jobj = jobj))
275+
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,62 @@ test_that("kmeans", {
141141
cluster <- summary.model$cluster
142142
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
143143
})
144+
145+
test_that("naiveBayes", {
146+
# R code to reproduce the result.
147+
# We do not support instance weights yet. So we ignore the frequencies.
148+
#
149+
#' library(e1071)
150+
#' t <- as.data.frame(Titanic)
151+
#' t1 <- t[t$Freq > 0, -5]
152+
#' m <- naiveBayes(Survived ~ ., data = t1)
153+
#' m
154+
#' predict(m, t1)
155+
#
156+
# -- output of 'm'
157+
#
158+
# A-priori probabilities:
159+
# Y
160+
# No Yes
161+
# 0.4166667 0.5833333
162+
#
163+
# Conditional probabilities:
164+
# Class
165+
# Y 1st 2nd 3rd Crew
166+
# No 0.2000000 0.2000000 0.4000000 0.2000000
167+
# Yes 0.2857143 0.2857143 0.2857143 0.1428571
168+
#
169+
# Sex
170+
# Y Male Female
171+
# No 0.5 0.5
172+
# Yes 0.5 0.5
173+
#
174+
# Age
175+
# Y Child Adult
176+
# No 0.2000000 0.8000000
177+
# Yes 0.4285714 0.5714286
178+
#
179+
# -- output of 'predict(m, t1)'
180+
#
181+
# Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No
182+
#
183+
184+
t <- as.data.frame(Titanic)
185+
t1 <- t[t$Freq > 0, -5]
186+
df <- suppressWarnings(createDataFrame(sqlContext, t1))
187+
m <- naiveBayes(Survived ~ ., data = df)
188+
s <- summary(m)
189+
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
190+
expect_equal(sum(s$apriori), 1)
191+
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
192+
p <- collect(select(predict(m, df), "prediction"))
193+
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
194+
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
195+
"Yes", "Yes", "No", "No"))
196+
197+
# Test e1071::naiveBayes
198+
if (requireNamespace("e1071", quietly = TRUE)) {
199+
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
200+
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
201+
}
202+
})

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,8 @@ test_that("approxQuantile() on a DataFrame", {
18171817

18181818
test_that("SQL error message is returned from JVM", {
18191819
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
1820-
expect_equal(grepl("Table not found: blah", retError), TRUE)
1820+
expect_equal(grepl("Table not found", retError), TRUE)
1821+
expect_equal(grepl("blah", retError), TRUE)
18211822
})
18221823

18231824
irisDF <- suppressWarnings(createDataFrame(sqlContext, iris))

core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.spark.executor.ShuffleWriteMetrics;
3333
import org.apache.spark.memory.MemoryConsumer;
3434
import org.apache.spark.memory.TaskMemoryManager;
35+
import org.apache.spark.serializer.SerializerManager;
3536
import org.apache.spark.storage.BlockManager;
3637
import org.apache.spark.unsafe.Platform;
3738
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -163,19 +164,22 @@ public final class BytesToBytesMap extends MemoryConsumer {
163164
private long peakMemoryUsedBytes = 0L;
164165

165166
private final BlockManager blockManager;
167+
private final SerializerManager serializerManager;
166168
private volatile MapIterator destructiveIterator = null;
167169
private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
168170

169171
public BytesToBytesMap(
170172
TaskMemoryManager taskMemoryManager,
171173
BlockManager blockManager,
174+
SerializerManager serializerManager,
172175
int initialCapacity,
173176
double loadFactor,
174177
long pageSizeBytes,
175178
boolean enablePerfMetrics) {
176179
super(taskMemoryManager, pageSizeBytes);
177180
this.taskMemoryManager = taskMemoryManager;
178181
this.blockManager = blockManager;
182+
this.serializerManager = serializerManager;
179183
this.loadFactor = loadFactor;
180184
this.loc = new Location();
181185
this.pageSizeBytes = pageSizeBytes;
@@ -209,6 +213,7 @@ public BytesToBytesMap(
209213
this(
210214
taskMemoryManager,
211215
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
216+
SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
212217
initialCapacity,
213218
0.70,
214219
pageSizeBytes,
@@ -271,7 +276,7 @@ private void advanceToNextPage() {
271276
}
272277
try {
273278
Closeables.close(reader, /* swallowIOException = */ false);
274-
reader = spillWriters.getFirst().getReader(blockManager);
279+
reader = spillWriters.getFirst().getReader(serializerManager);
275280
recordsInPage = -1;
276281
} catch (IOException e) {
277282
// Scala iterator does not handle exception

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.spark.executor.ShuffleWriteMetrics;
3232
import org.apache.spark.memory.MemoryConsumer;
3333
import org.apache.spark.memory.TaskMemoryManager;
34+
import org.apache.spark.serializer.SerializerManager;
3435
import org.apache.spark.storage.BlockManager;
3536
import org.apache.spark.unsafe.Platform;
3637
import org.apache.spark.unsafe.array.LongArray;
@@ -51,6 +52,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
5152
private final RecordComparator recordComparator;
5253
private final TaskMemoryManager taskMemoryManager;
5354
private final BlockManager blockManager;
55+
private final SerializerManager serializerManager;
5456
private final TaskContext taskContext;
5557
private ShuffleWriteMetrics writeMetrics;
5658

@@ -78,14 +80,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
7880
public static UnsafeExternalSorter createWithExistingInMemorySorter(
7981
TaskMemoryManager taskMemoryManager,
8082
BlockManager blockManager,
83+
SerializerManager serializerManager,
8184
TaskContext taskContext,
8285
RecordComparator recordComparator,
8386
PrefixComparator prefixComparator,
8487
int initialSize,
8588
long pageSizeBytes,
8689
UnsafeInMemorySorter inMemorySorter) throws IOException {
8790
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
88-
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
91+
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
92+
pageSizeBytes, inMemorySorter);
8993
sorter.spill(Long.MAX_VALUE, sorter);
9094
// The external sorter will be used to insert records, in-memory sorter is not needed.
9195
sorter.inMemSorter = null;
@@ -95,18 +99,20 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
9599
public static UnsafeExternalSorter create(
96100
TaskMemoryManager taskMemoryManager,
97101
BlockManager blockManager,
102+
SerializerManager serializerManager,
98103
TaskContext taskContext,
99104
RecordComparator recordComparator,
100105
PrefixComparator prefixComparator,
101106
int initialSize,
102107
long pageSizeBytes) {
103-
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
108+
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
104109
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
105110
}
106111

107112
private UnsafeExternalSorter(
108113
TaskMemoryManager taskMemoryManager,
109114
BlockManager blockManager,
115+
SerializerManager serializerManager,
110116
TaskContext taskContext,
111117
RecordComparator recordComparator,
112118
PrefixComparator prefixComparator,
@@ -116,6 +122,7 @@ private UnsafeExternalSorter(
116122
super(taskMemoryManager, pageSizeBytes);
117123
this.taskMemoryManager = taskMemoryManager;
118124
this.blockManager = blockManager;
125+
this.serializerManager = serializerManager;
119126
this.taskContext = taskContext;
120127
this.recordComparator = recordComparator;
121128
this.prefixComparator = prefixComparator;
@@ -412,7 +419,7 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
412419
final UnsafeSorterSpillMerger spillMerger =
413420
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
414421
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
415-
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
422+
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
416423
}
417424
if (inMemSorter != null) {
418425
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -463,7 +470,7 @@ public long spill() throws IOException {
463470
}
464471
spillWriter.close();
465472
spillWriters.add(spillWriter);
466-
nextUpstream = spillWriter.getReader(blockManager);
473+
nextUpstream = spillWriter.getReader(serializerManager);
467474

468475
long released = 0L;
469476
synchronized (UnsafeExternalSorter.this) {
@@ -549,7 +556,7 @@ public UnsafeSorterIterator getIterator() throws IOException {
549556
} else {
550557
LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
551558
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
552-
queue.add(spillWriter.getReader(blockManager));
559+
queue.add(spillWriter.getReader(serializerManager));
553560
}
554561
if (inMemSorter != null) {
555562
queue.add(inMemSorter.getSortedIterator());

0 commit comments

Comments
 (0)