From 071f29f998f86a5a05744c703bf6a9a2384c3805 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 21 Oct 2015 23:48:43 -0700 Subject: [PATCH 001/324] Add support for colnames, colnames<-, coltypes<- --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/DataFrame.R | 53 +++++++++++++++++++++++++++++++- R/pkg/R/generics.R | 12 ++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 52f7a0106aae6..202ace07bf431 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -26,6 +26,8 @@ exportMethods("arrange", "attach", "cache", "collect", + "colnames", + "coltypes<-", "columns", "count", "cov", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 993be82a47f75..bc4c61d67421c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -239,7 +239,7 @@ setMethod("dtypes", #' #' @rdname columns #' @name columns -#' @aliases names +#' @aliases names colnames #' @export #' @examples #'\dontrun{ @@ -276,6 +276,57 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @name coltypes +#' @param x (DataFrame) +#' @return value (character) A character vector with the target column types for the given DataFrame +#' @rdname coltypes +#' @aliases coltypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' coltypes(df) <- c("string", "integer") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0 || length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + cast(col, value[i]) + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4a419f785e92c..08f25b2cd0d01 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -399,6 +399,18 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname colnames +#' @export +setGeneric("colnames", function(x) { standardGeneric("colnames") }) + +#' @rdname colnames<- +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes<- +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 67d8b23cd7b8d..3020776e3c6f8 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -598,6 +598,30 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form expect_equal(testNames[2], "name") }) +test_that("names() colnames() set the column names", { + df <- jsonFile(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") +}) + +test_that("coltypes() set the column types", { + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("string", "integer") + expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + expect_error(coltypes(df) <- c("string"), + "Length of type vector should match the number of columns for DataFrame") +}) + test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) From c03b6d11589102b91f08728519e8520025db91e1 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 22 Oct 2015 03:59:26 -0700 Subject: [PATCH 002/324] [SPARK-11121][CORE] Correct the TaskLocation type Correct the logic to return `HDFSCacheTaskLocation` instance when the input `str` is a in memory location. Author: zhichao.li Closes #9096 from zhichao-li/uselessBranch. --- .../org/apache/spark/scheduler/TaskLocation.scala | 2 +- .../apache/spark/scheduler/TaskSetManagerSuite.scala | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index da07ce2c6ea49..1b65926f5c749 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -67,7 +67,7 @@ private[spark] object TaskLocation { if (hstr.equals(str)) { new HostTaskLocation(str) } else { - new HostTaskLocation(hstr) + new HDFSCacheTaskLocation(hstr) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f0eadf240943e..695523cc8aa3a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -759,9 +759,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, - Seq(HostTaskLocation("host1")), - Seq(HostTaskLocation("host2")), - Seq(HDFSCacheTaskLocation("host3"))) + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) @@ -776,6 +776,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.myLocalityLevels.sameElements(Array(ANY))) } + test("Test TaskLocation for different host type.") { + assert(TaskLocation("host1") === HostTaskLocation("host1")) + assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) From 94e2064fa1b04c05c805d9175c7c78bf583db5c6 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Thu, 22 Oct 2015 09:34:07 -0700 Subject: [PATCH 003/324] [SPARK-11244][SPARKR] sparkR.stop() should remove SQLContext SparkR should remove `.sparkRSQLsc` and `.sparkRHivesc` when `sparkR.stop()` is called. Otherwise even when SparkContext is reinitialized, `sparkRSQL.init` returns the stale copy of the object and complains: ```r sc <- sparkR.init("local") sqlContext <- sparkRSQL.init(sc) sparkR.stop() sc <- sparkR.init("local") sqlContext <- sparkRSQL.init(sc) sqlContext ``` producing ```r Error in callJMethod(x, "getClass") : Invalid jobj 1. If SparkR was restarted, Spark operations need to be re-executed. ``` I have added the check and removal only when SparkContext itself is initialized. I have also added corresponding test for this fix. Let me know if you want me to move the test to SQL test suite instead. p.s. I tried lint-r but ended up a lots of errors on existing code. Author: Forest Fang Closes #9205 from saurfang/sparkR.stop. --- R/pkg/R/sparkR.R | 8 ++++++++ R/pkg/inst/tests/test_context.R | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 9cf2f1a361cf2..043b0057bd04a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -39,6 +39,14 @@ sparkR.stop <- function() { sc <- get(".sparkRjsc", envir = env) callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) + + if (exists(".sparkRSQLsc", envir = env)) { + rm(".sparkRSQLsc", envir = env) + } + + if (exists(".sparkRHivesc", envir = env)) { + rm(".sparkRHivesc", envir = env) + } } if (exists(".backendLaunched", envir = env)) { diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index 513bbc8e62059..e99815ed1562c 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -26,6 +26,16 @@ test_that("repeatedly starting and stopping SparkR", { } }) +test_that("repeatedly starting and stopping SparkR SQL", { + for (i in 1:4) { + sc <- sparkR.init() + sqlContext <- sparkRSQL.init(sc) + df <- createDataFrame(sqlContext, data.frame(a = 1:20)) + expect_equal(count(df), 20) + sparkR.stop() + } +}) + test_that("rdd GC across sparkR.stop", { sparkR.stop() sc <- sparkR.init() # sc should get id 0 From f6d06adf05afa9c5386dc2396c94e7a98730289f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 22 Oct 2015 09:46:30 -0700 Subject: [PATCH 004/324] [SPARK-10708] Consolidate sort shuffle implementations There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together. Author: Josh Rosen Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations. --- .../sort/BypassMergeSortShuffleWriter.java | 106 +++++-- .../{unsafe => sort}/PackedRecordPointer.java | 2 +- .../ShuffleExternalSorter.java} | 28 +- .../ShuffleInMemorySorter.java} | 16 +- .../ShuffleSortDataFormat.java} | 8 +- .../shuffle/sort/SortShuffleFileWriter.java | 53 ---- .../shuffle/{unsafe => sort}/SpillInfo.java | 4 +- .../{unsafe => sort}/UnsafeShuffleWriter.java | 12 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../shuffle/sort/SortShuffleManager.scala | 175 +++++++++-- .../shuffle/sort/SortShuffleWriter.scala | 28 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 202 ------------- .../spark/util/collection/ChainedBuffer.scala | 146 ---------- .../util/collection/ExternalSorter.scala | 35 +-- .../PartitionedSerializedPairBuffer.scala | 273 ------------------ .../PackedRecordPointerSuite.java | 5 +- .../ShuffleInMemorySorterSuite.java} | 16 +- .../UnsafeShuffleWriterSuite.java | 10 +- .../org/apache/spark/SortShuffleSuite.scala | 65 +++++ .../spark/scheduler/DAGSchedulerSuite.scala | 6 +- .../BypassMergeSortShuffleWriterSuite.scala | 64 ++-- .../SortShuffleManagerSuite.scala} | 30 +- .../shuffle/sort/SortShuffleWriterSuite.scala | 45 --- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 102 ------- .../util/collection/ChainedBufferSuite.scala | 144 --------- ...PartitionedSerializedPairBufferSuite.scala | 148 ---------- docs/configuration.md | 7 +- project/MimaExcludes.scala | 9 +- .../apache/spark/sql/execution/Exchange.scala | 23 +- .../execution/UnsafeRowSerializerSuite.scala | 9 +- 30 files changed, 456 insertions(+), 1317 deletions(-) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointer.java (98%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleExternalSorter.java => sort/ShuffleExternalSorter.java} (95%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorter.java => sort/ShuffleInMemorySorter.java} (88%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleSortDataFormat.java => sort/ShuffleSortDataFormat.java} (86%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/SpillInfo.java (90%) rename core/src/main/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriter.java (98%) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/PackedRecordPointerSuite.java (96%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe/UnsafeShuffleInMemorySorterSuite.java => sort/ShuffleInMemorySorterSuite.java} (87%) rename core/src/test/java/org/apache/spark/shuffle/{unsafe => sort}/UnsafeShuffleWriterSuite.java (98%) rename core/src/test/scala/org/apache/spark/shuffle/{unsafe/UnsafeShuffleManagerSuite.scala => sort/SortShuffleManagerSuite.scala} (80%) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index f5d80bbcf3557..ee82d679935c0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,21 +21,30 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.annotation.Nullable; +import scala.None$; +import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -62,7 +71,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final BlockManager blockManager; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; private final Serializer serializer; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; public BypassMergeSortShuffleWriter( - SparkConf conf, BlockManager blockManager, - Partitioner partitioner, - ShuffleWriteMetrics writeMetrics, - Serializer serializer) { + IndexShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); - this.numPartitions = partitioner.numPartitions(); this.blockManager = blockManager; - this.partitioner = partitioner; - this.writeMetrics = writeMetrics; - this.serializer = serializer; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.serializer = Serializer.getSerializer(dep.serializer()); + this.shuffleBlockResolver = shuffleBlockResolver; } @Override - public void insertAll(Iterator> records) throws IOException { + public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -124,13 +154,24 @@ public void insertAll(Iterator> records) throws IOException { for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } + + partitionLengths = + writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - @Override - public long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException { + @VisibleForTesting + long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { @@ -165,18 +206,33 @@ public long[] writePartitionedFile( } @Override - public void stop() throws IOException { - if (partitionWriters != null) { - try { - for (DiskBlockObjectWriter writer : partitionWriters) { - // This method explicitly does _not_ throw exceptions: - File file = writer.revertPartialWritesAndClose(); - if (!file.delete()) { - logger.error("Error while deleting file {}", file.getAbsolutePath()); + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; } } - } finally { - partitionWriters = null; + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return None$.empty(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index 4ee6a82c0423e..c11711966fa8c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java similarity index 95% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index e73ba39468828..85fdaa8115fa3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.File; @@ -48,7 +48,7 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then * written to a single output file (or multiple files, if we've spilled). The format of the output * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are @@ -59,9 +59,9 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class UnsafeShuffleExternalSorter { +final class ShuffleExternalSorter { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; @@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter { private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; + private long numRecordsInsertedSinceLastSpill = 0; + + /** Force this sorter to spill when there are this many elements in memory. For testing only */ + private final long numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; - public UnsafeShuffleExternalSorter( + public ShuffleExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, @@ -117,6 +121,8 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.numElementsForSpillThreshold = + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; @@ -140,7 +146,8 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(initialSize); + numRecordsInsertedSinceLastSpill = 0; } /** @@ -166,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this @@ -406,6 +413,10 @@ public void insertRecord( int lengthInBytes, int partitionId) throws IOException { + if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + spill(); + } + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; @@ -453,6 +464,7 @@ public void insertRecord( recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); + numRecordsInsertedSinceLastSpill += 1; } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java similarity index 88% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 5bab501da9364..a8dee6c6101c1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Comparator; import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleInMemorySorter { +final class ShuffleInMemorySorter { private final Sorter sorter; private static final class SortComparator implements Comparator { @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pointerArrayInsertPosition = 0; - public UnsafeShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); } public void expandPointerArray() { @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) { /** * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. */ - public static final class UnsafeShuffleSorterIterator { + public static final class ShuffleSorterIterator { private final long[] pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, long[] pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -117,8 +117,8 @@ public void loadNext() { /** * Return an iterator over record pointers in sorted order. */ - public UnsafeShuffleSorterIterator getSortedIterator() { + public ShuffleSorterIterator getSortedIterator() { sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java similarity index 86% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index a66d74ee44782..8a1e5aec6ff0e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import org.apache.spark.util.collection.SortDataFormat; -final class UnsafeShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { - public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); - private UnsafeShuffleSortDataFormat() { } + private ShuffleSortDataFormat() { } @Override public PackedRecordPointer getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java deleted file mode 100644 index 656ea0401a144..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.File; -import java.io.IOException; - -import scala.Product2; -import scala.collection.Iterator; - -import org.apache.spark.annotation.Private; -import org.apache.spark.TaskContext; -import org.apache.spark.storage.BlockId; - -/** - * Interface for objects that {@link SortShuffleWriter} uses to write its output files. - */ -@Private -public interface SortShuffleFileWriter { - - void insertAll(Iterator> records) throws IOException; - - /** - * Write all the data added into this shuffle sorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException; - - void stop() throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java similarity index 90% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 7bac0dc0bbeb6..df9f7b7abe028 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.File; import org.apache.spark.storage.TempShuffleBlockId; /** - * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + * Metadata for a block of data written by {@link ShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index fdb309e365f69..e8f050cb2dab1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.*; @@ -80,7 +80,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; @Nullable private MapStatus mapStatus; - @Nullable private UnsafeShuffleExternalSorter sorter; + @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -104,15 +104,15 @@ public UnsafeShuffleWriter( IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, - UnsafeShuffleHandle handle, + SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -195,7 +195,7 @@ public void write(scala.collection.Iterator> records) throws IOEx private void open() throws IOException { assert (sorter == null); - sorter = new UnsafeShuffleExternalSorter( + sorter = new ShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index c32998345145a..704158bfc7643 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -330,7 +330,7 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 9df4e551669cc..1105167d39d8d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ +/** + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * Sort-based shuffle has two different write paths for producing its map output files: + * + * - Serialized sorting: used when all three of the following conditions hold: + * 1. The shuffle dependency specifies no aggregation or output ordering. + * 2. The shuffle serializer supports relocation of serialized values (this is currently + * supported by KryoSerializer and Spark SQL's custom serializers). + * 3. The shuffle produces fewer than 16777216 output partitions. + * - Deserialized sorting: used to handle all other cases. + * + * ----------------------- + * Serialized sorting mode + * ----------------------- + * + * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the + * shuffle writer and are buffered in a serialized form during sorting. This write path implements + * several optimizations: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on these optimizations, see SPARK-7081. + */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { if (!conf.getBoolean("spark.shuffle.spill", true)) { @@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } - private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ + private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) + if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } } /** @@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - // We currently use the same block store shuffle fetcher as the hash-based shuffle. new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] - shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) - new SortShuffleWriter( - shuffleBlockResolver, baseShuffleHandle, mapId, context) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + numMapsForShuffle.putIfAbsent( + handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val env = SparkEnv.get + handle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + bypassMergeSortHandle, + mapId, + context, + env.conf) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shuffleMapNumber.containsKey(shuffleId)) { - val numMaps = shuffleMapNumber.remove(shuffleId) - (0 until numMaps).map{ mapId => + Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - indexShuffleBlockResolver - } - /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() } } + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * buffering map outputs in a serialized form. This is an extreme defensive programming measure, + * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. + * */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + log.debug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1cf..bbd9c1ab53cd8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: SortShuffleFileWriter[K, V] = null + private var sorter: ExternalSorter[K, V, _] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C]( require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, - writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side @@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C]( } private[spark] object SortShuffleWriter { - def shouldBypassMergeSort( - conf: SparkConf, - numPartitions: Int, - aggregator: Option[Aggregator[_, _, _]], - keyOrdering: Option[Ordering[_]]): Boolean = { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") + false + } else { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + dep.partitioner.numPartitions <= bypassMergeThreshold + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala deleted file mode 100644 index 75f22f642b9d1..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. - */ -private[spark] class UnsafeShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object UnsafeShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. - */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - val shufId = dependency.shuffleId - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") - false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") - false - } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") - true - } - } -} - -/** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - No individual record is larger than 128 MB when serialized. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. - * In sort-based shuffle, incoming records are sorted according to their target partition ids, then - * written to a single map output file. Reducers fetch contiguous regions of this file in order to - * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged - * to produce the final output file. - * - * UnsafeShuffleManager optimizes this process in several ways: - * - * - Its sort operates on serialized binary data rather than Java objects, which reduces memory - * consumption and GC overheads. This optimization requires the record serializer to have certain - * properties to allow serialized records to be re-ordered without requiring deserialization. - * See SPARK-4550, where this optimization was first proposed and implemented, for more details. - * - * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts - * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per - * record in the sorting array, this fits more of the array into cache. - * - * - The spill merging procedure operates on blocks of serialized records that belong to the same - * partition and does not need to deserialize records during the merge. - * - * - When the spill compression codec supports concatenation of compressed data, the spill merge - * simply concatenates the serialized and compressed spill partitions to produce the final output - * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used - * and avoids the need to allocate decompression or copying buffers during the merge. - * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. - */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + - "manager; its optimized shuffles will continue to spill to disk when necessary.") - } - - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - private[this] val shufflesThatFellBackToSortShuffle = - Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) - private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() - - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - sortShuffleManager.getReader(handle, startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { - handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] => - numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) - val env = SparkEnv.get - new UnsafeShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - context.taskMemoryManager(), - env.shuffleMemoryManager, - unsafeShuffleHandle, - mapId, - context, - env.conf) - case other => - shufflesThatFellBackToSortShuffle.add(handle.shuffleId) - sortShuffleManager.getWriter(handle, mapId, context) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { - sortShuffleManager.unregisterShuffle(shuffleId) - } else { - Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - } - - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - sortShuffleManager.shuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - sortShuffleManager.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala deleted file mode 100644 index ae60f3b0cb555..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - -/** - * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The - * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts - * of memory and needing to copy the full contents. The disadvantage is that the contents don't - * occupy a contiguous segment of memory. - */ -private[spark] class ChainedBuffer(chunkSize: Int) { - - private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros( - java.lang.Long.highestOneBit(chunkSize)) - assert((1 << chunkSizeLog2) == chunkSize, - s"ChainedBuffer chunk size $chunkSize must be a power of two") - private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]() - private var _size: Long = 0 - - /** - * Feed bytes from this buffer into a DiskBlockObjectWriter. - * - * @param pos Offset in the buffer to read from. - * @param os OutputStream to read into. - * @param len Number of bytes to read. - */ - def read(pos: Long, os: OutputStream, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size ${_size} of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - os.write(chunks(chunkIndex), posInChunk, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Read bytes from this buffer into a byte array. - * - * @param pos Offset in the buffer to read from. - * @param bytes Byte array to read into. - * @param offs Offset in the byte array to read to. - * @param len Number of bytes to read. - */ - def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos + len > _size) { - throw new IndexOutOfBoundsException( - s"Read of $len bytes at position $pos would go past size of buffer") - } - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toRead: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead) - written += toRead - chunkIndex += 1 - posInChunk = 0 - } - } - - /** - * Write bytes from a byte array into this buffer. - * - * @param pos Offset in the buffer to write to. - * @param bytes Byte array to write from. - * @param offs Offset in the byte array to write from. - * @param len Number of bytes to write. - */ - def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = { - if (pos > _size) { - throw new IndexOutOfBoundsException( - s"Write at position $pos starts after end of buffer ${_size}") - } - // Grow if needed - val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt - while (endChunkIndex >= chunks.length) { - chunks += new Array[Byte](chunkSize) - } - - var chunkIndex: Int = (pos >> chunkSizeLog2).toInt - var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt - var written: Int = 0 - while (written < len) { - val toWrite: Int = math.min(len - written, chunkSize - posInChunk) - System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite) - written += toWrite - chunkIndex += 1 - posInChunk = 0 - } - - _size = math.max(_size, pos + len) - } - - /** - * Total size of buffer that can be written to without allocating additional memory. - */ - def capacity: Long = chunks.size.toLong * chunkSize - - /** - * Size of the logical buffer. - */ - def size: Long = _size -} - -/** - * Output stream that writes to a ChainedBuffer. - */ -private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream { - private var pos: Long = 0 - - override def write(b: Int): Unit = { - throw new UnsupportedOperationException() - } - - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - chainedBuffer.write(pos, bytes, offs, len) - pos += len - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 749be34d8e8fd..c48c453a90d01 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we - * don't. Inside these buffers, we sort elements by partition ID and then possibly also by key. + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. * To avoid calling the partitioner multiple times with each key, we store the partition ID * alongside each record. * @@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Option[Serializer] = None) extends Logging - with Spillable[WritablePartitionedPairCollection[K, C]] - with SortShuffleFileWriter[K, V] { + with Spillable[WritablePartitionedPairCollection[K, C]] { private val conf = SparkEnv.get.conf @@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C]( if (shouldPartition) partitioner.get.getPartition(key) else 0 } - // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class. - // As a sanity check, make sure that we're not handling a shuffle which should use that path. - if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) { - throw new IllegalArgumentException("ExternalSorter should not be used to handle " - + " a sort that the BypassMergeSortShuffleWriter should handle") - } - private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager private val ser = Serializer.getSerializer(serializer) @@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C]( // grow internal data structures by growing + copying every time the number of objects doubles. private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) - private val useSerializedPairBuffer = - ordering.isEmpty && - conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) && - ser.supportsRelocationOfSerializedObjects - private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB - private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = { - if (useSerializedPairBuffer) { - new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance) - } else { - new PartitionedPairBuffer[K, C] - } - } // Data structures to store in-memory objects before we spill. Depending on whether we have an // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we // store them in an array buffer. private var map = new PartitionedAppendOnlyMap[K, C] - private var buffer = newBuffer() + private var buffer = new PartitionedPairBuffer[K, C] // Total spilling statistics private var _diskBytesSpilled = 0L @@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C]( */ private[spark] def numSpills: Int = spills.size - override def insertAll(records: Iterator[Product2[K, V]]): Unit = { + def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C]( } else { estimatedSize = buffer.estimateSize() if (maybeSpill(buffer, estimatedSize)) { - buffer = newBuffer() + buffer = new PartitionedPairBuffer[K, C] } } @@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C]( * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ - override def writePartitionedFile( + def writePartitionedFile( blockId: BlockId, context: TaskContext, outputFile: File): Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala deleted file mode 100644 index 87a786b02d651..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.InputStream -import java.nio.IntBuffer -import java.util.Comparator - -import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.DiskBlockObjectWriter -import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ - -/** - * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes - * its records upon insert and stores them as raw bytes. - * - * We use two data-structures to store the contents. The serialized records are stored in a - * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a - * metadata buffer that stores pointers into the data buffer as well as the partition ID of each - * record. Each entry in the metadata buffer takes up a fixed amount of space. - * - * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not - * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can - * happen without following any pointers, which should minimize cache misses. - * - * Currently, only sorting by partition is supported. - * - * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across - * two integers: - * - * +-------------+------------+------------+-------------+ - * | keyStart | keyValLen | partitionId | - * +-------------+------------+------------+-------------+ - * - * The buffer can support up to `536870911 (2 ^ 29 - 1)` records. - * - * @param metaInitialRecords The initial number of entries in the metadata buffer. - * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records. - * @param serializerInstance the serializer used for serializing inserted records. - */ -private[spark] class PartitionedSerializedPairBuffer[K, V]( - metaInitialRecords: Int, - kvBlockSize: Int, - serializerInstance: SerializerInstance) - extends WritablePartitionedPairCollection[K, V] with SizeTracker { - - if (serializerInstance.isInstanceOf[JavaSerializerInstance]) { - throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" + - " Java-serialized objects.") - } - - require(metaInitialRecords <= MAXIMUM_RECORDS, - s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records") - private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE) - - private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize) - private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer) - private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream) - - def insert(partition: Int, key: K, value: V): Unit = { - if (metaBuffer.position == metaBuffer.capacity) { - growMetaBuffer() - } - - val keyStart = kvBuffer.size - kvSerializationStream.writeKey[Any](key) - kvSerializationStream.writeValue[Any](value) - kvSerializationStream.flush() - val keyValLen = (kvBuffer.size - keyStart).toInt - - // keyStart, a long, gets split across two ints - metaBuffer.put(keyStart.toInt) - metaBuffer.put((keyStart >> 32).toInt) - metaBuffer.put(keyValLen) - metaBuffer.put(partition) - } - - /** Double the size of the array because we've reached capacity */ - private def growMetaBuffer(): Unit = { - if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) { - throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records") - } - val newCapacity = - if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) { - // Overflow - MAXIMUM_META_BUFFER_CAPACITY - } else { - metaBuffer.capacity * 2 - } - val newMetaBuffer = IntBuffer.allocate(newCapacity) - newMetaBuffer.put(metaBuffer.array) - metaBuffer = newMetaBuffer - } - - /** Iterate through the data in a given order. For this class this is not really destructive. */ - override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) - : Iterator[((Int, K), V)] = { - sort(keyComparator) - val is = orderedInputStream - val deserStream = serializerInstance.deserializeStream(is) - new Iterator[((Int, K), V)] { - var metaBufferPos = 0 - def hasNext: Boolean = metaBufferPos < metaBuffer.position - def next(): ((Int, K), V) = { - val key = deserStream.readKey[Any]().asInstanceOf[K] - val value = deserStream.readValue[Any]().asInstanceOf[V] - val partition = metaBuffer.get(metaBufferPos + PARTITION) - metaBufferPos += RECORD_SIZE - ((partition, key), value) - } - } - } - - override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity - - override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) - : WritablePartitionedIterator = { - sort(keyComparator) - new WritablePartitionedIterator { - // current position in the meta buffer in ints - var pos = 0 - - def writeNext(writer: DiskBlockObjectWriter): Unit = { - val keyStart = getKeyStartPos(metaBuffer, pos) - val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) - pos += RECORD_SIZE - kvBuffer.read(keyStart, writer, keyValLen) - writer.recordWritten() - } - def nextPartition(): Int = metaBuffer.get(pos + PARTITION) - def hasNext(): Boolean = pos < metaBuffer.position - } - } - - // Visible for testing - def orderedInputStream: OrderedInputStream = { - new OrderedInputStream(metaBuffer, kvBuffer) - } - - private def sort(keyComparator: Option[Comparator[K]]): Unit = { - val comparator = if (keyComparator.isEmpty) { - new Comparator[Int]() { - def compare(partition1: Int, partition2: Int): Int = { - partition1 - partition2 - } - } - } else { - throw new UnsupportedOperationException() - } - - val sorter = new Sorter(new SerializedSortDataFormat) - sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator) - } -} - -private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer) - extends InputStream { - - import PartitionedSerializedPairBuffer._ - - private var metaBufferPos = 0 - private var kvBufferPos = - if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0 - - override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length) - - override def read(bytes: Array[Byte], offs: Int, len: Int): Int = { - if (metaBufferPos >= metaBuffer.position) { - return -1 - } - val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) - - (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt - val toRead = math.min(bytesRemainingInRecord, len) - kvBuffer.read(kvBufferPos, bytes, offs, toRead) - if (toRead == bytesRemainingInRecord) { - metaBufferPos += RECORD_SIZE - if (metaBufferPos < metaBuffer.position) { - kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos) - } - } else { - kvBufferPos += toRead - } - toRead - } - - override def read(): Int = { - throw new UnsupportedOperationException() - } -} - -private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] { - - private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE) - - /** Return the sort key for the element at the given index. */ - override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = { - metaBuffer.get(pos * RECORD_SIZE + PARTITION) - } - - /** Swap two elements. */ - override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = { - val iOff = pos0 * RECORD_SIZE - val jOff = pos1 * RECORD_SIZE - System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE) - System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE) - System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE) - } - - /** Copy a single element from src(srcPos) to dst(dstPos). */ - override def copyElement( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE) - } - - /** - * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. - * Overlapping ranges are allowed. - */ - override def copyRange( - src: IntBuffer, - srcPos: Int, - dst: IntBuffer, - dstPos: Int, - length: Int): Unit = { - val srcOff = srcPos * RECORD_SIZE - val dstOff = dstPos * RECORD_SIZE - System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length) - } - - /** - * Allocates a Buffer that can hold up to 'length' elements. - * All elements of the buffer should be considered invalid until data is explicitly copied in. - */ - override def allocate(length: Int): IntBuffer = { - IntBuffer.allocate(length * RECORD_SIZE) - } -} - -private object PartitionedSerializedPairBuffer { - val KEY_START = 0 // keyStart, a long, gets split across two ints - val KEY_VAL_LEN = 2 - val PARTITION = 3 - val RECORD_SIZE = PARTITION + 1 // num ints of metadata - - val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1 - val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4 - - def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = { - val lower32 = metaBuffer.get(metaBufferPos + KEY_START) - val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1) - (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL) - } -} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java similarity index 96% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 934b7e03050b6..232ae4d926bcd 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; +import org.apache.spark.shuffle.sort.PackedRecordPointer; import org.junit.Test; import static org.junit.Assert.*; @@ -24,7 +25,7 @@ import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; +import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; public class PackedRecordPointerSuite { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 40fefe2c9d140..1ef3c5ff64bac 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Arrays; import java.util.Random; @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -public class UnsafeShuffleInMemorySorterSuite { +public class ShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; @@ -40,8 +40,8 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -62,7 +62,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); - final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -79,7 +79,7 @@ public void testBasicSorting() throws Exception { } // Sort the records - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int prevPartitionId = -1; Arrays.sort(dataToSort); for (int i = 0; i < dataToSort.length; i++) { @@ -103,7 +103,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { @@ -112,7 +112,7 @@ public void testSortingManyNumbers() throws Exception { } Arrays.sort(numbersToSort); int[] sorterResult = new int[numbersToSort.length]; - UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); int j = 0; while (iter.hasNext()) { iter.loadNext(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java similarity index 98% rename from core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java rename to core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index d218344cd4520..29d9823b1f71b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.*; import java.nio.ByteBuffer; @@ -23,7 +23,6 @@ import scala.*; import scala.collection.Iterator; -import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; import com.google.common.collect.Iterators; @@ -56,6 +55,7 @@ import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import org.apache.spark.storage.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; @@ -204,7 +204,7 @@ private UnsafeShuffleWriter createWriter( shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, conf @@ -461,7 +461,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); - final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); @@ -516,7 +516,7 @@ public void testPeakMemoryUsed() throws Exception { shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle<>(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf); diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 63358172ea1f4..b8ab227517cc4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -17,13 +17,78 @@ package org.apache.spark +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. + private var tempDir: File = _ + override def beforeAll() { conf.set("spark.shuffle.manager", "sort") } + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + conf.set("spark.local.dir", tempDir.getAbsolutePath) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the new serialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") { + sc = new SparkContext("local", "test", conf) + // Create a shuffled RDD and verify that it actually uses the old deserialized map output path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep)) + ensureFilesAreCleanedUp(shuffledRdd) + } + + private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = { + def getAllFiles: Set[File] = + FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5b01ddb298c39..3816b8c4a09aa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1062,10 +1062,10 @@ class DAGSchedulerSuite */ test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val firstShuffleId = firstShuffleDep.shuffleId val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) submit(reduceRdd, Array(0)) @@ -1175,7 +1175,7 @@ class DAGSchedulerSuite */ test("register map outputs correctly after ExecutorLost and task Resubmitted") { val firstRDD = new MyRDD(sc, 3, Nil) - val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) submit(reduceRdd, Array(0)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 341f56df2dafc..b92a302806f76 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _ private var taskMetrics: TaskMetrics = _ - private var shuffleWriteMetrics: ShuffleWriteMetrics = _ private var tempDir: File = _ private var outputFile: File = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] - private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0) - private val serializer: Serializer = new JavaSerializer(conf) + private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) - shuffleWriteMetrics = new ShuffleWriteMetrics taskMetrics = new TaskMetrics - taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) MockitoAnnotations.initMocks(this) + shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( + shuffleId = 0, + numMaps = 2, + dependency = dependency + ) + when(dependency.partitioner).thenReturn(new HashPartitioner(7)) + when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(Iterator.empty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === 0) + writer.write(Iterator.empty) + writer.stop( /* success = */ true) + assert(writer.getPartitionLengths.sum === 0) assert(outputFile.exists()) assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === 0) assert(shuffleWriteMetrics.shuffleRecordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) @@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) - writer.insertAll(records) + writer.write(records) + writer.stop( /* success = */ true) assert(temporaryFilesCreated.nonEmpty) - val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile) - assert(partitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.sum === outputFile.length()) assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length()) assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) @@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( - new SparkConf(loadDefaults = false), blockManager, - new HashPartitioner(7), - shuffleWriteMetrics, - serializer + blockResolver, + shuffleHandle, + 0, // MapId + taskContext, + conf ) intercept[SparkException] { - writer.insertAll((0 until 100000).iterator.map(i => { + writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte })) } assert(temporaryFilesCreated.nonEmpty) - writer.stop() + writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala similarity index 80% rename from core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 6727934d8c7ca..8744a072cb3f6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe +package org.apache.spark.shuffle.sort import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -29,9 +29,9 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are * performed in other suites. */ -class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { +class SortShuffleManagerSuite extends SparkFunSuite with Matchers { - import UnsafeShuffleManager.canUseUnsafeShuffle + import SortShuffleManager.canUseSerializedShuffle private class RuntimeExceptionAnswer extends Answer[Object] { override def answer(invocation: InvocationOnMock): Object = { @@ -55,10 +55,10 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { dep } - test("supported shuffle dependencies") { + test("supported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, @@ -68,7 +68,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) when(rangePartitioner.numPartitions).thenReturn(2) - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = rangePartitioner, serializer = kryo, keyOrdering = None, @@ -77,7 +77,7 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // Shuffles with key orderings are supported as long as no aggregator is specified - assert(canUseUnsafeShuffle(shuffleDep( + assert(canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), @@ -87,12 +87,12 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { } - test("unsupported shuffle dependencies") { + test("unsupported shuffle dependencies for serialized shuffle") { val kryo = Some(new KryoSerializer(new SparkConf())) val java = Some(new JavaSerializer(new SparkConf())) // We only support serializers that support object relocation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = java, keyOrdering = None, @@ -100,9 +100,11 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles with more than 16 million output partitions - assert(!canUseUnsafeShuffle(shuffleDep( - partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + // The serialized shuffle path do not support shuffles with more than 16 million output + // partitions, due to a limitation in its sorter implementation. + assert(!canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner( + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1), serializer = kryo, keyOrdering = None, aggregator = None, @@ -110,14 +112,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers { ))) // We do not support shuffles that perform aggregation - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = None, aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), mapSideCombine = false ))) - assert(!canUseUnsafeShuffle(shuffleDep( + assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, keyOrdering = Some(mock(classOf[Ordering[Any]])), diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala deleted file mode 100644 index 34b4984f12c09..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort - -import org.mockito.Mockito._ - -import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite} - -class SortShuffleWriterSuite extends SparkFunSuite { - - import SortShuffleWriter._ - - test("conditions for bypassing merge-sort") { - val conf = new SparkConf(loadDefaults = false) - val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS) - val ord = implicitly[Ordering[Int]] - - // Numbers of partitions that are above and below the default bypassMergeThreshold - val FEW_PARTITIONS = 50 - val MANY_PARTITIONS = 10000 - - // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high - assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None)) - assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None)) - - // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord))) - assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None)) - } -} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala deleted file mode 100644 index 259020a2ddc34..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.io.File - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.Utils - -class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. - - override def beforeAll() { - conf.set("spark.shuffle.manager", "tungsten-sort") - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new KryoSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } - - test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { - val tmpDir = Utils.createTempDir() - try { - val myConf = conf.clone() - .set("spark.local.dir", tmpDir.getAbsolutePath) - sc = new SparkContext("local", "test", myConf) - // Create a shuffled RDD and verify that it will actually use the old SortShuffle path - val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) - val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) - .setSerializer(new JavaSerializer(myConf)) - val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) - def getAllFiles: Set[File] = - FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet - val filesBeforeShuffle = getAllFiles - // Force the shuffle to be performed - shuffledRdd.count() - // Ensure that the shuffle actually created files that will need to be cleaned up - val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle - filesCreatedByShuffle.map(_.getName) should be - Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") - // Check that the cleanup actually removes the files - sc.env.blockManager.master.removeShuffle(0, blocking = true) - for (file <- filesCreatedByShuffle) { - assert (!file.exists(), s"Shuffle file $file was not cleaned up") - } - } finally { - Utils.deleteRecursively(tmpDir) - } - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala deleted file mode 100644 index 05306f408847d..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.nio.ByteBuffer - -import org.scalatest.Matchers._ - -import org.apache.spark.SparkFunSuite - -class ChainedBufferSuite extends SparkFunSuite { - test("write and read at start") { - // write from start of source array - val buffer = new ChainedBuffer(8) - buffer.capacity should be (0) - verifyWriteAndRead(buffer, 0, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 0, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 0, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 0, 0, 0, 8) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 0, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 0, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at middle") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 3) - - // write from start of source array - verifyWriteAndRead(buffer, 3, 0, 0, 4) - buffer.capacity should be (8) - - // write from middle of source array - verifyWriteAndRead(buffer, 3, 5, 0, 4) - buffer.capacity should be (8) - - // read to middle of target array - verifyWriteAndRead(buffer, 3, 0, 5, 4) - buffer.capacity should be (8) - - // write up to border - verifyWriteAndRead(buffer, 3, 0, 0, 5) - buffer.capacity should be (8) - - // expand into second buffer - verifyWriteAndRead(buffer, 3, 0, 0, 12) - buffer.capacity should be (16) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 3, 0, 0, 28) - buffer.capacity should be (32) - } - - test("write and read at later buffer") { - val buffer = new ChainedBuffer(8) - - // fill to a middle point - verifyWriteAndRead(buffer, 0, 0, 0, 11) - - // write from start of source array - verifyWriteAndRead(buffer, 11, 0, 0, 4) - buffer.capacity should be (16) - - // write from middle of source array - verifyWriteAndRead(buffer, 11, 5, 0, 4) - buffer.capacity should be (16) - - // read to middle of target array - verifyWriteAndRead(buffer, 11, 0, 5, 4) - buffer.capacity should be (16) - - // write up to border - verifyWriteAndRead(buffer, 11, 0, 0, 5) - buffer.capacity should be (16) - - // expand into second buffer - verifyWriteAndRead(buffer, 11, 0, 0, 12) - buffer.capacity should be (24) - - // expand into multiple buffers - verifyWriteAndRead(buffer, 11, 0, 0, 28) - buffer.capacity should be (40) - } - - - // Used to make sure we're writing different bytes each time - var rangeStart = 0 - - /** - * @param buffer The buffer to write to and read from. - * @param offsetInBuffer The offset to write to in the buffer. - * @param offsetInSource The offset in the array that the bytes are written from. - * @param offsetInTarget The offset in the array to read the bytes into. - * @param length The number of bytes to read and write - */ - def verifyWriteAndRead( - buffer: ChainedBuffer, - offsetInBuffer: Int, - offsetInSource: Int, - offsetInTarget: Int, - length: Int): Unit = { - val source = new Array[Byte](offsetInSource + length) - (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource) - buffer.write(offsetInBuffer, source, offsetInSource, length) - val target = new Array[Byte](offsetInTarget + length) - buffer.read(offsetInBuffer, target, offsetInTarget, length) - ByteBuffer.wrap(source, offsetInSource, length) should be - (ByteBuffer.wrap(target, offsetInTarget, length)) - - rangeStart += 100 - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala deleted file mode 100644 index 3b67f6206495a..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import com.google.common.io.ByteStreams - -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.Mockito.RETURNS_SMART_NULLS -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.Matchers._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.DiskBlockObjectWriter - -class PartitionedSerializedPairBufferSuite extends SparkFunSuite { - test("OrderedInputStream single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - - val bytes = ByteStreams.toByteArray(buffer.orderedInputStream) - - val baos = new ByteArrayOutputStream() - val stream = serializerInstance.serializeStream(baos) - stream.writeObject(10) - stream.writeObject(struct) - stream.close() - - baos.toByteArray should be (bytes) - } - - test("insert single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (1) - elements.head should be (((4, 10), struct)) - } - - test("insert multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val elements = buffer.partitionedDestructiveSortedIterator(None).toArray - elements.size should be (3) - elements(0) should be (((4, 2), struct2)) - elements(1) should be (((5, 3), struct3)) - elements(2) should be (((6, 1), struct1)) - } - - test("write single record") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct = SomeStruct("something", 5) - buffer.insert(4, 10, struct) - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - stream.readObject[AnyRef]() should be (10) - stream.readObject[AnyRef]() should be (struct) - } - - test("write multiple records") { - val serializerInstance = new KryoSerializer(new SparkConf()).newInstance - val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance) - val struct1 = SomeStruct("something1", 8) - buffer.insert(6, 1, struct1) - val struct2 = SomeStruct("something2", 9) - buffer.insert(4, 2, struct2) - val struct3 = SomeStruct("something3", 10) - buffer.insert(5, 3, struct3) - - val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val (writer, baos) = createMockWriter() - assert(it.hasNext) - it.nextPartition should be (4) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (5) - it.writeNext(writer) - assert(it.hasNext) - it.nextPartition should be (6) - it.writeNext(writer) - assert(!it.hasNext) - - val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) - val iter = stream.asIterator - iter.next() should be (2) - iter.next() should be (struct2) - iter.next() should be (3) - iter.next() should be (struct3) - iter.next() should be (1) - iter.next() should be (struct1) - assert(!iter.hasNext) - } - - def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { - val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) - val baos = new ByteArrayOutputStream() - when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - val args = invocationOnMock.getArguments - val bytes = args(0).asInstanceOf[Array[Byte]] - val offset = args(1).asInstanceOf[Int] - val length = args(2).asInstanceOf[Int] - baos.write(bytes, offset, length) - } - }) - (writer, baos) - } -} - -case class SomeStruct(str: String, num: Int) diff --git a/docs/configuration.md b/docs/configuration.md index 46d92ceb762d6..be9c36bdfe3de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.manager sort - Implementation to use for shuffling data. There are three implementations available: - sort, hash and the new (1.5+) tungsten-sort. + Implementation to use for shuffling data. There are two implementations available: + sort and hash. Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. - Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly - implementation with a fall back to regular sort based shuffle if its requirements are not - met. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0872d3f3e7093..b5e661d3ecfa8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,6 +37,7 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("network"), + MimaBuild.excludeSparkPackage("unsafe"), // These are needed if checking against the sbt build, since they are part of // the maven-generated artifacts in 1.3. excludePackage("org.spark-project.jetty"), @@ -44,7 +45,11 @@ object MimaExcludes { // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar") + excludePackage("org.apache.spark.sql.columnar"), + // The shuffle package is considered private. + excludePackage("org.apache.spark.shuffle"), + // The collections utlities are considered pricate. + excludePackage("org.apache.spark.util.collection") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ @@ -750,4 +755,4 @@ object MimaExcludes { MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") case _ => Seq() } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 1d3379a5e2d91..7f60c8f5eaa95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree @@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf val shuffleManager = SparkEnv.get.shuffleManager - val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || - shuffleManager.isInstanceOf[UnsafeShuffleManager] + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { @@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { - // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting - // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an aggregator or ordering and the record serializer - // has certain properties. If this optimization is enabled, we can safely avoid the copy. + } else if (serializer.supportsRelocationOfSerializedObjects) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties. If this optimization is enabled, we can safely avoid the copy. // // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only // need to check whether the optimization is enabled and supported by our serializer. - // - // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code - // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls - // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In - // both cases, we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. true } } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 75d1fced594c4..1680d7e0a85ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten Utils.tryWithSafeFinally { val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1024") + .set("spark.shuffle.spill.initialMemoryThreshold", "1") .set("spark.shuffle.sort.bypassMergeThreshold", "0") .set("spark.testing.memory", "80000") @@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") // prepare data val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 1000).iterator.map { i => + val data = (1 to 10000).iterator.map { i => (i, converter(Row(i))) } val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( @@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } } - test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") { - val conf = new SparkConf() - .set("spark.shuffle.manager", "tungsten-sort") + test("SPARK-10403: unsafe row serializer with SortShuffleManager") { + val conf = new SparkConf().set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) From 42d225f449c633be7465493c57b9881303ee14ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Oct 2015 10:53:59 -0700 Subject: [PATCH 005/324] [SPARK-11216][SQL][FOLLOW-UP] add encoder/decoder for external row address comments in https://github.com/apache/spark/pull/9184 Author: Wenchen Fan Closes #9212 from cloud-fan/encoder. --- .../spark/sql/catalyst/encoders/ClassEncoder.scala | 14 +++----------- .../spark/sql/catalyst/encoders/RowEncoder.scala | 9 ++++++--- .../spark/sql/catalyst/expressions/objects.scala | 8 +++++++- .../sql/catalyst/encoders/RowEncoderSuite.scala | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala index f3a1063871775..54096f18cbea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala @@ -48,20 +48,12 @@ case class ClassEncoder[T]( private val dataType = ObjectType(clsTag.runtimeClass) override def toRow(t: T): InternalRow = { - if (t == null) { - null - } else { - inputRow(0) = t - extractProjection(inputRow) - } + inputRow(0) = t + extractProjection(inputRow) } override def fromRow(row: InternalRow): T = { - if (row eq null) { - null.asInstanceOf[T] - } else { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } + constructProjection(row).get(0, dataType).asInstanceOf[T] } override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3e74aabd078df..5142856afdcac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +/** + * A factory for constructing encoders that convert external row to/from the Spark SQL + * internal binary representation. + */ object RowEncoder { - def apply(schema: StructType): ClassEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -136,7 +139,7 @@ object RowEncoder { constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) ) } - CreateRow(fields) + CreateExternalRow(fields) } private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { @@ -195,7 +198,7 @@ object RowEncoder { Literal.create(null, externalDataTypeFor(f.dataType)), constructorFor(getField(input, i, f.dataType), f.dataType)) } - CreateRow(convertedFields) + CreateExternalRow(convertedFields) } private def getField( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8fc00ad1bcb04..b42d6c5c1e14e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -456,7 +456,13 @@ case class MapObjects( } } -case class CreateRow(children: Seq[Expression]) extends Expression { +/** + * Constructs a new external row, using the result of evaluating the specified expressions + * as content. + * + * @param children A list of expression to use as content of the external row. + */ +case class CreateExternalRow(children: Seq[Expression]) extends Expression { override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6041b62b74bdd..e8301e8e06b52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -73,7 +73,7 @@ class RowEncoderSuite extends SparkFunSuite { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) - val inputGenerator = RandomDataGenerator.forType(schema).get + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null try { From 7bb6d31cff279776f90744407291682774cfe1c2 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 22 Oct 2015 11:31:47 -0700 Subject: [PATCH 006/324] [SPARK-11232][CORE] Use 'offer' instead of 'put' to make sure calling send won't be interrupted The current `NettyRpcEndpointRef.send` can be interrupted because it uses `LinkedBlockingQueue.put`, which may hang the application. Image the following execution order: | thread 1: TaskRunner.kill | thread 2: TaskRunner.run ------------- | ------------- | ------------- 1 | killed = true | 2 | | if (killed) { 3 | | throw new TaskKilledException 4 | | case _: TaskKilledException _: InterruptedException if task.killed => 5 | task.kill(interruptThread): interruptThread is true | 6 | | execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) 7 | | localEndpoint.send(StatusUpdate(taskId, state, serializedData)): in LocalBackend Then `localEndpoint.send(StatusUpdate(taskId, state, serializedData))` will throw `InterruptedException`. This will prevent the executor from updating the task status and hang the application. An failure caused by the above issue here: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/44062/consoleFull Since `receivers` is an unbounded `LinkedBlockingQueue`, we can just use `LinkedBlockingQueue.offer` to resolve this issue. Author: zsxwing Closes #9198 from zsxwing/dont-interrupt-send. --- .../scala/org/apache/spark/rpc/netty/Dispatcher.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index f1a8273f157ef..7bf44a6565b61 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -66,7 +66,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } val data = endpoints.get(name) endpointRefs.put(data.endpoint, data.ref) - receivers.put(data) // for the OnStart message + receivers.offer(data) // for the OnStart message } endpointRef } @@ -80,7 +80,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val data = endpoints.remove(name) if (data != null) { data.inbox.stop() - receivers.put(data) // for the OnStop message + receivers.offer(data) // for the OnStop message } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -163,7 +163,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { true } else { data.inbox.post(createMessageFn(data.ref)) - receivers.put(data) + receivers.offer(data) false } } @@ -183,7 +183,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { // Stop all endpoints. This will queue all endpoints for processing by the message loops. endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) // Enqueue a message that tells the message loops to stop. - receivers.put(PoisonPill) + receivers.offer(PoisonPill) threadpool.shutdown() } @@ -218,7 +218,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val data = receivers.take() if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. - receivers.put(PoisonPill) + receivers.offer(PoisonPill) return } data.inbox.process(Dispatcher.this) From 3535b91ddc9fd05b613a121e09263b0f378bd5fa Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 22 Oct 2015 11:39:06 -0700 Subject: [PATCH 007/324] [SPARK-11163] Remove unnecessary addPendingTask calls. This commit removes unnecessary calls to addPendingTask in TaskSetManager.executorLost. These calls are unnecessary: for tasks that are still pending and haven't been launched, they're still in all of the correct pending lists, so calling addPendingTask has no effect. For tasks that are currently running (which may still be in the pending lists, depending on how they were scheduled), we call addPendingTask in handleFailedTask, so the calls at the beginning of executorLost are redundant. I think these calls are left over from when we re-computed the locality levels in addPendingTask; now that we call recomputeLocality separately, I don't think these are necessary. Now that those calls are removed, the readding parameter in addPendingTask is no longer necessary, so this commit also removes that parameter. markhamstra can you take a look at this? cc vanzin Author: Kay Ousterhout Closes #9154 from kayousterhout/SPARK-11163. --- .../spark/scheduler/TaskSetManager.scala | 27 ++++--------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c02597c4365c9..987800d3d1f1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -177,14 +177,11 @@ private[spark] class TaskSetManager( var emittedTaskSizeWarning = false - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there + /** Add a task to all the pending-task lists that it should be on. */ + private def addPendingTask(index: Int) { + // Utility method that adds `index` to a list only if it's not already there def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { + if (!list.contains(index)) { list += index } } @@ -219,9 +216,7 @@ private[spark] class TaskSetManager( addTo(pendingTasksWithNoPrefs) } - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } + allPendingTasks += index // No point scanning this whole list to find the old task there } /** @@ -783,18 +778,6 @@ private[spark] class TaskSetManager( /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because dequeueTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding = true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding = true) - } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor From d4950e6be48954125eeb1be550c102636521bde3 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 22 Oct 2015 13:11:37 -0700 Subject: [PATCH 008/324] [SPARK-9735][SQL] Respect the user specified schema than the infer partition schema for HadoopFsRelation To enable the unit test of `hadoopFsRelationSuite.Partition column type casting`. It previously threw exception like below, as we treat the auto infer partition schema with higher priority than the user specified one. ``` java.lang.ClassCastException: java.lang.Integer cannot be cast to org.apache.spark.unsafe.types.UTF8String at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.getUTF8String(rows.scala:220) at org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(generated.java:62) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) 07:44:01.344 ERROR org.apache.spark.executor.Executor: Exception in task 14.0 in stage 3.0 (TID 206) java.lang.ClassCastException: java.lang.Integer cannot be cast to org.apache.spark.unsafe.types.UTF8String at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.getUTF8String(rows.scala:220) at org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(generated.java:62) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) ``` Author: Cheng Hao Closes #8026 from chenghao-intel/partition_discovery. --- .../apache/spark/sql/sources/interfaces.scala | 29 +++++++++++-- .../sql/sources/hadoopFsRelationSuites.scala | 42 +++++++++++++------ 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7b030b7d73bd5..84eef0f8a672c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration @@ -544,11 +544,32 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { - val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference) + userDefinedPartitionColumns match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getString(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + + case _ => + // user did not provide a partitioning schema + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()) + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 42b9b3d6340d8..e3605bb3f6bf0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -510,21 +510,39 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - // HadoopFsRelation.discoverPartitions() called by refresh(), which will ignore - // the given partition data type. - ignore("Partition column type casting") { + test("SPARK-9735 Partition column type casting") { withTempPath { file => - val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) - - input - .write - .format(dataSourceName) - .mode(SaveMode.Overwrite) - .partitionBy("ps", "p2") - .saveAsTable("t") + val df = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 1.0d, p2, 123, 123.123f)).toDF("a", "b", "p1", "p2", "p3", "f") + + val input = df.select( + 'a, + 'b, + 'p1.cast(StringType).as('ps1), + 'p2, + 'p3.cast(FloatType).as('pf1), + 'f) withTempTable("t") { - checkAnswer(sqlContext.table("t"), input.collect()) + input + .write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("ps1", "p2", "pf1", "f") + .saveAsTable("t") + + input + .write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("ps1", "p2", "pf1", "f") + .saveAsTable("t") + + val realData = input.collect() + + checkAnswer(sqlContext.table("t"), realData ++ realData) } } } From 188ea348fdcf877d86f3c433cd15f6468fe3b42a Mon Sep 17 00:00:00 2001 From: guoxi Date: Thu, 22 Oct 2015 13:56:18 -0700 Subject: [PATCH 009/324] [SPARK-11242][SQL] In conf/spark-env.sh.template SPARK_DRIVER_MEMORY is documented incorrectly Minor fix on the comment Author: guoxi Closes #9201 from xguo27/SPARK-11242. --- conf/spark-env.sh.template | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 990ded420be72..771251f90ee36 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -36,10 +36,10 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) -# - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). -# - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) +# - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). +# - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. From 53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 22 Oct 2015 15:20:17 -0700 Subject: [PATCH 010/324] [SPARK-11116][SQL] First Draft of Dataset API *This PR adds a new experimental API to Spark, tentitively named Datasets.* A `Dataset` is a strongly-typed collection of objects that can be transformed in parallel using functional or relational operations. Example usage is as follows: ### Functional ```scala > val ds: Dataset[Int] = Seq(1, 2, 3).toDS() > ds.filter(_ % 1 == 0).collect() res1: Array[Int] = Array(1, 2, 3) ``` ### Relational ```scala scala> ds.toDF().show() +-----+ |value| +-----+ | 1| | 2| | 3| +-----+ > ds.select(expr("value + 1").as[Int]).collect() res11: Array[Int] = Array(2, 3, 4) ``` ## Comparison to RDDs A `Dataset` differs from an `RDD` in the following ways: - The creation of a `Dataset` requires the presence of an explicit `Encoder` that can be used to serialize the object into a binary format. Encoders are also capable of mapping the schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime reflection based serialization. - Internally, a `Dataset` is represented by a Catalyst logical plan and the data is stored in the encoded form. This representation allows for additional logical operations and enables many operations (sorting, shuffling, etc.) to be performed without deserializing to an object. A `Dataset` can be converted to an `RDD` by calling the `.rdd` method. ## Comparison to DataFrames A `Dataset` can be thought of as a specialized DataFrame, where the elements map to a specific JVM object type, instead of to a generic `Row` container. A DataFrame can be transformed into specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed `Dataset` to a generic DataFrame by calling `ds.toDF()`. ## Implementation Status and TODOs This is a rough cut at the least controversial parts of the API. The primary purpose here is to get something committed so that we can better parallelize further work and get early feedback on the API. The following is being deferred to future PRs: - Joins and Aggregations (prototype here https://github.com/apache/spark/commit/f11f91e6f08c8cf389b8388b626cd29eec32d937) - Support for Java Additionally, the responsibility for binding an encoder to a given schema is currently done in a fairly ad-hoc fashion. This is an internal detail, and what we are doing today works for the cases we care about. However, as we add more APIs we'll probably need to do this in a more principled way (i.e. separate resolution from binding as we do in DataFrames). ## COMPATIBILITY NOTE Long term we plan to make `DataFrame` extend `Dataset[Row]`. However, making this change to che class hierarchy would break the function signatures for the existing function operations (map, flatMap, etc). As such, this class should be considered a preview of the final API. Changes will be made to the interface after Spark 1.6. Author: Michael Armbrust Closes #9190 from marmbrus/dataset-infra. --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +- .../sql/catalyst/encoders/ClassEncoder.scala | 38 +- .../spark/sql/catalyst/encoders/Encoder.scala | 19 +- .../catalyst/encoders/ProductEncoder.scala | 12 +- .../catalyst/encoders/primitiveTypes.scala | 100 +++++ .../spark/sql/catalyst/encoders/tuples.scala | 173 ++++++++ .../catalyst/expressions/AttributeMap.scala | 7 + .../catalyst/expressions/AttributeSet.scala | 4 + .../expressions/complexTypeCreator.scala | 8 + .../sql/catalyst/expressions/package.scala | 12 + .../plans/logical/basicOperators.scala | 72 +++- .../encoders/PrimitiveEncoderSuite.scala | 43 ++ .../encoders/ProductEncoderSuite.scala | 21 +- .../scala/org/apache/spark/sql/Column.scala | 15 + .../org/apache/spark/sql/DataFrame.scala | 11 + .../scala/org/apache/spark/sql/Dataset.scala | 392 ++++++++++++++++++ .../org/apache/spark/sql/DatasetHolder.scala | 30 ++ .../org/apache/spark/sql/GroupedDataset.scala | 68 +++ .../org/apache/spark/sql/SQLContext.scala | 12 + .../org/apache/spark/sql/SQLImplicits.scala | 16 +- .../spark/sql/execution/GroupedIterator.scala | 141 +++++++ .../spark/sql/execution/SparkStrategies.scala | 8 + .../spark/sql/execution/basicOperators.scala | 79 ++++ .../spark/sql/DatasetPrimitiveSuite.scala | 103 +++++ .../org/apache/spark/sql/DatasetSuite.scala | 124 ++++++ .../org/apache/spark/sql/QueryTest.scala | 8 + 26 files changed, 1501 insertions(+), 23 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 27c96f41221ad..713c6b547d9b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -411,9 +411,9 @@ trait ScalaReflection { } /** Returns expressions for extracting all the fields from the given type. */ - def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { + def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children + extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct] } } @@ -497,11 +497,11 @@ trait ScalaReflection { } } - CreateStruct(params.head.map { p => + CreateNamedStruct(params.head.flatMap { p => val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - extractorFor(fieldValue, fieldType) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil }) case t if t <:< localTypeOf[Array[_]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala index 54096f18cbea1..b484b8fde6369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -41,9 +41,11 @@ case class ClassEncoder[T]( clsTag: ClassTag[T]) extends Encoder[T] { - private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) private val inputRow = new GenericMutableRow(1) + @transient private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) private val dataType = ObjectType(clsTag.runtimeClass) @@ -64,4 +66,36 @@ case class ClassEncoder[T]( copy(constructExpression = boundExpression) } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(oldSchema) + val attributeToNewPosition = AttributeMap.byIndex(newSchema) + copy(constructExpression = constructExpression transform { + case r: BoundReference => + r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) + }) + } + + override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = { + var remaining = schema + copy(constructExpression = constructExpression transform { + case u: UnresolvedAttribute => + val pos = remaining.head + remaining = remaining.drop(1) + pos + }) + } + + protected val attrs = extractExpressions.map(_.collect { + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }.headOption.getOrElse("")) + + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index bdb1c0959da87..efb872ddb81e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * and reuse internal buffers to improve performance. */ trait Encoder[T] { + /** Returns the schema of encoding this type of object as a Row. */ def schema: StructType @@ -46,13 +47,27 @@ trait Encoder[T] { /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must bind the encoder to a specific schema before you can call this function. + * you must `bind` an encoder to a specific schema before you can call this function. */ def fromRow(row: InternalRow): T /** * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the - * given schema + * given schema. */ def bind(schema: Seq[Attribute]): Encoder[T] + + /** + * Binds this encoder to the given schema positionally. In this binding, the first reference to + * any input is mapped to `schema(0)`, and so on for each input that is encountered. + */ + def bindOrdinals(schema: Seq[Attribute]): Encoder[T] + + /** + * Given an encoder that has already been bound to a given schema, returns a new encoder that + * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, + * when you are trying to use an encoder on grouping keys that were orriginally part of a larger + * row, but now you have projected out only the key expressions. + */ + def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 4f7ce455ada99..34f5e6c030f58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,15 +31,17 @@ import org.apache.spark.sql.types.{ObjectType, StructType} object ProductEncoder { def apply[T <: Product : TypeTag]: ClassEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. - val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) + val extractExpression = ScalaReflection.extractorsFor[T](inputObject) val constructExpression = ScalaReflection.constructorFor[T] - new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) - } - + new ClassEncoder[T]( + extractExpression.dataType, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala new file mode 100644 index 0000000000000..a93f2d7c6115d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ + +/** An encoder for primitive Long types. */ +case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] { + private val row = UnsafeRow.createFromByteArray(64, 1) + + override def clsTag: ClassTag[Long] = ClassTag.Long + override def schema: StructType = + StructType(StructField(fieldName, LongType) :: Nil) + + override def fromRow(row: InternalRow): Long = row.getLong(ordinal) + + override def toRow(t: Long): InternalRow = { + row.setLong(ordinal, t) + row + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this + override def bind(schema: Seq[Attribute]): Encoder[Long] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this +} + +/** An encoder for primitive Integer types. */ +case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] { + private val row = UnsafeRow.createFromByteArray(64, 1) + + override def clsTag: ClassTag[Int] = ClassTag.Int + override def schema: StructType = + StructType(StructField(fieldName, IntegerType) :: Nil) + + override def fromRow(row: InternalRow): Int = row.getInt(ordinal) + + override def toRow(t: Int): InternalRow = { + row.setInt(ordinal, t) + row + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this + override def bind(schema: Seq[Attribute]): Encoder[Int] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this +} + +/** An encoder for String types. */ +case class StringEncoder( + fieldName: String = "value", + ordinal: Int = 0) extends Encoder[String] { + + val record = new SpecificMutableRow(StringType :: Nil) + + @transient + lazy val projection = + GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil) + + override def schema: StructType = + StructType( + StructField("value", StringType, nullable = false) :: Nil) + + override def clsTag: ClassTag[String] = scala.reflect.classTag[String] + + + override final def fromRow(row: InternalRow): String = { + row.getString(ordinal) + } + + override final def toRow(value: String): InternalRow = { + val utf8String = UTF8String.fromString(value) + record(0) = utf8String + // TODO: this is a bit of a hack to produce UnsafeRows + projection(record) + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this + override def bind(schema: Seq[Attribute]): Encoder[String] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala new file mode 100644 index 0000000000000..a48eeda7d2e6f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructField, StructType} + +// Most of this file is codegen. +// scalastyle:off + +/** + * A set of composite encoders that take sub encoders and map each of their objects to a + * Scala tuple. Note that currently the implementation is fairly limited and only supports going + * from an internal row to a tuple. + */ +object TupleEncoder { + + /** Code generator for composite tuple encoders. */ + def main(args: Array[String]): Unit = { + (2 to 5).foreach { i => + val types = (1 to i).map(t => s"T$t").mkString(", ") + val tupleType = s"($types)" + val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ") + val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ") + val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ") + + println( + s""" + |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] { + | val schema = StructType(Array($fields)) + | + | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType] + | + | def fromRow(row: InternalRow): $tupleType = { + | ($fromRow) + | } + | + | override def toRow(t: $tupleType): InternalRow = + | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + | + | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = { + | this + | } + | + | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] = + | throw new UnsupportedOperationException("Tuple Encoders only support bind.") + | + | + | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] = + | throw new UnsupportedOperationException("Tuple Encoders only support bind.") + |} + """.stripMargin) + } + } +} + +class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema))) + + def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)] + + def fromRow(row: InternalRow): (T1, T2) = { + (e1.fromRow(row), e2.fromRow(row)) + } + + override def toRow(t: (T1, T2)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema))) + + def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)] + + def fromRow(row: InternalRow): (T1, T2, T3) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema))) + + def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)] + + def fromRow(row: InternalRow): (T1, T2, T3, T4) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3, T4)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema))) + + def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)] + + def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 96a11e352ec50..ef3cc554b79c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,6 +26,13 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } + + /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */ + def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex) + + /** Given a schema, constructs a map from ordinal to Attribute. */ + def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] = + schema.zipWithIndex.map { case (a, i) => i -> a }.toMap } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 5345696570b41..3831535574205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -31,6 +31,10 @@ protected class AttributeEquals(val a: Attribute) { } object AttributeSet { + /** Returns an empty [[AttributeSet]]. */ + val empty = apply(Iterable.empty) + + /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a5f02e2463aed..059e45bd684ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -125,6 +125,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { */ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + /** + * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 30b7f8d3766a5..f1fa13daa77eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructField, StructType} /** * A set of classes that can be used to represent trees of relational expressions. A key goal of @@ -80,4 +81,15 @@ package object expressions { /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection } + + + /** + * Helper functions for working with `Seq[Attribute]`. + */ + implicit class AttributeSeq(attrs: Seq[Attribute]) { + /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ + def toStructType: StructType = { + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ae9482c10f126..21a55a5371841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ @@ -417,7 +418,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { } /** - * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer * of the output requires some specific ordering or distribution of the data. @@ -443,3 +444,72 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } +/** + * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are + * used respectively to decode/encode from the JVM object representation expected by `func.` + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} + +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumn { + def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { + val attrs = implicitly[Encoder[U]].schema.toAttributes + new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to + * decode/encode from the JVM object representation expected by `func.` + */ +case class AppendColumn[T, U]( + func: T => U, + tEncoder: Encoder[T], + uEncoder: Encoder[U], + newColumns: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ newColumns + override def missingInput: AttributeSet = super.missingInput -- newColumns +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K : Encoder, T : Encoder, U : Encoder]( + func: (K, Iterator[T]) => Iterator[U], + groupingAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups[K, T, U] = { + new MapGroups( + func, + implicitly[Encoder[K]], + implicitly[Encoder[T]], + implicitly[Encoder[U]], + groupingAttributes, + implicitly[Encoder[U]].schema.toAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => Iterator[U], + kEncoder: Encoder[K], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala new file mode 100644 index 0000000000000..52f8383faca92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite + +class PrimitiveEncoderSuite extends SparkFunSuite { + test("long encoder") { + val enc = new LongEncoder() + val row = enc.toRow(10) + assert(row.getLong(0) == 10) + assert(enc.fromRow(row) == 10) + } + + test("int encoder") { + val enc = new IntEncoder() + val row = enc.toRow(10) + assert(row.getInt(0) == 10) + assert(enc.fromRow(row) == 10) + } + + test("string encoder") { + val enc = new StringEncoder() + val row = enc.toRow("test") + assert(row.getString(0) == "test") + assert(enc.fromRow(row) == "test") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index 02e43ddb35478..7735acbcbad41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -248,12 +248,16 @@ class ProductEncoderSuite extends SparkFunSuite { val types = convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") - val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq - case (other, _) => - other - }.mkString("[", ",", "]") + val encodedData = try { + convertedData.toSeq(encoder.schema).zip(encoder.schema).map { + case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => + a.toArray[Any](at.elementType).toSeq + case (other, _) => + other + }.mkString("[", ",", "]") + } catch { + case e: Throwable => s"Failed to toSeq: $e" + } fail( s"""Encoded/Decoded data does not match input data @@ -272,8 +276,9 @@ class ProductEncoderSuite extends SparkFunSuite { |Construct Expressions: |${boundEncoder.constructExpression.treeString} | - """.stripMargin) + """.stripMargin) + } } - } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 37d559c8e4301..de11a1699afd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -36,6 +38,11 @@ private[sql] object Column { def unapply(col: Column): Option[Expression] = Some(col.expr) } +/** + * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * @since 1.6.0 + */ +class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) /** * :: Experimental :: @@ -69,6 +76,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode + /** + * Provides a type hint about the expected return value of this column. This information can + * be used by operations such as `select` on a [[Dataset]] to automatically convert the + * results into the correct JVM types. + * @since 1.6.0 + */ + def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + /** * Extracts a value or values from a complex type. * The following types of extraction are supported: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2f10aa9f3c446..bf25bcde208e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -258,6 +259,16 @@ class DataFrame private[sql]( // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = this + /** + * :: Experimental :: + * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the + * specified type, `U`. + * @group basic + * @since 1.6.0 + */ + @Experimental + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution) + /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala new file mode 100644 index 0000000000000..96213c7630400 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.StructType + +/** + * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel + * using functional or relational operations. + * + * A [[Dataset]] differs from an [[RDD]] in the following ways: + * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored + * in the encoded form. This representation allows for additional logical operations and + * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to + * an object. + * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be + * used to serialize the object into a binary format. Encoders are also capable of mapping the + * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime + * reflection based serialization. Operations that change the type of object stored in the + * dataset also need an encoder for the new type. + * + * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific + * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into + * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed + * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`. + * + * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However, + * making this change to the class hierarchy would break the function signatures for the existing + * functional operations (map, flatMap, etc). As such, this class should be considered a preview + * of the final API. Changes will be made to the interface after Spark 1.6. + * + * @since 1.6.0 + */ +@Experimental +class Dataset[T] private[sql]( + @transient val sqlContext: SQLContext, + @transient val queryExecution: QueryExecution)( + implicit val encoder: Encoder[T]) extends Serializable { + + private implicit def classTag = encoder.clsTag + + private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = + this(sqlContext, new QueryExecution(sqlContext, plan)) + + /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ + def schema: StructType = encoder.schema + + /* ************* * + * Conversions * + * ************* */ + + /** + * Returns a new `Dataset` where each record has been mapped on to the specified type. + * TODO: should bind here... + * TODO: document binding rules + * @since 1.6.0 + */ + def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]]) + + /** + * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have + * the same name after two Datasets have been joined. + */ + def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) + + /** + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] + * objects that allow fields to be accessed by ordinal or name. + */ + def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) + + + /** + * Returns this Dataset. + * @since 1.6.0 + */ + def toDS(): Dataset[T] = this + + /** + * Converts this Dataset to an RDD. + * @since 1.6.0 + */ + def rdd: RDD[T] = { + val tEnc = implicitly[Encoder[T]] + val input = queryExecution.analyzed.output + queryExecution.toRdd.mapPartitions { iter => + val bound = tEnc.bind(input) + iter.map(bound.fromRow) + } + } + + /* *********************** * + * Functional Operations * + * *********************** */ + + /** + * Concise syntax for chaining custom transformations. + * {{{ + * def featurize(ds: Dataset[T]) = ... + * + * dataset + * .transform(featurize) + * .transform(...) + * }}} + * + * @since 1.6.0 + */ + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) + + /** + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + + /** + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + + /** + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + new Dataset( + sqlContext, + MapPartitions[T, U]( + func, + implicitly[Encoder[T]], + implicitly[Encoder[U]], + implicitly[Encoder[U]].schema.toAttributes, + logicalPlan)) + } + + def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = + mapPartitions(_.flatMap(func)) + + /* ************** * + * Side effects * + * ************** */ + + /** + * Runs `func` on each element of this Dataset. + * @since 1.6.0 + */ + def foreach(func: T => Unit): Unit = rdd.foreach(func) + + /** + * Runs `func` on each partition of this Dataset. + * @since 1.6.0 + */ + def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + + /* ************* * + * Aggregation * + * ************* */ + + /** + * Reduces the elements of this Dataset using the specified binary function. The given function + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: (T, T) => T): T = rdd.reduce(func) + + /** + * Aggregates the elements of each partition, and then the results for all the partitions, using a + * given associative and commutative function and a neutral "zero value". + * + * This behaves somewhat differently than the fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then those results will be folded into the final result. + * If op is not commutative, then the result may differ from that of a fold applied to a + * non-distributed collection. + * @since 1.6.0 + */ + def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) + + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * @since 1.6.0 + */ + def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { + val inputPlan = queryExecution.analyzed + val withGroupingKey = AppendColumn(func, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedDataset( + implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns), + implicitly[Encoder[T]].bind(inputPlan.output), + executed, + inputPlan.output, + withGroupingKey.newColumns) + } + + /* ****************** * + * Typed Relational * + * ****************** */ + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(e[Int]("value + 1")) + * }}} + * @since 1.6.0 + */ + def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { + new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + } + + // Codegen + // scalastyle:off + + /** sbt scalaShell; println(Seq(1).toDS().genSelect) */ + private def genSelect: String = { + (2 to 5).map { n => + val types = (1 to n).map(i =>s"U$i").mkString(", ") + val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ") + val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ") + val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ") + s""" + |/** + | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + | * @since 1.6.0 + | */ + |def select[$types]($args): Dataset[($types)] = { + | implicit val te = new Tuple${n}Encoder($encoders) + | new Dataset[($types)](sqlContext, + | Project( + | $schema :: Nil, + | logicalPlan)) + |} + | + """.stripMargin + }.mkString("\n") + } + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = { + implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder) + new Dataset[(U1, U2)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = { + implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder) + new Dataset[(U1, U2, U3)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = { + implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder) + new Dataset[(U1, U2, U3, U4)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = { + implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder) + new Dataset[(U1, U2, U3, U4, U5)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil, + logicalPlan)) + } + + // scalastyle:on + + /* **************** * + * Set operations * + * **************** */ + + /** + * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]]. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def distinct: Dataset[T] = withPlan(Distinct) + + /** + * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also + * present in `other`. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def intersect(other: Dataset[T]): Dataset[T] = + withPlan[T](other)(Intersect) + + /** + * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * @since 1.6.0 + */ + def union(other: Dataset[T]): Dataset[T] = + withPlan[T](other)(Union) + + /** + * Returns a new [[Dataset]] where any elements present in `other` have been removed. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + + /* ************************** * + * Gather to Driver Actions * + * ************************** */ + + /** Returns the first element in this [[Dataset]]. */ + def first(): T = rdd.first() + + /** Collects the elements to an Array. */ + def collect(): Array[T] = rdd.collect() + + /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + def take(num: Int): Array[T] = rdd.take(num) + + /* ******************** * + * Internal Functions * + * ******************** */ + + private[sql] def logicalPlan = queryExecution.analyzed + + private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan))) + + private[sql] def withPlan[R : Encoder]( + other: Dataset[_])( + f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = + new Dataset[R]( + sqlContext, + sqlContext.executePlan( + f(logicalPlan, other.logicalPlan))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala new file mode 100644 index 0000000000000..17817cbcc5e05 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -0,0 +1,30 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * A container for a [[DataFrame]], used for implicit conversions. + * + * @since 1.3.0 + */ +private[sql] case class DatasetHolder[T](df: Dataset[T]) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDS(): Dataset[T] = df +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala new file mode 100644 index 0000000000000..89a16dd8b0acc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing + * [[Dataset]]. + */ +class GroupedDataset[K, T] private[sql]( + private val kEncoder: Encoder[K], + private val tEncoder: Encoder[T], + queryExecution: QueryExecution, + private val dataAttributes: Seq[Attribute], + private val groupingAttributes: Seq[Attribute]) extends Serializable { + + private implicit def kEnc = kEncoder + private implicit def tEnc = tEncoder + private def logicalPlan = queryExecution.analyzed + private def sqlContext = queryExecution.sqlContext + + /** + * Returns a [[Dataset]] that contains each unique key. + */ + def keys: Dataset[K] = { + new Dataset[K]( + sqlContext, + Distinct( + Project(groupingAttributes, logicalPlan))) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapGroups(f, groupingAttributes, logicalPlan)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a107639947aa2..5e7198f974389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,6 +21,7 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference + import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -33,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -487,6 +489,16 @@ class SQLContext private[sql]( DataFrame(this, logicalPlan) } + + def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { + val enc = implicitly[Encoder[T]] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(attributes, encoded) + + new Dataset[T](this, plan) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index bf03c61088426..af8474df0de80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation + import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -30,9 +34,19 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. */ -private[sql] abstract class SQLImplicits { +abstract class SQLImplicits { protected def _sqlContext: SQLContext + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + + implicit def newIntEncoder: Encoder[Int] = new IntEncoder() + implicit def newLongEncoder: Encoder[Long] = new LongEncoder() + implicit def newStringEncoder: Encoder[String] = new StringEncoder() + + implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { + DatasetHolder(_sqlContext.createDataset(s)) + } + /** * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. * @since 1.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala new file mode 100644 index 0000000000000..10742cf7348f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} + +object GroupedIterator { + def apply( + input: Iterator[InternalRow], + keyExpressions: Seq[Expression], + inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { + if (input.hasNext) { + new GroupedIterator(input, keyExpressions, inputSchema) + } else { + Iterator.empty + } + } +} + +/** + * Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to + * next will return a pair containing the current group and an iterator that will return all the + * elements of that group. Iterators for each group are lazily constructed by extracting rows + * from the input iterator. As such, full groups are never materialized by this class. + * + * Example input: + * {{{ + * Input: [a, 1], [b, 2], [b, 3] + * Grouping: x#1 + * InputSchema: x#1, y#2 + * }}} + * + * Result: + * {{{ + * First call to next(): ([a], Iterator([a, 1]) + * Second call to next(): ([b], Iterator([b, 2], [b, 3]) + * }}} + * + * Note, the class does not handle the case of an empty input for simplicity of implementation. + * Use the factory to construct a new instance. + * + * @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or + * it is possible for the same group to appear more than once. + * @param groupingExpressions The set of expressions used to do grouping. The result of evaluating + * these expressions will be returned as the first part of each call + * to `next()`. + * @param inputSchema The schema of the rows in the `input` iterator. + */ +class GroupedIterator private( + input: Iterator[InternalRow], + groupingExpressions: Seq[Expression], + inputSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow])] { + + /** Compares two input rows and returns 0 if they are in the same group. */ + val sortOrder = groupingExpressions.map(SortOrder(_, Ascending)) + val keyOrdering = GenerateOrdering.generate(sortOrder, inputSchema) + + /** Creates a row containing only the key for a given input row. */ + val keyProjection = GenerateUnsafeProjection.generate(groupingExpressions, inputSchema) + + /** + * Holds null or the row that will be returned on next call to `next()` in the inner iterator. + */ + var currentRow = input.next() + + /** Holds a copy of an input row that is in the current group. */ + var currentGroup = currentRow.copy() + var currentIterator: Iterator[InternalRow] = null + assert(keyOrdering.compare(currentGroup, currentRow) == 0) + + // Return true if we already have the next iterator or fetching a new iterator is successful. + def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator + + def next(): (InternalRow, Iterator[InternalRow]) = { + assert(hasNext) // Ensure we have fetched the next iterator. + val ret = (keyProjection(currentGroup), currentIterator) + currentIterator = null + ret + } + + def fetchNextGroupIterator(): Boolean = { + if (currentRow != null || input.hasNext) { + val inputIterator = new Iterator[InternalRow] { + // Return true if we have a row and it is in the current group, or if fetching a new row is + // successful. + def hasNext = { + (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) || + fetchNextRowInGroup() + } + + def fetchNextRowInGroup(): Boolean = { + if (currentRow != null || input.hasNext) { + currentRow = input.next() + if (keyOrdering.compare(currentGroup, currentRow) == 0) { + // The row is in the current group. Continue the inner iterator. + true + } else { + // We got a row, but its not in the right group. End this inner iterator and prepare + // for the next group. + currentIterator = null + currentGroup = currentRow.copy() + false + } + } else { + // There is no more input so we are done. + false + } + } + + def next(): InternalRow = { + assert(hasNext) // Ensure we have fetched the next row. + val res = currentRow + currentRow = null + res + } + } + currentIterator = inputIterator + true + } else { + false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79bd1a41808de..637deff4e2202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -372,6 +372,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + + case logical.MapPartitions(f, tEnc, uEnc, output, child) => + execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil + case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil + case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => + execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index dc38fe59feed5..2bb3dba5bd2ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.MutablePair @@ -311,3 +313,80 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val tBoundEncoder = tEncoder.bind(child.output) + func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) + } + } +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns[T, U]( + func: T => U, + tEncoder: Encoder[T], + uEncoder: Encoder[U], + newColumns: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val tBoundEncoder = tEncoder.bind(child.output) + val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) + iter.map { row => + val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => Iterator[U], + kEncoder: Encoder[K], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val groupKeyEncoder = kEncoder.bind(groupingAttributes) + + grouped.flatMap { case (key, rowIter) => + val result = func( + groupKeyEncoder.fromRow(key), + rowIter.map(tEncoder.fromRow)) + result.map(uEncoder.toRow) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala new file mode 100644 index 0000000000000..32443557fb8e0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +case class IntClass(value: Int) + +class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(1, 2, 3, 4, 5, 6) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(1, 2, 3).toDS().as[IntClass] + checkAnswer( + ds, + IntClass(1), IntClass(2), IntClass(3)) + + assert(ds.collect().head == IntClass(1)) + } + + test("map") { + val ds = Seq(1, 2, 3).toDS() + checkAnswer( + ds.map(_ + 1), + 2, 3, 4) + } + + test("filter") { + val ds = Seq(1, 2, 3, 4).toDS() + checkAnswer( + ds.filter(_ % 2 == 0), + 2, 4) + } + + test("foreach") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(acc +=) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(acc +=)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.reduce(_ + _) == 6) + } + + test("fold") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.fold(0)(_ + _) == 6) + } + + test("groupBy function, keys") { + val ds = Seq(1, 2, 3, 4, 5).toDS() + val grouped = ds.groupBy(_ % 2) + checkAnswer( + grouped.keys, + 0, 1) + } + + test("groupBy function, mapGroups") { + val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() + val grouped = ds.groupBy(_ % 2) + val agged = grouped.mapGroups { case (g, iter) => + val name = if (g == 0) "even" else "odd" + Iterator((name, iter.size)) + } + + checkAnswer( + agged, + ("even", 5), ("odd", 6)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala new file mode 100644 index 0000000000000..08496249c60cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +case class ClassData(a: String, b: Int) + +class DatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + checkAnswer( + ds, + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + assert(ds.collect().head == ClassData("a", 1)) + } + + test("as case class - reordered fields by name") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) + } + + test("map") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.map(v => (v._1, v._2 + 1)), + ("a", 2), ("b", 3), ("c", 4)) + } + + test("select") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select(expr("_2 + 1").as[Int]), + 2, 3, 4) + } + + test("select 3") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("_2").as[Int], + expr("_2 + 1").as[Int]), + ("a", 1, 2), ("b", 2, 3), ("c", 3, 4)) + } + + test("filter") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.filter(_._1 == "b"), + ("b", 2)) + } + + test("foreach") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(v => acc += v._2) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(v => acc += v._2)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("fold") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("groupBy function, keys") { + val ds = Seq(("a", 1), ("b", 1)).toDS() + val grouped = ds.groupBy(v => (1, v._2)) + checkAnswer( + grouped.keys, + (1, 1)) + } + + test("groupBy function, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g._1, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index e3c5a426671d0..aba567512fe32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder} abstract class QueryTest extends PlanTest { @@ -53,6 +55,12 @@ abstract class QueryTest extends PlanTest { } } + protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = { + checkAnswer( + ds.toDF(), + sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param df the [[DataFrame]] to be executed From 163d53e829c166f061589cc379f61642d4c9a40f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Lipt=C3=A1k?= Date: Thu, 22 Oct 2015 15:27:11 -0700 Subject: [PATCH 011/324] [SPARK-7021] Add JUnit output for Python unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WIP Author: Gábor Lipták Closes #8323 from gliptak/SPARK-7021. --- python/pyspark/ml/tests.py | 9 ++++++++- python/pyspark/mllib/tests.py | 9 ++++++++- python/pyspark/sql/tests.py | 9 ++++++++- python/pyspark/streaming/tests.py | 11 ++++++++++- python/pyspark/tests.py | 19 ++++++++++++++----- 5 files changed, 48 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6a2577d66f287..7a16cf52cccb2 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -20,6 +20,10 @@ """ import sys +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -368,4 +372,7 @@ def test_fit_maximize_metric(self): if __name__ == "__main__": - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 2ad69a0ab1d3d..f8e8e0e0adbea 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -31,6 +31,10 @@ from numpy import sum as array_sum from py4j.protocol import Py4JJavaError +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version > '3': basestring = str @@ -1538,7 +1542,10 @@ def test_load_vectors(self): if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f465e1fa20941..6356d4bd6669b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -31,6 +31,10 @@ import datetime import py4j +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -1222,4 +1226,7 @@ def test_window_functions_without_partitionBy(self): if __name__ == "__main__": - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 49634252fd465..2c908daa8b214 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -27,6 +27,11 @@ import shutil from functools import reduce +try: + import xmlrunner +except ImportError: + xmlrunner = None + if sys.version_info[:2] <= (2, 6): try: import unittest2 as unittest @@ -1303,4 +1308,8 @@ def search_kinesis_asl_assembly_jar(): for testcase in testcases: sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - unittest.TextTestRunner(verbosity=3).run(tests) + if xmlrunner: + unittest.main(tests, verbosity=3, + testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3c51809444401..5bd94476597ab 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -35,6 +35,10 @@ import hashlib from py4j.protocol import Py4JJavaError +try: + import xmlrunner +except ImportError: + xmlrunner = None if sys.version_info[:2] <= (2, 6): try: @@ -249,10 +253,12 @@ def __getattr__(self, item): # Regression test for SPARK-3415 def test_pickling_file_handles(self): - ser = CloudPickleSerializer() - out1 = sys.stderr - out2 = ser.loads(ser.dumps(out1)) - self.assertEqual(out1, out2) + # to be corrected with SPARK-11160 + if not xmlrunner: + ser = CloudPickleSerializer() + out1 = sys.stderr + out2 = ser.loads(ser.dumps(out1)) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -2006,7 +2012,10 @@ def test_statcounter_array(self): print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: print("NOTE: Skipping NumPy tests as it does not seem to be installed") - unittest.main() + if xmlrunner: + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + else: + unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: From 34e71c6d89c1f2b6236dbf0d75cd12da08003c84 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 22 Oct 2015 15:58:08 -0700 Subject: [PATCH 012/324] [SPARK-11251] Fix page size calculation in local mode ``` // My machine only has 8 cores $ bin/spark-shell --master local[32] scala> val df = sc.parallelize(Seq((1, 1), (2, 2))).toDF("a", "b") scala> df.as("x").join(df.as("y"), $"x.a" === $"y.a").count() Caused by: java.io.IOException: Unable to acquire 2097152 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPage(UnsafeExternalSorter.java:351) ``` Author: Andrew Or Closes #9209 from andrewor14/fix-local-page-size. --- .../scala/org/apache/spark/SparkContext.scala | 48 ++++++++++++++----- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../OutputCommitCoordinatorSuite.scala | 3 +- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ccba3ed9e643c..a6857b4c7d882 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -269,7 +269,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) } private[spark] def env: SparkEnv = _env @@ -2560,6 +2560,21 @@ object SparkContext extends Logging { res } + /** + * The number of driver cores to use for execution in local mode, 0 otherwise. + */ + private[spark] def numDriverCores(master: String): Int = { + def convertToInt(threads: String): Int = { + if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + } + master match { + case "local" => 1 + case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) + case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) + case _ => 0 // driver is not used for execution + } + } + /** * Create a task scheduler based on a given master URL. * Return a 2-tuple of the scheduler backend and the task scheduler. @@ -2567,18 +2582,7 @@ object SparkContext extends Logging { private def createTaskScheduler( sc: SparkContext, master: String): (SchedulerBackend, TaskScheduler) = { - // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r + import SparkMasterRegex._ // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 @@ -2719,6 +2723,24 @@ object SparkContext extends Logging { } } +/** + * A collection of regexes for extracting information from the master string. + */ +private object SparkMasterRegex { + // Regular expression used for local[N] and local[*] master formats + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r +} + /** * A class encapsulating how to convert some type T to Writable. It stores both the Writable class * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 704158bfc7643..b5c35c569e45f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -190,6 +190,7 @@ object SparkEnv extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus, + numCores: Int, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") @@ -202,6 +203,7 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, + numUsableCores = numCores, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -241,8 +243,8 @@ object SparkEnv extends Logging { port: Int, isDriver: Boolean, isLocal: Boolean, + numUsableCores: Int, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 6d08d7c5b7d2a..48456a9cd6e7b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -87,7 +87,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { outputCommitCoordinator = spy(new OutputCommitCoordinator(conf, isDriver = true)) // Use Mockito.spy() to maintain the default infrastructure everywhere else. // This mocking allows us to control the coordinator responses in test cases. - SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, + SparkContext.numDriverCores(master), Some(outputCommitCoordinator)) } } // Use Mockito.spy() to maintain the default infrastructure everywhere else From a88c66ca8780c7228dc909f904d31cd9464ee0e3 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 22 Oct 2015 21:01:01 -0700 Subject: [PATCH 013/324] [SPARK-11098][CORE] Add Outbox to cache the sending messages to resolve the message disorder issue The current NettyRpc has a message order issue because it uses a thread pool to send messages. E.g., running the following two lines in the same thread, ``` ref.send("A") ref.send("B") ``` The remote endpoint may see "B" before "A" because sending "A" and "B" are in parallel. To resolve this issue, this PR added an outbox for each connection, and if we are connecting to the remote node when sending messages, just cache the sending messages in the outbox and send them one by one when the connection is established. Author: zsxwing Closes #9197 from zsxwing/rpc-outbox. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 145 +++++++----- .../org/apache/spark/rpc/netty/Outbox.scala | 222 ++++++++++++++++++ 2 files changed, 310 insertions(+), 57 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e01cf1a29e95b..284284eb805b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv( // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool // to implement non-blocking send/ask. // TODO: a non-blocking TransportClientFactory.createClient in future - private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", conf.getInt("spark.rpc.connect.threads", 64)) @volatile private var server: TransportServer = _ + private val stopped = new AtomicBoolean(false) + + /** + * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], + * we just put messages to its [[Outbox]] to implement a non-blocking `send` method. + */ + private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]() + + /** + * Remove the address's Outbox and stop it. + */ + private[netty] def removeOutbox(address: RpcAddress): Unit = { + val outbox = outboxes.remove(address) + if (outbox != null) { + outbox.stop() + } + } + def start(port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { @@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv( dispatcher.stop(endpointRef) } + private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = { + val targetOutbox = { + val outbox = outboxes.get(address) + if (outbox == null) { + val newOutbox = new Outbox(this, address) + val oldOutbox = outboxes.putIfAbsent(address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } + } else { + outbox + } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(address) + targetOutbox.stop() + } else { + targetOutbox.send(message) + } + } + private[netty] def send(message: RequestMessage): Unit = { val remoteAddr = message.receiver.address if (remoteAddr == address) { @@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv( val ack = response.asInstanceOf[Ack] logTrace(s"Received ack from ${ack.sender}") case Failure(e) => - logError(s"Exception when sending $message", e) + logWarning(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { // Message to a remote RPC endpoint. - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - logError(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[Ack](response) - logDebug(s"Receive ack from ${ack.sender}") - } - }) - } - }) - } catch { - case e: RejectedExecutionException => - // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send $message because RpcEnv is stopped") - } + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[Ack](response) + logDebug(s"Receive ack from ${ack.sender}") + } + })) } } + private[netty] def createClient(address: RpcAddress): TransportClient = { + clientFactory.createClient(address.host, address.port) + } + private[netty] def ask(message: RequestMessage): Future[Any] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address @@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv( } }(ThreadUtils.sameThread) } else { - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - } - - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - } - }) - } - }) - } catch { - case e: RejectedExecutionException => + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { if (!promise.tryFailure(e)) { - logWarning(s"Ignore failure", e) + logWarning("Ignore Exception", e) } - } + } + + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + } + })) } promise.future } @@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv( } private def cleanup(): Unit = { + if (!stopped.compareAndSet(false, true)) { + return + } + + val iter = outboxes.values().iterator() + while (iter.hasNext()) { + val outbox = iter.next() + outboxes.remove(outbox.address) + outbox.stop() + } if (timeoutScheduler != null) { timeoutScheduler.shutdownNow() } @@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) val messageOpt: Option[RemoteProcessDisconnected] = synchronized { remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala new file mode 100644 index 0000000000000..7d9d593b36241 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.Callable +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.rpc.RpcAddress + +private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback) + +private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { + + outbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + private val messages = new java.util.LinkedList[OutboxMessage] + + @GuardedBy("this") + private var client: TransportClient = null + + /** + * connectFuture points to the connect task. If there is no connect task, connectFuture will be + * null. + */ + @GuardedBy("this") + private var connectFuture: java.util.concurrent.Future[Unit] = null + + @GuardedBy("this") + private var stopped = false + + /** + * If there is any thread draining the message queue + */ + @GuardedBy("this") + private var draining = false + + /** + * Send a message. If there is no active connection, cache it and launch a new connection. If + * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]]. + */ + def send(message: OutboxMessage): Unit = { + val dropped = synchronized { + if (stopped) { + true + } else { + messages.add(message) + false + } + } + if (dropped) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + } else { + drainOutbox() + } + } + + /** + * Drain the message queue. If there is other draining thread, just exit. If the connection has + * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the + * connection. + */ + private def drainOutbox(): Unit = { + var message: OutboxMessage = null + synchronized { + if (stopped) { + return + } + if (connectFuture != null) { + // We are connecting to the remote address, so just exit + return + } + if (client == null) { + // There is no connect task but client is null, so we need to launch the connect task. + launchConnectTask() + return + } + if (draining) { + // There is some thread draining, so just exit + return + } + message = messages.poll() + if (message == null) { + return + } + draining = true + } + while (true) { + try { + val _client = synchronized { client } + if (_client != null) { + _client.sendRpc(message.content, message.callback) + } else { + assert(stopped == true) + } + } catch { + case NonFatal(e) => + handleNetworkFailure(e) + return + } + synchronized { + if (stopped) { + return + } + message = messages.poll() + if (message == null) { + draining = false + return + } + } + } + } + + private def launchConnectTask(): Unit = { + connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] { + + override def call(): Unit = { + try { + val _client = nettyEnv.createClient(address) + outbox.synchronized { + client = _client + if (stopped) { + closeClient() + } + } + } catch { + case ie: InterruptedException => + // exit + return + case NonFatal(e) => + outbox.synchronized { connectFuture = null } + handleNetworkFailure(e) + return + } + outbox.synchronized { connectFuture = null } + // It's possible that no thread is draining now. If we don't drain here, we cannot send the + // messages until the next message arrives. + drainOutbox() + } + }) + } + + /** + * Stop [[Inbox]] and notify the waiting messages with the cause. + */ + private def handleNetworkFailure(e: Throwable): Unit = { + synchronized { + assert(connectFuture == null) + if (stopped) { + return + } + stopped = true + closeClient() + } + // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along + // with a new connection + nettyEnv.removeOutbox(address) + + // Notify the connection failure for the remaining messages + // + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(e) + message = messages.poll() + } + assert(messages.isEmpty) + } + + private def closeClient(): Unit = synchronized { + // Not sure if `client.close` is idempotent. Just for safety. + if (client != null) { + client.close() + } + client = null + } + + /** + * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a + * [[SparkException]]. + */ + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + if (connectFuture != null) { + connectFuture.cancel(true) + } + closeClient() + } + + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message = messages.poll() + } + } +} From fa6a4fbf08c8cca36cbe9f0d2bd20bc7be2ca45d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 22 Oct 2015 22:41:21 -0700 Subject: [PATCH 014/324] [SPARK-11134][CORE] Increase LauncherBackendSuite timeout. This test can take a little while to finish on slow / loaded machines. Author: Marcelo Vanzin Closes #9235 from vanzin/SPARK-11134. --- .../org/apache/spark/launcher/LauncherBackendSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala index 07e8869833e95..639d1daa36c73 100644 --- a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -54,13 +54,13 @@ class LauncherBackendSuite extends SparkFunSuite with Matchers { .startApplication() try { - eventually(timeout(10 seconds), interval(100 millis)) { + eventually(timeout(30 seconds), interval(100 millis)) { handle.getAppId() should not be (null) } handle.stop() - eventually(timeout(10 seconds), interval(100 millis)) { + eventually(timeout(30 seconds), interval(100 millis)) { handle.getState() should be (SparkAppHandle.State.KILLED) } } finally { From b1c1597e3c47f1912809f3c5ab21833fa4241b54 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 22 Oct 2015 22:42:15 -0700 Subject: [PATCH 015/324] Fix a (very tiny) typo Author: Jacek Laskowski Closes #9230 from jaceklaskowski/utils-seconds-typo. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 55950405f0488..5a976ee839b1e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -952,7 +952,7 @@ private[spark] object Utils extends Logging { } /** - * Convert a time parameter such as (50s, 100ms, or 250us) to microseconds for internal use. If + * Convert a time parameter such as (50s, 100ms, or 250us) to seconds for internal use. If * no suffix is provided, the passed number is assumed to be in seconds. */ def timeStringAsSeconds(str: String): Long = { From cdea0174e32a5f4c28fd59899b2e9774994303d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 23 Oct 2015 00:00:21 -0700 Subject: [PATCH 016/324] [SPARK-11273][SQL] Move ArrayData/MapData/DataTypeParser to catalyst.util package Author: Reynold Xin Closes #9239 from rxin/types-private. --- .../apache/spark/mllib/linalg/Matrices.scala | 1 + .../apache/spark/mllib/linalg/Vectors.scala | 1 + .../expressions/SpecializedGetters.java | 4 ++-- .../catalyst/expressions/UnsafeArrayData.java | 1 + .../catalyst/expressions/UnsafeMapData.java | 2 +- .../execution/UnsafeExternalRowSorter.java | 2 +- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../apache/spark/sql/catalyst/SqlParser.scala | 1 + .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/JoinedRow.scala | 1 + .../sql/catalyst/expressions/aggregates.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 1 + .../codegen/GenerateSafeProjection.scala | 1 + .../expressions/collectionOperations.scala | 1 + .../expressions/complexTypeCreator.scala | 2 +- .../expressions/complexTypeExtractors.scala | 1 + .../sql/catalyst/expressions/generators.scala | 1 + .../sql/catalyst/expressions/objects.scala | 1 + .../expressions/regexpExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 1 + .../expressions/stringExpressions.scala | 1 + .../util}/AbstractScalaRowIterator.scala | 2 +- .../util}/ArrayBasedMapData.scala | 4 ++-- .../{types => catalyst/util}/ArrayData.scala | 3 ++- .../util}/DataTypeParser.scala | 3 ++- .../util}/GenericArrayData.scala | 3 ++- .../{types => catalyst/util}/MapData.scala | 4 +++- .../apache/spark/sql/types/StructType.scala | 2 +- .../encoders/ProductEncoderSuite.scala | 20 +++++++++---------- .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 1 + .../util}/DataTypeParserSuite.scala | 3 ++- .../types/{decimal => }/DecimalSuite.scala | 3 +-- .../scala/org/apache/spark/sql/Column.scala | 4 ++-- .../sql/execution/datasources/DDLParser.scala | 1 + .../datasources/json/JacksonGenerator.scala | 2 +- .../datasources/json/JacksonParser.scala | 2 +- .../parquet/CatalystRowConverter.scala | 2 +- .../apache/spark/sql/execution/python.scala | 3 ++- .../spark/sql/test/ExamplePointUDT.scala | 1 + .../org/apache/spark/sql/UnsafeRowSuite.scala | 5 +++-- .../spark/sql/UserDefinedTypeSuite.scala | 2 ++ .../sql/columnar/ColumnarTestUtils.scala | 3 ++- .../execution/RowFormatConvertersSuite.scala | 3 ++- .../spark/sql/hive/HiveInspectors.scala | 6 +++--- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- .../org/apache/spark/sql/hive/hiveUDFs.scala | 1 + .../spark/sql/hive/HiveInspectorSuite.scala | 1 + 50 files changed, 76 insertions(+), 48 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{ => catalyst/util}/AbstractScalaRowIterator.scala (96%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/ArrayBasedMapData.scala (96%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/ArrayData.scala (97%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/DataTypeParser.scala (98%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/GenericArrayData.scala (98%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{types => catalyst/util}/MapData.scala (93%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/{types => catalyst/util}/DataTypeParserSuite.scala (98%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/types/{decimal => }/DecimalSuite.scala (99%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 8ba6e4e78d969..8879dcf75c9bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -26,6 +26,7 @@ import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 3642e9286504f..dcdc614455d34 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -30,6 +30,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index 8f1027f3164c8..eea7149d02594 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; +import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 761f0447943e8..3513960b41813 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -21,6 +21,7 @@ import java.math.BigInteger; import java.nio.ByteBuffer; +import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 5bebe2a96e391..651eb1ff0c561 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -19,7 +19,7 @@ import java.nio.ByteBuffer; -import org.apache.spark.sql.types.MapData; +import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.Platform; /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 1d27182912c8a..7d94e0566faa9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,7 +26,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.AbstractScalaRowIterator; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index f25591794abdb..2ec0ff53c89c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -27,7 +27,7 @@ import scala.language.existentials import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 713c6b547d9b7..c25161ee81b66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 08ca325b21777..833368b7d5898 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 5142856afdcac..e9cc00a2b64ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 99d7444dc470f..5564e242b0472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index d3560df0792eb..935c3aa28c999 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 95061c4635879..70819be5af5b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a4ec5085fa153..f0f7a6cf0cc4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,6 +27,7 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ee50587ed097e..f0ed8645d923f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 75c66bc271fe0..89d87726ac649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 059e45bd684ed..1854dfaa7db35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index a2b5a6a58090e..41cd0a104a1f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c0845e1a0102f..1a2092c909c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b42d6c5c1e14e..81855289762c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.util.GenericArrayData import scala.language.existentials diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 64f15945c790d..9e484c5ed83bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,7 +22,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 017efd2a166a7..cfc68fc00bea8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index abc5c94589baa..8770c4b76c2e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -22,6 +22,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala index 1090bdb5a4bd3..6d35f140cf23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.catalyst.util /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index e5ffe32217351..70b028d2b3f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -42,7 +42,7 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte ArrayBasedMapData.toScalaMap(this).hashCode() } - override def toString(): String = { + override def toString: String = { s"keys: $keyArray, values: $valueArray" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index b4ea300f5f306..cad4a08b0d839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types.DataType abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 6e081ea9237bd..2b83651f9086d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import scala.language.implicitConversions import scala.util.matching.Regex import scala.util.parsing.combinator.syntactical.StandardTokenParsers import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.types._ /** * This is a data type parser that can be used to parse string representations of data types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9448d88d6c5f0..e9bf7b33e35be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala index f50969f0f0b79..40db6067adf71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.types.DataType abstract class MapData extends Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d6b436724b2a0..11fce4beaf55f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer -import scala.math.max import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} +import org.apache.spark.sql.catalyst.util.DataTypeParser /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index 7735acbcbad41..008d0bea8a941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.catalyst.encoders -import java.util - -import org.apache.spark.sql.types.{StructField, ArrayType, ArrayData} - import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{StructField, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -166,43 +164,43 @@ class ProductEncoderSuite extends SparkFunSuite { null: Array[Byte])) encodeDecodeTestCustom(("Array[Byte]", Array[Byte](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Int] null", null: Array[Int])) encodeDecodeTestCustom(("Array[Int]", Array[Int](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Long] null", null: Array[Long])) encodeDecodeTestCustom(("Array[Long]", Array[Long](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Double] null", null: Array[Double])) encodeDecodeTestCustom(("Array[Double]", Array[Double](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Float] null", null: Array[Float])) encodeDecodeTestCustom(("Array[Float]", Array[Float](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Boolean] null", null: Array[Boolean])) encodeDecodeTestCustom(("Array[Boolean]", Array[Boolean](true, false))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTest(("Array[Short] null", null: Array[Short])) encodeDecodeTestCustom(("Array[Short]", Array[Short](1, 2, 3))) - { (l, r) => util.Arrays.equals(l._2, r._2) } + { (l, r) => java.util.Arrays.equals(l._2, r._2) } encodeDecodeTestCustom(("java.sql.Timestamp", new java.sql.Timestamp(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c6aad34e972b5..68545f33e5465 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 098944a9f4fc5..5adcac39c6514 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala similarity index 98% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1ba290753ce48..1e3409a9db6eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class DataTypeParserSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala similarity index 99% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index f9aceb8d3b13e..50683947da224 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.types.decimal +package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.Decimal import org.scalatest.PrivateMethodTester import scala.language.postfixOps diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index de11a1699afd9..e4f4cf1533ac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql - import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 446739d5b8a2c..6969b423d01b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index d7d6edeb6c6d3..3f34520afe6b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} import scala.collection.Map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 09b8a9e936a1d..b2e52011a7276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 49007e45ecf87..b16c46579f7c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -32,7 +32,7 @@ import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index d4e6980967e82..d611b0011da16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a741a45f1c527..8d4854b698ed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.test +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 7d1ee39d4b539..00f1526576cc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.{KryoSerializer, JavaSerializer} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index d17671d48a2fc..a229e5814df89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} + import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 964cdb52b245a..a5882f7870e37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -22,7 +22,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.types.{ArrayBasedMapData, GenericArrayData, AtomicType, Decimal} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 5dc37e5c3c238..b3fceeab64cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} +import org.apache.spark.sql.types.{ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 43c238fd49e0e..36f0708f9da3d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String @@ -50,8 +50,8 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[org.apache.spark.sql.types.MapData]] - * List: [[org.apache.spark.sql.types.ArrayData]] + * Map: [[MapData]] + * List: [[ArrayData]] * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5819cb9d08778..fdb576bedbbaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -32,11 +32,12 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.execution.{FileRelation, datasources} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index f57b206999399..2ccad474b4f7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 81a70b8d42267..8bb9058cd74ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row From 16dc9f344c08deee104090106cb0a537a90e33fc Mon Sep 17 00:00:00 2001 From: Rohan Bhanderi Date: Fri, 23 Oct 2015 01:10:46 -0700 Subject: [PATCH 017/324] Fix typo "Received" to "Receiver" in streaming-kafka-integration.md Removed typo on line 8 in markdown : "Received" -> "Receiver" Author: Rohan Bhanderi Closes #9242 from RohanBhanderi/patch-1. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 5db39ae54a274..ab7f0117c0b7f 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -5,7 +5,7 @@ title: Spark Streaming + Kafka Integration Guide [Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach -This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. +This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. From 487d409e71767c76399217a07af8de1bb0da7aa8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 23 Oct 2015 01:33:14 -0700 Subject: [PATCH 018/324] [SPARK-11243][SQL] zero out padding bytes in UnsafeRow For nested StructType, the underline buffer could be used for others before, we should zero out the padding bytes for those primitive types that have less than 8 bytes. cc cloud-fan Author: Davies Liu Closes #9217 from davies/zero_out. --- .../expressions/codegen/UnsafeRowWriter.java | 20 ++++++++++++++----- .../codegen/GeneratedProjectionSuite.scala | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index adbe2621870df..048b7749d8fb4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -100,19 +100,27 @@ public void alignToWords(int numBytes) { } public void write(int ordinal, boolean value) { - Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putBoolean(holder.buffer, offset, value); } public void write(int ordinal, byte value) { - Platform.putByte(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putByte(holder.buffer, offset, value); } public void write(int ordinal, short value) { - Platform.putShort(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putShort(holder.buffer, offset, value); } public void write(int ordinal, int value) { - Platform.putInt(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putInt(holder.buffer, offset, value); } public void write(int ordinal, long value) { @@ -123,7 +131,9 @@ public void write(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value); + final long offset = getFieldOffset(ordinal); + Platform.putLong(holder.buffer, offset, 0L); + Platform.putFloat(holder.buffer, offset, value); } public void write(int ordinal, double value) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 5adcac39c6514..1522ee34e43a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -99,4 +99,24 @@ class GeneratedProjectionSuite extends SparkFunSuite { val row2 = safeProj(unsafeRow) assert(row2 === row) } + + test("padding bytes should be zeroed out") { + val types = Seq(BooleanType, ByteType, ShortType, IntegerType, FloatType, BinaryType, + StringType) + val struct = StructType(types.map(StructField("", _, true))) + val fields = Array[DataType](StringType, struct) + val unsafeProj = UnsafeProjection.create(fields) + + val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes, + UTF8String.fromString("")) + val row1 = InternalRow(UTF8String.fromString(""), innerRow) + val unsafe1 = unsafeProj(row1).copy() + // create a Row with long String before the inner struct + val row2 = InternalRow(UTF8String.fromString("a_long_string").repeat(10), innerRow) + val unsafe2 = unsafeProj(row2).copy() + assert(unsafe1.getStruct(1, 7) === unsafe2.getStruct(1, 7)) + val unsafe3 = unsafeProj(row1).copy() + assert(unsafe1 === unsafe3) + assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) + } } From 03ccb22080965d44fc0e1fc94dc75a96bfa26b8a Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 23 Oct 2015 08:31:01 -0700 Subject: [PATCH 019/324] [SPARK-10382] Make example code in user guide testable A POC code for making example code in user guide testable. mengxr We still need to talk about the labels in code. Author: Xusen Yin Closes #9109 from yinxusen/SPARK-10382. --- docs/_plugins/include_example.rb | 96 ++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 docs/_plugins/include_example.rb diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb new file mode 100644 index 0000000000000..0f4184c7462be --- /dev/null +++ b/docs/_plugins/include_example.rb @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +require 'liquid' +require 'pygments' + +module Jekyll + class IncludeExampleTag < Liquid::Tag + + def initialize(tag_name, markup, tokens) + @markup = markup + super + end + + def render(context) + site = context.registers[:site] + config_dir = (site.config['code_dir'] || '../examples/src/main').sub(/^\//,'') + @code_dir = File.join(site.source, config_dir) + + clean_markup = @markup.strip + @file = File.join(@code_dir, clean_markup) + @lang = clean_markup.split('.').last + + code = File.open(@file).read.encode("UTF-8") + code = select_lines(code) + + Pygments.highlight(code, :lexer => @lang) + end + + # Trim the code block so as to have the same indention, regardless of their positions in the + # code file. + def trim_codeblock(lines) + # Select the minimum indention of the current code block. + min_start_spaces = lines + .select { |l| l.strip.size !=0 } + .map { |l| l[/\A */].size } + .min + + lines.map { |l| l[min_start_spaces .. -1] } + end + + # Select lines according to labels in code. Currently we use "$example on$" and "$example off$" + # as labels. Note that code blocks identified by the labels should not overlap. + def select_lines(code) + lines = code.each_line.to_a + + # Select the array of start labels from code. + startIndices = lines + .each_with_index + .select { |l, i| l.include? "$example on$" } + .map { |l, i| i } + + # Select the array of end labels from code. + endIndices = lines + .each_with_index + .select { |l, i| l.include? "$example off$" } + .map { |l, i| i } + + raise "Start indices amount is not equal to end indices amount, please check the code." \ + unless startIndices.size == endIndices.size + + raise "No code is selected by include_example, please check the code." \ + if startIndices.size == 0 + + # Select and join code blocks together, with a space line between each of two continuous + # blocks. + lastIndex = -1 + result = "" + startIndices.zip(endIndices).each do |start, endline| + raise "Overlapping between two example code blocks are not allowed." if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$." if start == endline + lastIndex = endline + range = Range.new(start + 1, endline - 1) + result += trim_codeblock(lines[range]).join + result += "\n" + end + result + end + end +end + +Liquid::Template.register_tag('include_example', Jekyll::IncludeExampleTag) From 282a15f78e08f0dc9e696945be4fc973011a96d9 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 23 Oct 2015 08:43:49 -0700 Subject: [PATCH 020/324] [SPARK-10277] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.regression Author: Yu ISHIKAWA Closes #8684 from yu-iskw/SPARK-10277. --- python/pyspark/mllib/regression.py | 102 ++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 961b5e80b013c..6f00d1df209c0 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,7 +18,7 @@ import numpy as np from numpy import array -from pyspark import RDD +from pyspark import RDD, since from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector @@ -43,6 +43,8 @@ class LabeledPoint(object): column matrix) Note: 'label' and 'features' are accessible as class attributes. + + .. versionadded:: 1.0.0 """ def __init__(self, label, features): @@ -66,6 +68,8 @@ class LinearModel(object): :param weights: Weights computed for every feature. :param intercept: Intercept computed for this model. + + .. versionadded:: 0.9.0 """ def __init__(self, weights, intercept): @@ -73,11 +77,15 @@ def __init__(self, weights, intercept): self._intercept = float(intercept) @property + @since("1.0.0") def weights(self): + """Weights computed for every feature.""" return self._coeff @property + @since("1.0.0") def intercept(self): + """Intercept computed for this model.""" return self._intercept def __repr__(self): @@ -94,8 +102,11 @@ class LinearRegressionModelBase(LinearModel): True >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, x): """ Predict the value of the dependent variable given a vector or @@ -163,14 +174,20 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a LinearRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a LinearRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -199,8 +216,20 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): + """ + Train a linear regression model with no regularization using Stochastic Gradient Descent. + This solves the least squares regression formulation + f(weights) = 1/n ||A weights-y||^2^ + (which is the mean squared error). + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -313,14 +342,20 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a LassoModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a LassoModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -330,8 +365,19 @@ def load(cls, sc, path): class LassoWithSGD(object): + """ + Train a regression model with L1-regularization using Stochastic Gradient Descent. + This solves the l1-regularized least squares regression formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -434,14 +480,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + + .. versionadded:: 0.9.0 """ + @since("1.4.0") def save(self, sc, path): + """Save a RidgeRegressionMode.""" java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( _py2java(sc, self._coeff), self.intercept) java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a RidgeRegressionMode.""" java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( sc._jsc.sc(), path) weights = _java2py(sc, java_model.weights()) @@ -451,8 +503,19 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): + """ + Train a regression model with L2-regularization using Stochastic Gradient Descent. + This solves the l2-regularized least squares regression formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with + its corresponding right hand side label y. + See also the documentation for the precise formulation. + + .. versionadded:: 0.9.0 + """ @classmethod + @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): @@ -531,6 +594,8 @@ class IsotonicRegressionModel(Saveable, Loader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.4.0 """ def __init__(self, boundaries, predictions, isotonic): @@ -538,6 +603,7 @@ def __init__(self, boundaries, predictions, isotonic): self.predictions = predictions self.isotonic = isotonic + @since("1.4.0") def predict(self, x): """ Predict labels for provided features. @@ -562,7 +628,9 @@ def predict(self, x): return x.map(lambda v: self.predict(v)) return np.interp(x, self.boundaries, self.predictions) + @since("1.4.0") def save(self, sc, path): + """Save a IsotonicRegressionModel.""" java_boundaries = _py2java(sc, self.boundaries.tolist()) java_predictions = _py2java(sc, self.predictions.tolist()) java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( @@ -570,7 +638,9 @@ def save(self, sc, path): java_model.save(sc._jsc.sc(), path) @classmethod + @since("1.4.0") def load(cls, sc, path): + """Load a IsotonicRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( sc._jsc.sc(), path) py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray() @@ -579,8 +649,29 @@ def load(cls, sc, path): class IsotonicRegression(object): + """ + Isotonic regression. + Currently implemented using parallelized pool adjacent violators algorithm. + Only univariate (single feature) algorithm supported. + + Sequential PAV implementation based on: + Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + Available from [[http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf]] + + Sequential PAV parallelization based on: + Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + "An approach to parallelizing isotonic regression." + Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + Available from [[http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf]] + + @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] + + .. versionadded:: 1.4.0 + """ @classmethod + @since("1.4.0") def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. @@ -598,10 +689,13 @@ class StreamingLinearAlgorithm(object): Base class that has to be inherited by any StreamingLinearAlgorithm. Prevents reimplementation of methods predictOn and predictOnValues. + + .. versionadded:: 1.5.0 """ def __init__(self, model): self._model = model + @since("1.5.0") def latestModel(self): """ Returns the latest model. @@ -616,6 +710,7 @@ def _validate(self, dstream): raise ValueError( "Model must be intialized using setInitialWeights") + @since("1.5.0") def predictOn(self, dstream): """ Make predictions on a dstream. @@ -625,6 +720,7 @@ def predictOn(self, dstream): self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) + @since("1.5.0") def predictOnValues(self, dstream): """ Make predictions on a keyed dstream. @@ -649,6 +745,8 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. :param convergenceTol: A condition which decides iteration termination. + + .. versionadded:: 1.5.0 """ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize @@ -659,6 +757,7 @@ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, conver super(StreamingLinearRegressionWithSGD, self).__init__( model=self._model) + @since("1.5.0") def setInitialWeights(self, initialWeights): """ Set the initial value of weights. @@ -669,6 +768,7 @@ def setInitialWeights(self, initialWeights): self._model = LinearRegressionModel(initialWeights, 0) return self + @since("1.5.0") def trainOn(self, dstream): """Train the model on the incoming dstream.""" self._validate(dstream) From 4e38defae13b2b13e196b4d172722ef5e6266c66 Mon Sep 17 00:00:00 2001 From: Jayant Shekar Date: Fri, 23 Oct 2015 08:45:13 -0700 Subject: [PATCH 021/324] [SPARK-6723] [MLLIB] Model import/export for ChiSqSelector This is a PR for Parquet-based model import/export. * Added save/load for ChiSqSelectorModel * Updated the test suite ChiSqSelectorSuite Author: Jayant Shekar Closes #6785 from jayantshekhar/SPARK-6723. --- .../spark/mllib/feature/ChiSqSelector.scala | 70 ++++++++++++++++++- .../mllib/feature/ChiSqSelectorSuite.scala | 26 +++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index b1524cf377808..5246faf221914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD @Since("1.3.0") @Experimental class ChiSqSelectorModel @Since("1.3.0") ( - @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") ( s"Only sparse and dense vectors are supported but got ${other.getClass}.") } } + + @Since("1.6.0") + override def save(sc: SparkContext, path: String): Unit = { + ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { + @Since("1.6.0") + override def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + ChiSqSelectorModel.SaveLoadV1_0.load(sc, path) + } + + private[feature] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(feature: Int) + + private[feature] + val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" + + def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(model.selectedFeatures.length) { i => + Data(model.selectedFeatures(i)) + } + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + + } + + def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val dataFrame = sqlContext.read.parquet(Loader.dataPath(path)) + val dataArray = dataFrame.select("feature") + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val features = dataArray.map { + case Row(feature: Int) => (feature) + }.collect() + + return new ChiSqSelectorModel(features) + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 889727fb55823..734800a9afad6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { }.collect().toSet assert(filteredData == preFilteredData) } + + test("model load / save") { + val model = ChiSqSelectorSuite.createModel() + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val sameModel = ChiSqSelectorModel.load(sc, path) + ChiSqSelectorSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object ChiSqSelectorSuite extends SparkFunSuite { + + def createModel(): ChiSqSelectorModel = { + val arr = Array(1, 2, 3, 4) + new ChiSqSelectorModel(arr) + } + + def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = { + assert(a.selectedFeatures.deep == b.selectedFeatures.deep) + } } From e1a897b657eb62e837026f7b3efafb9a6424ec4f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 23 Oct 2015 13:04:06 -0700 Subject: [PATCH 022/324] [SPARK-11274] [SQL] Text data source support for Spark SQL. This adds API for reading and writing text files, similar to SparkContext.textFile and RDD.saveAsTextFile. ``` SQLContext.read.text("/path/to/something.txt") DataFrame.write.text("/path/to/write.txt") ``` Using the new Dataset API, this also supports ``` val ds: Dataset[String] = SQLContext.read.text("/path/to/something.txt").as[String] ``` Author: Reynold Xin Closes #9240 from rxin/SPARK-11274. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../apache/spark/sql/DataFrameReader.scala | 16 ++ .../apache/spark/sql/DataFrameWriter.scala | 18 ++ .../datasources/json/JSONRelation.scala | 7 +- .../datasources/text/DefaultSource.scala | 160 ++++++++++++++++++ sql/core/src/test/resources/text-suite.txt | 4 + .../datasources/text/TextSuite.scala | 81 +++++++++ 7 files changed, 283 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala create mode 100644 sql/core/src/test/resources/text-suite.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index ca50000b4756e..1ca2044057e56 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,4 @@ org.apache.spark.sql.execution.datasources.jdbc.DefaultSource org.apache.spark.sql.execution.datasources.json.DefaultSource org.apache.spark.sql.execution.datasources.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.text.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8651a3569d6f..824220d85e04d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -302,6 +302,22 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) } + /** + * Loads a text file and returns a [[DataFrame]] with a single string column named "text". + * Each line in the text file is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * sqlContext.read.text("/path/to/spark/README.md") + * + * // Java: + * sqlContext.read().text("/path/to/spark/README.md") + * }}} + * + * @param path input path + * @since 1.6.0 + */ + def text(path: String): DataFrame = format("text").load(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 764510ab4b4bd..7887e559a3025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -244,6 +244,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. + * + * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { val props = new Properties() @@ -317,6 +319,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def orc(path: String): Unit = format("orc").save(path) + /** + * Saves the content of the [[DataFrame]] in a text file at the specified path. + * The DataFrame must have only one column that is of string type. + * Each row becomes a new line in the output file. For example: + * {{{ + * // Scala: + * df.write.text("/path/to/output") + * + * // Java: + * df.write().text("/path/to/output") + * }}} + * + * @since 1.6.0 + */ + def text(path: String): Unit = format("text").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index d05e6efa83c84..794b889a93627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -161,11 +161,10 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with SparkHadoopMapRedUtil with Logging { - val writer = new CharArrayWriter() + private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - val result = new Text() + private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private[this] val result = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala new file mode 100644 index 0000000000000..ab26c57ad1923 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{Path, FileStatus} +import org.apache.hadoop.io.{NullWritable, Text, LongWritable} +import org.apache.hadoop.mapred.{TextInputFormat, JobConf} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A data source for reading text files. + */ +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + dataSchema.foreach(verifySchema) + new TextRelation(None, partitionColumns, paths)(sqlContext) + } + + override def shortName(): String = "text" + + private def verifySchema(schema: StructType): Unit = { + if (schema.size != 1) { + throw new AnalysisException( + s"Text data source supports only a single column, and you have ${schema.size} columns.") + } + val tpe = schema(0).dataType + if (tpe != StringType) { + throw new AnalysisException( + s"Text data source supports only a string column, but you have ${tpe.simpleString}.") + } + } +} + +private[sql] class TextRelation( + val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + override val paths: Array[String] = Array.empty[String]) + (@transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) { + + /** Data schema is always a single column, named "text". */ + override def dataSchema: StructType = new StructType().add("text", StringType) + + /** This is an internal data source that outputs internal row format. */ + override val needConversion: Boolean = false + + /** Read path. */ + override def buildScan(inputPaths: Array[FileStatus]): RDD[Row] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val paths = inputPaths.map(_.getPath).sortBy(_.toUri) + + if (paths.nonEmpty) { + FileInputFormat.setInputPaths(job, paths: _*) + } + + sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + .mapPartitions { iter => + var buffer = new Array[Byte](1024) + val row = new GenericMutableRow(1) + iter.map { case (_, line) => + if (line.getLength > buffer.length) { + buffer = new Array[Byte](line.getLength) + } + System.arraycopy(line.getBytes, 0, buffer, 0, line.getLength) + row.update(0, UTF8String.fromBytes(buffer, 0, line.getLength)) + row + } + }.asInstanceOf[RDD[Row]] + } + + /** Write path. */ + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new TextOutputWriter(path, dataSchema, context) + } + } + } + + override def equals(other: Any): Boolean = other match { + case that: TextRelation => + paths.toSet == that.paths.toSet && partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode(paths.toSet, partitionColumns) + } +} + +class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) + extends OutputWriter + with SparkHadoopMapRedUtil { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + val utf8string = row.getUTF8String(0) + buffer.set(utf8string.getBytes) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/text-suite.txt new file mode 100644 index 0000000000000..e8fd967197fe8 --- /dev/null +++ b/sql/core/src/test/resources/text-suite.txt @@ -0,0 +1,4 @@ +This is a test file for the text data source +1+1 +数据砖头 +"doh" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala new file mode 100644 index 0000000000000..0a2306c06646c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.util.Utils + + +class TextSuite extends QueryTest with SharedSQLContext { + + test("reading text file") { + verifyFrame(sqlContext.read.format("text").load(testFile)) + } + + test("SQLContext.read.text() API") { + verifyFrame(sqlContext.read.text(testFile)) + } + + test("writing") { + val df = sqlContext.read.text(testFile) + + val tempFile = Utils.createTempDir() + tempFile.delete() + df.write.text(tempFile.getCanonicalPath) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + + Utils.deleteRecursively(tempFile) + } + + test("error handling for invalid schema") { + val tempFile = Utils.createTempDir() + tempFile.delete() + + val df = sqlContext.range(2) + intercept[AnalysisException] { + df.write.text(tempFile.getCanonicalPath) + } + + intercept[AnalysisException] { + sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + } + } + + private def testFile: String = { + Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString + } + + /** Verifies data and schema. */ + private def verifyFrame(df: DataFrame): Unit = { + // schema + assert(df.schema == new StructType().add("text", StringType)) + + // verify content + val data = df.collect() + assert(data(0) == Row("This is a test file for the text data source")) + assert(data(1) == Row("1+1")) + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + // scalastyle:off + assert(data(2) == Row("数据砖头")) + // scalastyle:on + assert(data(3) == Row("\"doh\"")) + assert(data.length == 4) + } +} From 4725cb988b98f367c07214c4c3cfd1206fb2b5c2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 23 Oct 2015 17:15:13 -0700 Subject: [PATCH 023/324] [SPARK-11194] [SQL] Use MutableURLClassLoader for the classLoader in IsolatedClientLoader. https://issues.apache.org/jira/browse/SPARK-11194 Author: Yin Huai Closes #9170 from yhuai/SPARK-11194. --- .../hive/client/IsolatedClientLoader.scala | 79 ++++++++++++------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 567e4d7b411ec..f99c3ed2ae987 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -30,7 +30,7 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{MutableURLClassLoader, Utils} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext @@ -148,39 +148,51 @@ private[hive] class IsolatedClientLoader( protected def classToPath(name: String): String = name.replaceAll("\\.", "/") + ".class" - /** The classloader that is used to load an isolated version of Hive. */ - private[hive] var classLoader: ClassLoader = if (isolationOn) { - new URLClassLoader(allJars, rootClassLoader) { - override def loadClass(name: String, resolve: Boolean): Class[_] = { - val loaded = findLoadedClass(name) - if (loaded == null) doLoadClass(name, resolve) else loaded - } - def doLoadClass(name: String, resolve: Boolean): Class[_] = { - val classFileName = name.replaceAll("\\.", "/") + ".class" - if (isBarrierClass(name)) { - // For barrier classes, we construct a new copy of the class. - val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) - logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") - defineClass(name, bytes, 0, bytes.length) - } else if (!isSharedClass(name)) { - logDebug(s"hive class: $name - ${getResource(classToPath(name))}") - super.loadClass(name, resolve) - } else { - // For shared classes, we delegate to baseClassLoader. - logDebug(s"shared class: $name") - baseClassLoader.loadClass(name) + /** + * The classloader that is used to load an isolated version of Hive. + * This classloader is a special URLClassLoader that exposes the addURL method. + * So, when we add jar, we can add this new jar directly through the addURL method + * instead of stacking a new URLClassLoader on top of it. + */ + private[hive] val classLoader: MutableURLClassLoader = { + val isolatedClassLoader = + if (isolationOn) { + new URLClassLoader(allJars, rootClassLoader) { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + val loaded = findLoadedClass(name) + if (loaded == null) doLoadClass(name, resolve) else loaded + } + def doLoadClass(name: String, resolve: Boolean): Class[_] = { + val classFileName = name.replaceAll("\\.", "/") + ".class" + if (isBarrierClass(name)) { + // For barrier classes, we construct a new copy of the class. + val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) + logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") + defineClass(name, bytes, 0, bytes.length) + } else if (!isSharedClass(name)) { + logDebug(s"hive class: $name - ${getResource(classToPath(name))}") + super.loadClass(name, resolve) + } else { + // For shared classes, we delegate to baseClassLoader. + logDebug(s"shared class: $name") + baseClassLoader.loadClass(name) + } + } } + } else { + baseClassLoader } - } - } else { - baseClassLoader + // Right now, we create a URLClassLoader that gives preference to isolatedClassLoader + // over its own URLs when it loads classes and resources. + // We may want to use ChildFirstURLClassLoader based on + // the configuration of spark.executor.userClassPathFirst, which gives preference + // to its own URLs over the parent class loader (see Executor's createClassLoader method). + new NonClosableMutableURLClassLoader(isolatedClassLoader) } private[hive] def addJar(path: String): Unit = synchronized { val jarURL = new java.io.File(path).toURI.toURL - // TODO: we should avoid of stacking classloaders (use a single URLClassLoader and add jars - // to that) - classLoader = new java.net.URLClassLoader(Array(jarURL), classLoader) + classLoader.addURL(jarURL) } /** The isolated client interface to Hive. */ @@ -221,3 +233,14 @@ private[hive] class IsolatedClientLoader( */ private[hive] var cachedHive: Any = null } + +/** + * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. + * This class loader cannot be closed (its `close` method is a no-op). + */ +private[sql] class NonClosableMutableURLClassLoader( + parent: ClassLoader) + extends MutableURLClassLoader(Array.empty, parent) { + + override def close(): Unit = {} +} From 2462dbcce89d657bca17ae311c99c2a4bee4a5fa Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 23 Oct 2015 21:38:04 -0700 Subject: [PATCH 024/324] [SPARK-10971][SPARKR] RRunner should allow setting path to Rscript. Add a new spark conf option "spark.sparkr.r.driver.command" to specify the executable for an R script in client modes. The existing spark conf option "spark.sparkr.r.command" is used to specify the executable for an R script in cluster modes for both driver and workers. See also [launch R worker script](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/api/r/RRDD.scala#L395). BTW, [envrionment variable "SPARKR_DRIVER_R"](https://github.com/apache/spark/blob/master/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java#L275) is used to locate R shell on the local host. For your information, PYSPARK has two environment variables serving simliar purpose: PYSPARK_PYTHON Python binary executable to use for PySpark in both driver and workers (default is `python`). PYSPARK_DRIVER_PYTHON Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). pySpark use the code [here](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala#L41) to determine the python executable for a python script. Author: Sun Rui Closes #9179 from sun-rui/SPARK-10971. --- .../org/apache/spark/deploy/RRunner.scala | 11 ++++++++++- docs/configuration.md | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 58cc1f9d963df..ed183cf16a9cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -40,7 +40,16 @@ object RRunner { // Time to wait for SparkR backend to initialize in seconds val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt - val rCommand = "Rscript" + val rCommand = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + var cmd = sys.props.getOrElse("spark.sparkr.r.command", "Rscript") + cmd = sys.props.getOrElse("spark.r.command", cmd) + if (sys.props.getOrElse("spark.submit.deployMode", "client") == "client") { + cmd = sys.props.getOrElse("spark.r.driver.command", cmd) + } + cmd + } // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode diff --git a/docs/configuration.md b/docs/configuration.md index be9c36bdfe3de..682384d4249e0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1589,6 +1589,20 @@ Apart from these, the following properties are also available, and may be useful Number of threads used by RBackend to handle RPC calls from SparkR package. + + spark.r.command + Rscript + + Executable for executing R scripts in cluster modes for both driver and workers. + + + + spark.r.driver.command + spark.r.command + + Executable for executing R scripts in client modes for driver. Ignored in cluster modes. + + #### Cluster Managers @@ -1628,6 +1642,10 @@ The following variables can be set in `spark-env.sh`: PYSPARK_DRIVER_PYTHON Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + + SPARKR_DRIVER_R + R binary executable to use for SparkR shell (default is R). + SPARK_LOCAL_IP IP address of the machine to bind to. From 5e458125018029cef5cde3390f4a55dd4e164fde Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 23 Oct 2015 21:42:00 -0700 Subject: [PATCH 025/324] [SPARK-11294][SPARKR] Improve R doc for read.df, write.df, saveAsTable Add examples for read.df, write.df; fix grouping for read.df, loadDF; fix formatting and text truncation for write.df, saveAsTable. Several text issues: ![image](https://cloud.githubusercontent.com/assets/8969467/10708590/1303a44e-79c3-11e5-854f-3a2e16854cd7.png) - text collapsed into a single paragraph - text truncated at 2 places, eg. "overwrite: Existing data is expected to be overwritten by the contents of error:" shivaram Author: felixcheung Closes #9261 from felixcheung/rdocreadwritedf. --- R/pkg/R/DataFrame.R | 27 +++++++++++++-------------- R/pkg/R/SQLContext.R | 16 +++++++++++----- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 993be82a47f75..2acbd081cd504 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1572,18 +1572,17 @@ setMethod("except", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' #' @rdname write.df #' @name write.df @@ -1596,6 +1595,7 @@ setMethod("except", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") +#' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), @@ -1637,18 +1637,17 @@ setMethod("saveDF", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' #' @rdname saveAsTable #' @name saveAsTable diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 399f53657a68c..1bf025cce4376 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,14 +452,21 @@ dropTempTable <- function(sqlContext, tableName) { #' #' @param sqlContext SQLContext to use #' @param path The path of files to load -#' @param source the name of external data source +#' @param source The name of external data source +#' @param schema The data schema defined in structType #' @return DataFrame +#' @rdname read.df +#' @name read.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, "path/to/file.json", source = "json") +#' df1 <- read.df(sqlContext, "path/to/file.json", source = "json") +#' schema <- structType(structField("name", "string"), +#' structField("info", "map")) +#' df2 <- read.df(sqlContext, mapTypeJsonPath, "json", schema) +#' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") #' } read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { @@ -482,9 +489,8 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) dataFrame(sdf) } -#' @aliases loadDF -#' @export - +#' @rdname read.df +#' @name loadDF loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { read.df(sqlContext, path, source, schema, ...) } From ffed00493a0aa2373a04e3aa374404936fbe15c7 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 23 Oct 2015 22:56:55 -0700 Subject: [PATCH 026/324] =?UTF-8?q?[SPARK-11125]=20[SQL]=20Uninformative?= =?UTF-8?q?=20exception=20when=20running=20spark-sql=20witho=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ut building with -Phive-thriftserver and SPARK_PREPEND_CLASSES is set This is the exception after this patch. Please help review. ``` java.lang.NoClassDefFoundError: org/apache/hadoop/hive/cli/CliDriver at java.lang.ClassLoader.defineClass1(Native Method) at java.lang.ClassLoader.defineClass(ClassLoader.java:800) at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142) at java.net.URLClassLoader.defineClass(URLClassLoader.java:449) at java.net.URLClassLoader.access$100(URLClassLoader.java:71) at java.net.URLClassLoader$1.run(URLClassLoader.java:361) at java.net.URLClassLoader$1.run(URLClassLoader.java:355) at java.security.AccessController.doPrivileged(Native Method) at java.net.URLClassLoader.findClass(URLClassLoader.java:354) at java.lang.ClassLoader.loadClass(ClassLoader.java:425) at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:308) at java.lang.ClassLoader.loadClass(ClassLoader.java:412) at java.lang.ClassLoader.loadClass(ClassLoader.java:358) at java.lang.Class.forName0(Native Method) at java.lang.Class.forName(Class.java:270) at org.apache.spark.util.Utils$.classForName(Utils.scala:173) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:647) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:180) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:205) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:120) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: java.lang.ClassNotFoundException: org.apache.hadoop.hive.cli.CliDriver at java.net.URLClassLoader$1.run(URLClassLoader.java:366) at java.net.URLClassLoader$1.run(URLClassLoader.java:355) at java.security.AccessController.doPrivileged(Native Method) at java.net.URLClassLoader.findClass(URLClassLoader.java:354) at java.lang.ClassLoader.loadClass(ClassLoader.java:425) at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:308) at java.lang.ClassLoader.loadClass(ClassLoader.java:358) ... 21 more Failed to load hive class. You need to build Spark with -Phive and -Phive-thriftserver. ``` Author: Jeff Zhang Closes #9134 from zjffdu/SPARK-11125. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ad92f5635af35..640cc325281a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -655,6 +655,15 @@ object SparkSubmit { // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + case e: NoClassDefFoundError => + e.printStackTrace(printStream) + if (e.getMessage.contains("org/apache/hadoop/hive")) { + // scalastyle:off println + printStream.println(s"Failed to load hive class.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println + } + System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } // SPARK-4170 From e5bc8c27577f96c1ae5dc8cf9bf41cbe2877ffe3 Mon Sep 17 00:00:00 2001 From: dima Date: Sat, 24 Oct 2015 18:16:45 +0100 Subject: [PATCH 027/324] [SPARK-11245] update twitter4j to 4.0.4 version update twitter4j to 4.0.4 version https://issues.apache.org/jira/browse/SPARK-11245 Author: dima Closes #9221 from pronix/twitter4j_update. --- external/twitter/pom.xml | 2 +- .../apache/spark/streaming/twitter/TwitterInputDStream.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 4c22ec8b3b154..087270de90b3f 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -51,7 +51,7 @@ org.twitter4j twitter4j-stream - 3.0.3 + 4.0.4 org.scalacheck diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index d7de74b350543..9a85a6597c27f 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -87,7 +87,7 @@ class TwitterReceiver( val query = new FilterQuery if (filters.size > 0) { - query.track(filters.toArray) + query.track(filters.mkString(",")) newTwitterStream.filter(query) } else { newTwitterStream.sample() From 28132ceb10d0c127495ce8cb36135e1cb54164d7 Mon Sep 17 00:00:00 2001 From: Jeffrey Naisbitt Date: Sat, 24 Oct 2015 18:21:36 +0100 Subject: [PATCH 028/324] [SPARK-11264] bin/spark-class can't find assembly jars with certain GREP_OPTIONS set Temporarily remove GREP_OPTIONS if set in bin/spark-class. Some GREP_OPTIONS will modify the output of the grep commands that are looking for the assembly jars. For example, if the -n option is specified, the grep output will look like: 5:spark-assembly-1.5.1-hadoop2.4.0.jar This will not match the regular expressions, and so the jar files will not be found. We could improve the regular expression to handle this case and trim off extra characters, but it is difficult to know which options may or may not be set. Unsetting GREP_OPTIONS within the script handles all the cases and gives the desired output. Author: Jeffrey Naisbitt Closes #9231 from naisbitt/unset-GREP_OPTIONS. --- bin/spark-class | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/spark-class b/bin/spark-class index e38e08dec40e4..8cae6ccbabe7c 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -42,6 +42,7 @@ else ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" fi +GREP_OPTIONS= num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 From 146da0d8100490a6e49a6c076ec253cdaf9f8905 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sun, 25 Oct 2015 01:33:22 +0100 Subject: [PATCH 029/324] Fix typos Two typos squashed. BTW Let me know how to proceed with other typos if I ran across any. I don't feel well to leave them aside as much as sending pull requests with such tiny changes. Guide me. Author: Jacek Laskowski Closes #9250 from jaceklaskowski/typos-hunting. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 2 +- .../main/scala/org/apache/spark/metrics/MetricsSystem.scala | 2 +- .../main/scala/org/apache/spark/scheduler/TaskScheduler.scala | 3 ++- core/src/main/scala/org/apache/spark/util/ThreadUtils.scala | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 58d3b846fd80d..f023e4b21cb40 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -621,7 +621,7 @@ private[spark] object SparkConf extends Logging { /** * Return whether the given config should be passed to an executor on start-up. * - * Certain akka and authentication configs are required of the executor when it connects to + * Certain akka and authentication configs are required from the executor when it connects to * the scheduler, while the rest of the spark configs can be inherited from the driver later. */ def isExecutorStartupConf(name: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 48afe3ae3511f..fdf76d312db3b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -197,7 +197,7 @@ private[spark] class MetricsSystem private ( } } catch { case e: Exception => { - logError("Sink class " + classPath + " cannot be instantialized") + logError("Sink class " + classPath + " cannot be instantiated") throw e } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index f25f3ed0d9037..cb9a3008107d7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -22,7 +22,8 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId /** - * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. + * Low-level task scheduler interface, currently implemented exclusively by + * [[org.apache.spark.scheduler.TaskSchedulerImpl]]. * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 15e7519d708c6..53283448c87b1 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -80,7 +80,7 @@ private[spark] object ThreadUtils { } /** - * Wrapper over newSingleThreadScheduledExecutor. + * Wrapper over ScheduledThreadPoolExecutor. */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() From b67dc6a4342577e73b0600b51052c286c4569960 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 25 Oct 2015 10:31:44 +0100 Subject: [PATCH 030/324] [SPARK-11299][DOC] Fix link to Scala DataFrame Functions reference The SQL programming guide's link to the DataFrame functions reference points to the wrong location; this patch fixes that. Author: Josh Rosen Closes #9269 from JoshRosen/SPARK-11299. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 30206c6f6fd93..f07c9573696ed 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -215,7 +215,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.DataFrame). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). From 92b9c5edd90f7b89efc687c0cea6778daa1a6b66 Mon Sep 17 00:00:00 2001 From: Alexander Slesarenko Date: Sun, 25 Oct 2015 10:37:10 +0100 Subject: [PATCH 031/324] [SPARK-6428][SQL] Removed unnecessary typecasts in MutableInt, MutableDouble etc. marmbrus rxin I believe these typecasts are not required in the presence of explicit return types. Author: Alexander Slesarenko Closes #9262 from aslesarenko/remove-typecasts. --- .../expressions/SpecificMutableRow.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4f56f94bd4ca4..475cbe005a6ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * val newCopy = new Mutable$tpe * newCopy.isNull = isNull * newCopy.value = value - * newCopy.asInstanceOf[this.type] + * newCopy * } * }""" * }.foreach(println) @@ -78,7 +78,7 @@ final class MutableInt extends MutableValue { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableInt] + newCopy } } @@ -93,7 +93,7 @@ final class MutableFloat extends MutableValue { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableFloat] + newCopy } } @@ -108,7 +108,7 @@ final class MutableBoolean extends MutableValue { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableBoolean] + newCopy } } @@ -123,7 +123,7 @@ final class MutableDouble extends MutableValue { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableDouble] + newCopy } } @@ -138,7 +138,7 @@ final class MutableShort extends MutableValue { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableShort] + newCopy } } @@ -153,7 +153,7 @@ final class MutableLong extends MutableValue { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableLong] + newCopy } } @@ -168,7 +168,7 @@ final class MutableByte extends MutableValue { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableByte] + newCopy } } @@ -183,7 +183,7 @@ final class MutableAny extends MutableValue { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[MutableAny] + newCopy } } From 80279ac1875d488f7000f352a958a35536bd4c2e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 25 Oct 2015 19:05:45 +0000 Subject: [PATCH 032/324] [SPARK-11287] Fixed class name to properly start TestExecutor from deploy.client.TestClient Executing deploy.client.TestClient fails due to bad class name for TestExecutor in ApplicationDescription. Author: Bryan Cutler Closes #9255 from BryanCutler/fix-TestClient-classname-SPARK-11287. --- .../main/scala/org/apache/spark/deploy/client/TestClient.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 1c79089303e3d..adb3f02258029 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -48,8 +48,9 @@ private[spark] object TestClient { val url = args(0) val conf = new SparkConf val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val executorClassname = TestExecutor.getClass.getCanonicalName.stripSuffix("$") val desc = new ApplicationDescription("TestClient", Some(1), 512, - Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") + Command(executorClassname, Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() From 63accc79625d8a03d0624717af5e1d81b18a6da3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 25 Oct 2015 21:18:35 -0700 Subject: [PATCH 033/324] [SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key. TODO: - [x] add tests Author: Burak Yavuz Closes #8954 from brkyvz/kinesis-handler. --- .../kinesis/KinesisBackedBlockRDD.scala | 35 ++- .../kinesis/KinesisInputDStream.scala | 15 +- .../streaming/kinesis/KinesisReceiver.scala | 18 +- .../kinesis/KinesisRecordProcessor.scala | 4 +- .../streaming/kinesis/KinesisUtils.scala | 247 ++++++++++++++++-- .../kinesis/JavaKinesisStreamSuite.java | 29 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 16 +- .../kinesis/KinesisReceiverSuite.scala | 4 +- .../kinesis/KinesisStreamSuite.scala | 44 +++- 9 files changed, 337 insertions(+), 75 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 5d32fa699ae5b..000897a4e7290 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.kinesis import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} @@ -67,7 +68,7 @@ class KinesisBackedBlockRDDPartition( * sequence numbers of the corresponding blocks. */ private[kinesis] -class KinesisBackedBlockRDD( +class KinesisBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, val regionName: String, val endpointUrl: String, @@ -75,8 +76,9 @@ class KinesisBackedBlockRDD( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, + val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, val awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[Array[Byte]](sc, blockIds) { + ) extends BlockRDD[T](sc, blockIds) { require(blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") @@ -90,23 +92,23 @@ class KinesisBackedBlockRDD( } } - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + override def compute(split: Partition, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] val blockId = partition.blockId - def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + def getBlockFromBlockManager(): Option[Iterator[T]] = { logDebug(s"Read partition data of $this from block manager, block $blockId") - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) } - def getBlockFromKinesis(): Iterator[Array[Byte]] = { - val credenentials = awsCredentialsOption.getOrElse { + def getBlockFromKinesis(): Iterator[T] = { + val credentials = awsCredentialsOption.getOrElse { new DefaultAWSCredentialsProviderChain().getCredentials() } partition.seqNumberRanges.ranges.iterator.flatMap { range => - new KinesisSequenceRangeIterator( - credenentials, endpointUrl, regionName, range, retryTimeoutMs) + new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, + range, retryTimeoutMs).map(messageHandler) } } if (partition.isBlockIdValid) { @@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator( endpointUrl: String, regionId: String, range: SequenceNumberRange, - retryTimeoutMs: Int - ) extends NextIterator[Array[Byte]] with Logging { + retryTimeoutMs: Int) extends NextIterator[Record] with Logging { private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName @@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator( client.setEndpoint(endpointUrl, "kinesis", regionId) - override protected def getNext(): Array[Byte] = { - var nextBytes: Array[Byte] = null + override protected def getNext(): Record = { + var nextRecord: Record = null if (toSeqNumberReceived) { finished = true } else { @@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator( } else { // Get the record, copy the data into a byte array and remember its sequence number - val nextRecord: Record = internalIterator.next() - val byteBuffer = nextRecord.getData() - nextBytes = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(nextBytes) + nextRecord = internalIterator.next() lastSeqNumber = nextRecord.getSequenceNumber() // If the this record's sequence number matches the stopping sequence number, then make sure @@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator( toSeqNumberReceived = true } } - } - nextBytes + nextRecord } override protected def close(): Unit = { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 2e4204dcb6f1a..72ab6357a53b0 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.kinesis +import scala.reflect.ClassTag + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} @@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.{Duration, StreamingContext, Time} -private[kinesis] class KinesisInputDStream( +private[kinesis] class KinesisInputDStream[T: ClassTag]( @transient _ssc: StreamingContext, streamName: String, endpointUrl: String, @@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream( checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, + messageHandler: Record => T, awsCredentialsOption: Option[SerializableAWSCredentials] - ) extends ReceiverInputDStream[Array[Byte]](_ssc) { + ) extends ReceiverInputDStream[T](_ssc) { private[streaming] - override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = { + override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { // This returns true even for when blockInfos is empty val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty) @@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream( context.sc, regionName, endpointUrl, blockIds, seqNumRanges, isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, + messageHandler = messageHandler, awsCredentialsOption = awsCredentialsOption) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + @@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream( } } - override def getReceiver(): Receiver[Array[Byte]] = { + override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption) + checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 6e0988c1af8a1..134d627cdaffa 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -80,7 +80,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies * the credentials */ -private[kinesis] class KinesisReceiver( +private[kinesis] class KinesisReceiver[T]( val streamName: String, endpointUrl: String, regionName: String, @@ -88,8 +88,9 @@ private[kinesis] class KinesisReceiver( checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, - awsCredentialsOption: Option[SerializableAWSCredentials] - ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + messageHandler: Record => T, + awsCredentialsOption: Option[SerializableAWSCredentials]) + extends Receiver[T](storageLevel) with Logging { receiver => /* * ================================================================================= @@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver( /** Add records of the given shard to the current block being generated */ private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { if (records.size > 0) { - val dataIterator = records.iterator().asScala.map { record => - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } + val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) @@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( - blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = { + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + @@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver( /** Callback method called when a block is ready to be pushed / stored. */ def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { storeBlockWithRanges(blockId, - arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]]) + arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]]) } /** Callback called in case of any error in internal of the BlockGenerator */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index b2405123321e3..1d5178790ec4c 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -41,8 +41,8 @@ import org.apache.spark.Logging * @param checkpointState represents the checkpoint state including the next checkpoint time. * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor( - receiver: KinesisReceiver, +private[kinesis] class KinesisRecordProcessor[T]( + receiver: KinesisReceiver[T], workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index c799fadf2d5ce..2849fd8a82102 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,16 +16,120 @@ */ package org.apache.spark.streaming.kinesis +import scala.reflect.ClassTag + import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Duration, StreamingContext} - object KinesisUtils { + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + */ + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T): ReceiverInputDStream[T] = { + val cleanedHandler = ssc.sc.clean(messageHandler) + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, None) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -61,12 +165,12 @@ object KinesisUtils { regionName: String, initialPositionInStream: InitialPositionInStream, checkpointInterval: Duration, - storageLevel: StorageLevel - ): ReceiverInputDStream[Array[Byte]] = { + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None) + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + defaultMessageHandler, None) } } @@ -109,12 +213,11 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): ReceiverInputDStream[Array[Byte]] = { + awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) } } @@ -156,11 +259,113 @@ object KinesisUtils { storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), - initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None) + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, + getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName, + checkpointInterval, storageLevel, defaultMessageHandler, None) } } + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + */ + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T]): JavaReceiverInputDStream[T] = { + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey) + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -198,8 +403,8 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel) + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) } /** @@ -241,10 +446,10 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, + defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } /** @@ -297,6 +502,14 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } } /** diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 87954a31f60ce..3f0f6793d2d21 100644 --- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,14 +17,19 @@ package org.apache.spark.streaming.kinesis; +import com.amazonaws.services.kinesis.model.Record; +import org.junit.Test; + +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; -import org.junit.Test; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import java.nio.ByteBuffer; + /** * Demonstrate the use of the KinesisUtils Java API */ @@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { public void testKinesisStream() { // Tests the API, does not actually test data receiving JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); - + + ssc.stop(); + } + + + private static Function handler = new Function() { + @Override + public String call(Record record) { + return record.getPartitionKey() + "-" + record.getSequenceNumber(); + } + }; + + @Test + public void testCustomHandler() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class); + ssc.stop(); } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index a89e5627e014c..9f9e146a08d46 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -73,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll testIfEnabled("Basic reading from Kinesis") { // Verify all data using multiple ranges in a single RDD partition - val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(1), + val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(1), Array(SequenceNumberRanges(allRanges.toArray)) ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData1.toSet === testData.toSet) // Verify all data using one range in each of the multiple RDD partitions - val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData2.toSet === testData.toSet) // Verify ordering within each partition - val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collectPartitions() assert(receivedData3.length === allRanges.size) @@ -209,7 +209,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll }, "Incorrect configuration of RDD, unexpected ranges set" ) - val rdd = new KinesisBackedBlockRDD( + val rdd = new KinesisBackedBlockRDD[Array[Byte]]( sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) val collectedData = rdd.map { bytes => new String(bytes).toInt @@ -223,7 +223,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll if (testIsBlockValid) { require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") - val rdd2 = new KinesisBackedBlockRDD( + val rdd2 = new KinesisBackedBlockRDD[Array[Byte]]( sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) intercept[SparkException] { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 3d136aec2e702..17ab444704f44 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -52,14 +52,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) val batch = Arrays.asList(record1, record2) - var receiverMock: KinesisReceiver = _ + var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ var checkpointClockMock: ManualClock = _ var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ override def beforeFunction(): Unit = { - receiverMock = mock[KinesisReceiver] + receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] checkpointStateMock = mock[KinesisCheckpointState] diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 1177dc758100d..ba84e557dfcc2 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -31,6 +32,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo @@ -113,9 +115,9 @@ class KinesisStreamSuite extends KinesisFunSuite val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) - assert(inputStream.isInstanceOf[KinesisInputDStream]) + assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]]) - val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream] + val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]] val time = Time(1000) // Generate block info data for testing @@ -134,8 +136,8 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the generated KinesisBackedBlockRDD has the all the right information val blockInfos = Seq(blockInfo1, blockInfo2) val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) - nonEmptyRDD shouldBe a [KinesisBackedBlockRDD] - val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD] + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) @@ -151,7 +153,7 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that KinesisBackedBlockRDD is generated even when there are no blocks val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) - emptyRDD shouldBe a [KinesisBackedBlockRDD] + emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] emptyRDD.partitions shouldBe empty // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid @@ -192,6 +194,32 @@ class KinesisStreamSuite extends KinesisFunSuite ssc.stop(stopSparkContext = false) } + testIfEnabled("custom message handling") { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + stream shouldBe a [ReceiverInputDStream[Int]] + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData) + val modData = testData.map(_ + 5) + assert(collected === modData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + testIfEnabled("failure recovery") { val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) val checkpointDir = Utils.createTempDir().getAbsolutePath @@ -210,7 +238,7 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) }) @@ -243,10 +271,10 @@ class KinesisStreamSuite extends KinesisFunSuite times.foreach { time => val (arrayOfSeqNumRanges, data) = collectedData(time) val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD] + rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] // Verify the recovered sequence ranges - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => assert(expected.ranges.toSeq === found.ranges.toSeq) From 85e654c5ec87e666a8845bfd77185c1ea57b268a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 25 Oct 2015 21:19:52 -0700 Subject: [PATCH 034/324] [SPARK-10984] Simplify *MemoryManager class structure This patch refactors the MemoryManager class structure. After #9000, Spark had the following classes: - MemoryManager - StaticMemoryManager - ExecutorMemoryManager - TaskMemoryManager - ShuffleMemoryManager This is fairly confusing. To simplify things, this patch consolidates several of these classes: - ShuffleMemoryManager and ExecutorMemoryManager were merged into MemoryManager. - TaskMemoryManager is moved into Spark Core. **Key changes and tasks**: - [x] Merge ExecutorMemoryManager into MemoryManager. - [x] Move pooling logic into Allocator. - [x] Move TaskMemoryManager from `spark-unsafe` to `spark-core`. - [x] Refactor the existing Tungsten TaskMemoryManager interactions so Tungsten code use only this and not both this and ShuffleMemoryManager. - [x] Refactor non-Tungsten code to use the TaskMemoryManager instead of ShuffleMemoryManager. - [x] Merge ShuffleMemoryManager into MemoryManager. - [x] Move code - [x] ~~Simplify 1/n calculation.~~ **Will defer to followup, since this needs more work.** - [x] Port ShuffleMemoryManagerSuite tests. - [x] Move classes from `unsafe` package to `memory` package. - [ ] Figure out how to handle the hacky use of the memory managers in HashedRelation's broadcast variable construction. - [x] Test porting and cleanup: several tests relied on mock functionality (such as `TestShuffleMemoryManager.markAsOutOfMemory`) which has been changed or broken during the memory manager consolidation - [x] AbstractBytesToBytesMapSuite - [x] UnsafeExternalSorterSuite - [x] UnsafeFixedWidthAggregationMapSuite - [x] UnsafeKVExternalSorterSuite **Compatiblity notes**: - This patch introduces breaking changes in `ExternalAppendOnlyMap`, which is marked as `DevloperAPI` (likely for legacy reasons): this class now cannot be used outside of a task. Author: Josh Rosen Closes #9127 from JoshRosen/SPARK-10984. --- .../spark}/memory/TaskMemoryManager.java | 111 +++--- .../shuffle/sort/PackedRecordPointer.java | 4 +- .../shuffle/sort/ShuffleExternalSorter.java | 57 ++- .../shuffle/sort/UnsafeShuffleWriter.java | 7 +- .../spark/unsafe/map/BytesToBytesMap.java | 36 +- .../sort/RecordPointerAndKeyPrefix.java | 4 +- .../unsafe/sort/UnsafeExternalSorter.java | 51 +-- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 23 +- .../scala/org/apache/spark/TaskContext.scala | 2 +- .../org/apache/spark/TaskContextImpl.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../apache/spark/memory/MemoryManager.scala | 197 ++++++++++- .../spark/memory/StaticMemoryManager.scala | 12 +- .../spark/memory/UnifiedMemoryManager.scala | 12 +- .../org/apache/spark/scheduler/Task.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 209 ----------- .../shuffle/sort/SortShuffleManager.scala | 1 - .../shuffle/sort/SortShuffleWriter.scala | 6 +- .../collection/ExternalAppendOnlyMap.scala | 49 ++- .../util/collection/ExternalSorter.scala | 8 +- .../spark/util/collection/Spillable.scala | 16 +- .../spark}/memory/TaskMemoryManagerSuite.java | 25 +- .../sort/PackedRecordPointerSuite.java | 12 +- .../sort/ShuffleInMemorySorterSuite.java | 9 +- .../sort/UnsafeShuffleWriterSuite.java | 53 ++- .../map/AbstractBytesToBytesMapSuite.java | 108 ++---- .../map/BytesToBytesMapOffHeapSuite.java | 7 +- .../map/BytesToBytesMapOnHeapSuite.java | 7 +- .../sort/UnsafeExternalSorterSuite.java | 34 +- .../sort/UnsafeInMemorySorterSuite.java | 13 +- .../scala/org/apache/spark/FailureSuite.scala | 4 +- .../memory/GrantEverythingMemoryManager.scala | 51 +-- .../spark/memory/MemoryManagerSuite.scala | 134 +++++++ .../spark/memory/MemoryTestingUtils.scala | 37 ++ .../memory/StaticMemoryManagerSuite.scala | 24 +- .../memory/UnifiedMemoryManagerSuite.scala | 26 +- .../shuffle/ShuffleMemoryManagerSuite.scala | 326 ------------------ .../BlockManagerReplicationSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 8 +- .../ExternalAppendOnlyMapSuite.scala | 60 ++-- .../util/collection/ExternalSorterSuite.scala | 48 ++- .../execution/UnsafeExternalRowSorter.java | 1 - .../UnsafeFixedWidthAggregationMap.java | 12 +- .../sql/execution/UnsafeKVExternalSorter.java | 22 +- .../TungstenAggregationIterator.scala | 9 +- .../datasources/WriterContainer.scala | 3 +- .../sql/execution/joins/HashedRelation.scala | 21 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 54 ++- .../UnsafeKVExternalSorterSuite.scala | 19 +- .../execution/UnsafeRowSerializerSuite.scala | 10 +- .../TungstenAggregationIteratorSuite.scala | 4 +- .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- .../unsafe/memory/ExecutorMemoryManager.java | 111 ------ .../unsafe/memory/HeapMemoryAllocator.java | 51 ++- .../spark/unsafe/memory/MemoryBlock.java | 5 +- 58 files changed, 888 insertions(+), 1255 deletions(-) rename {unsafe/src/main/java/org/apache/spark/unsafe => core/src/main/java/org/apache/spark}/memory/TaskMemoryManager.java (78%) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala rename {unsafe/src/test/java/org/apache/spark/unsafe => core/src/test/java/org/apache/spark}/memory/TaskMemoryManagerSuite.java (74%) rename sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala => core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala (56%) create mode 100644 core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala delete mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java similarity index 78% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java rename to core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 97b2c93f0dc37..7b31c90dac666 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.unsafe.memory; +package org.apache.spark.memory; import java.util.*; @@ -23,6 +23,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.unsafe.memory.MemoryBlock; + /** * Manages the memory allocated by an individual task. *

@@ -87,13 +89,9 @@ public class TaskMemoryManager { */ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); - /** - * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean - * up leaked memory. - */ - private final HashSet allocatedNonPageMemory = new HashSet(); + private final MemoryManager memoryManager; - private final ExecutorMemoryManager executorMemoryManager; + private final long taskAttemptId; /** * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods @@ -103,16 +101,38 @@ public class TaskMemoryManager { private final boolean inHeap; /** - * Construct a new MemoryManager. + * Construct a new TaskMemoryManager. */ - public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { - this.inHeap = executorMemoryManager.inHeap; - this.executorMemoryManager = executorMemoryManager; + public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); + this.memoryManager = memoryManager; + this.taskAttemptId = taskAttemptId; + } + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * @return number of bytes successfully granted (<= N). + */ + public long acquireExecutionMemory(long size) { + return memoryManager.acquireExecutionMemory(size, taskAttemptId); + } + + /** + * Release N bytes of execution memory. + */ + public void releaseExecutionMemory(long size) { + memoryManager.releaseExecutionMemory(size, taskAttemptId); + } + + public long pageSizeBytes() { + return memoryManager.pageSizeBytes(); } /** * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is - * intended for allocating large blocks of memory that will be shared between operators. + * intended for allocating large blocks of Tungsten memory that will be shared between operators. + * + * Returns `null` if there was not enough memory to allocate the page. */ public MemoryBlock allocatePage(long size) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { @@ -129,7 +149,15 @@ public MemoryBlock allocatePage(long size) { } allocatedPages.set(pageNumber); } - final MemoryBlock page = executorMemoryManager.allocate(size); + final long acquiredExecutionMemory = acquireExecutionMemory(size); + if (acquiredExecutionMemory != size) { + releaseExecutionMemory(acquiredExecutionMemory); + synchronized (this) { + allocatedPages.clear(pageNumber); + } + return null; + } + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { @@ -152,45 +180,16 @@ public void freePage(MemoryBlock page) { if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } - // Cannot access a page once it's freed. - executorMemoryManager.free(page); - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended - * to be used for allocating operators' internal data structures. For data pages that you want to - * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since - * that will enable intra-memory pointers (see - * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's - * top-level Javadoc for more details). - */ - public MemoryBlock allocate(long size) throws OutOfMemoryError { - assert(size > 0) : "Size must be positive, but got " + size; - final MemoryBlock memory = executorMemoryManager.allocate(size); - synchronized(allocatedNonPageMemory) { - allocatedNonPageMemory.add(memory); - } - return memory; - } - - /** - * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. - */ - public void free(MemoryBlock memory) { - assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; - executorMemoryManager.free(memory); - synchronized(allocatedNonPageMemory) { - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; - } + long pageSize = page.size(); + memoryManager.tungstenMemoryAllocator().free(page); + releaseExecutionMemory(pageSize); } /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * - * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/ * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). @@ -270,17 +269,15 @@ public long cleanUpAllAllocatedMemory() { } } - synchronized (allocatedNonPageMemory) { - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); - } - } + freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); + return freedBytes; } + + /** + * Returns the memory consumption, in bytes, for the current task + */ + public long getMemoryConsumptionForThisTask() { + return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index c11711966fa8c..f8f2b220e181d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.memory.TaskMemoryManager; + /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. *

@@ -26,7 +28,7 @@ * * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the - * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * 13-bit page numbers assigned by {@link TaskMemoryManager}), this * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. *

* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 85fdaa8115fa3..f43236f41ae7b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -33,14 +33,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -72,7 +71,6 @@ final class ShuffleExternalSorter { @VisibleForTesting final int maxRecordSizeBytes; private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; @@ -105,7 +103,6 @@ final class ShuffleExternalSorter { public ShuffleExternalSorter( TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, int initialSize, @@ -113,7 +110,6 @@ public ShuffleExternalSorter( SparkConf conf, ShuffleWriteMetrics writeMetrics) throws IOException { this.taskMemoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; @@ -124,7 +120,7 @@ public ShuffleExternalSorter( this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); @@ -140,9 +136,9 @@ public ShuffleExternalSorter( private void initializeForWriting() throws IOException { // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested); if (memoryAcquired != memoryRequested) { - shuffleMemoryManager.release(memoryAcquired); + taskMemoryManager.releaseExecutionMemory(memoryAcquired); throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } @@ -272,6 +268,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { */ @VisibleForTesting void spill() throws IOException { + assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -281,7 +278,7 @@ void spill() throws IOException { writeSortedFile(false); final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); inMemSorter = null; - shuffleMemoryManager.release(inMemSorterMemoryUsage); + taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -316,9 +313,13 @@ private long freeMemory() { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); - shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } + if (inMemSorter != null) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); + } allocatedPages.clear(); currentPage = null; currentPagePosition = -1; @@ -337,8 +338,9 @@ public void cleanupResources() { } } if (inMemSorter != null) { - shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); inMemSorter = null; + taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); } } @@ -353,21 +355,20 @@ private void growPointerArrayIfNecessary() throws IOException { logger.debug("Attempting to expand sort pointer array"); final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { - shuffleMemoryManager.release(memoryAcquired); + taskMemoryManager.releaseExecutionMemory(memoryAcquired); spill(); } else { inMemSorter.expandPointerArray(); - shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage); } } } - + /** * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. + * memory from the memory manager and spill if the requested memory can not be obtained. * * @param requiredSpace the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records @@ -386,17 +387,14 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -430,17 +428,14 @@ public void insertRecord( long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e8f050cb2dab1..f6c5c944bd77b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -49,12 +49,11 @@ import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -69,7 +68,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; @@ -103,7 +101,6 @@ public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, SerializedShuffleHandle handle, int mapId, TaskContext taskContext, @@ -117,7 +114,6 @@ public UnsafeShuffleWriter( this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); @@ -197,7 +193,6 @@ private void open() throws IOException { assert (sorter == null); sorter = new ShuffleExternalSorter( memoryManager, - shuffleMemoryManager, blockManager, taskContext, INITIAL_SORT_BUFFER_SIZE, diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index b24eed3952fd6..f035bdac810bd 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -26,7 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -34,7 +33,7 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -70,8 +69,6 @@ public final class BytesToBytesMap { private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; - /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ @@ -169,13 +166,11 @@ public final class BytesToBytesMap { public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { this.taskMemoryManager = taskMemoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -201,21 +196,18 @@ public BytesToBytesMap( public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this( taskMemoryManager, - shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, @@ -260,7 +252,6 @@ private void advanceToNextPage() { if (destructive && currentPage != null) { dataPagesIterator.remove(); this.bmap.taskMemoryManager.freePage(currentPage); - this.bmap.shuffleMemoryManager.release(currentPage.size()); } currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); @@ -572,14 +563,12 @@ public boolean putNewKey( if (useOverflowPage) { // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryRequested = requiredSize + 8; - final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryGranted != memoryRequested) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", memoryRequested); + final long overflowPageSize = requiredSize + 8; + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { + logger.debug("Failed to acquire {} bytes of memory", overflowPageSize); return false; } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested); dataPages.add(overflowPage); dataPage = overflowPage; dataPageBaseObject = overflowPage.getBaseObject(); @@ -655,17 +644,15 @@ public boolean putNewKey( } /** - * Acquire a new page from the {@link ShuffleMemoryManager}. + * Acquire a new page from the memory manager. * @return whether there is enough space to allocate the new page. */ private boolean acquireNewPage() { - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (newPage == null) { logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; @@ -705,7 +692,6 @@ public void free() { MemoryBlock dataPage = dataPagesIterator.next(); dataPagesIterator.remove(); taskMemoryManager.freePage(dataPage); - shuffleMemoryManager.release(dataPage.size()); } assert(dataPages.isEmpty()); } @@ -714,10 +700,6 @@ public TaskMemoryManager getTaskMemoryManager() { return taskMemoryManager; } - public ShuffleMemoryManager getShuffleMemoryManager() { - return shuffleMemoryManager; - } - public long getPageSizeBytes() { return pageSizeBytes; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java index 0c4ebde407cfc..dbf6770e07391 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -17,9 +17,11 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.memory.TaskMemoryManager; + final class RecordPointerAndKeyPrefix { /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * A pointer to a record; see {@link TaskMemoryManager} for a * description of how these addresses are encoded. */ public long recordPointer; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 0a311d2d935ac..e317ea391c556 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,12 +32,11 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -52,7 +51,6 @@ public final class UnsafeExternalSorter { private final RecordComparator recordComparator; private final int initialSize; private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private ShuffleWriteMetrics writeMetrics; @@ -82,7 +80,6 @@ public final class UnsafeExternalSorter { public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, @@ -90,26 +87,24 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( int initialSize, long pageSizeBytes, UnsafeInMemorySorter inMemorySorter) throws IOException { - return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); } public static UnsafeExternalSorter create( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes) throws IOException { - return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); } private UnsafeExternalSorter( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, @@ -118,7 +113,6 @@ private UnsafeExternalSorter( long pageSizeBytes, @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { this.taskMemoryManager = taskMemoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.recordComparator = recordComparator; @@ -261,7 +255,6 @@ private long freeMemory() { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); - shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } // TODO: track in-memory sorter memory usage (SPARK-10474) @@ -309,8 +302,7 @@ private void growPointerArrayIfNecessary() throws IOException { /** * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. + * memory from the memory manager and spill if the requested memory can not be obtained. * * @param requiredSpace the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records @@ -335,23 +327,20 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { } /** - * Acquire a new page from the {@link ShuffleMemoryManager}. + * Acquire a new page from the memory manager. * * If there is not enough space to allocate the new page, spill all existing ones * and try again. If there is still not enough space, report error to the caller. */ private void acquireNewPage() throws IOException { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -379,17 +368,14 @@ public void insertRecord( long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); @@ -441,17 +427,14 @@ public void insertKVRecord( long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f7787e1019c2b..5aad72c374c37 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -21,7 +21,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b5c35c569e45f..398e0936906a3 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -38,9 +38,8 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** @@ -70,10 +69,7 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - // TODO: unify these *MemoryManager classes (SPARK-10984) val memoryManager: MemoryManager, - val shuffleMemoryManager: ShuffleMemoryManager, - val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -340,13 +336,11 @@ object SparkEnv extends Logging { val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) val memoryManager: MemoryManager = if (useLegacyMemoryManager) { - new StaticMemoryManager(conf) + new StaticMemoryManager(conf, numUsableCores) } else { - new UnifiedMemoryManager(conf) + new UnifiedMemoryManager(conf, numUsableCores) } - val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores) - val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( @@ -405,15 +399,6 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - val executorMemoryManager: ExecutorMemoryManager = { - val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { - MemoryAllocator.UNSAFE - } else { - MemoryAllocator.HEAP - } - new ExecutorMemoryManager(allocator) - } - val envInstance = new SparkEnv( executorId, rpcEnv, @@ -431,8 +416,6 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, memoryManager, - shuffleMemoryManager, - executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 63cca80b2d734..af558d6e5b474 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,8 +21,8 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 5df94c6d3a103..f0ae83a9341bd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -20,9 +20,9 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} private[spark] class TaskContextImpl( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c3491bb8b1cf3..9e88d488c0379 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -29,10 +29,10 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -179,7 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 7168ac549106f..6c9a71c3855b0 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -17,20 +17,38 @@ package org.apache.spark.memory +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import com.google.common.annotations.VisibleForTesting +import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.memory.MemoryAllocator /** * An abstract memory manager that enforces how memory is shared between execution and storage. * * In this context, execution memory refers to that used for computation in shuffles, joins, * sorts and aggregations, while storage memory refers to that used for caching and propagating - * internal data across the cluster. There exists one of these per JVM. + * internal data across the cluster. There exists one MemoryManager per JVM. + * + * The MemoryManager abstract base class itself implements policies for sharing execution memory + * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of + * some task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever + * this set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. */ -private[spark] abstract class MemoryManager extends Logging { +private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging { + + // -- Methods related to memory allocation policies and bookkeeping ------------------------------ // The memory store used to evict cached blocks private var _memoryStore: MemoryStore = _ @@ -42,8 +60,10 @@ private[spark] abstract class MemoryManager extends Logging { } // Amount of execution/storage memory in use, accesses must be synchronized on `this` - protected var _executionMemoryUsed: Long = 0 - protected var _storageMemoryUsed: Long = 0 + @GuardedBy("this") protected var _executionMemoryUsed: Long = 0 + @GuardedBy("this") protected var _storageMemoryUsed: Long = 0 + // Map from taskAttemptId -> memory consumption in bytes + @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]() /** * Set the [[MemoryStore]] used by this manager to evict cached blocks. @@ -65,15 +85,6 @@ private[spark] abstract class MemoryManager extends Logging { // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) - /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). - */ - def acquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long - /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. * Blocks evicted in the process, if any, are added to `evictedBlocks`. @@ -102,9 +113,92 @@ private[spark] abstract class MemoryManager extends Logging { } /** - * Release N bytes of execution memory. + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return number of bytes successfully granted (<= N). + */ + @VisibleForTesting + private[memory] def doAcquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + + /** + * Try to acquire up to `numBytes` of execution memory for the current task and return the number + * of bytes obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies + * that control global sharing of memory between execution and storage. */ - def releaseExecutionMemory(numBytes: Long): Unit = synchronized { + private[memory] + final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!executionMemoryForTask.contains(taskAttemptId)) { + executionMemoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + notifyAll() + } + + // Once the cross-task memory allocation policy has decided to grant more memory to a task, + // this method is called in order to actually obtain that execution memory, potentially + // triggering eviction of storage memory: + def acquire(toGrant: Long): Long = synchronized { + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + executionMemoryForTask(taskAttemptId) += acquired + acquired + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = executionMemoryForTask.keys.size + val curMem = executionMemoryForTask(taskAttemptId) + val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum + + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; + // don't let it be negative + val maxToGrant = + math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, freeMemory) + + if (curMem < maxExecutionMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if ( + freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) { + return acquire(toGrant) + } else { + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free") + wait() + } + } else { + return acquire(toGrant) + } + } + 0L // Never reached + } + + @VisibleForTesting + private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized { if (numBytes > _executionMemoryUsed) { logWarning(s"Attempted to release $numBytes bytes of execution " + s"memory when we only have ${_executionMemoryUsed} bytes") @@ -114,6 +208,36 @@ private[spark] abstract class MemoryManager extends Logging { } } + /** + * Release numBytes of execution memory belonging to the given task. + */ + private[memory] + final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { + val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on $numBytes bytes but task only has $curMem") + } + if (executionMemoryForTask.contains(taskAttemptId)) { + executionMemoryForTask(taskAttemptId) -= numBytes + if (executionMemoryForTask(taskAttemptId) <= 0) { + executionMemoryForTask.remove(taskAttemptId) + } + releaseExecutionMemory(numBytes) + } + notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { + val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId) + releaseExecutionMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + /** * Release N bytes of storage memory. */ @@ -155,4 +279,43 @@ private[spark] abstract class MemoryManager extends Logging { _storageMemoryUsed } + /** + * Returns the execution memory consumption, in bytes, for the given task. + */ + private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { + executionMemoryForTask.getOrElse(taskAttemptId, 0L) + } + + // -- Fields related to Tungsten managed memory ------------------------------------------------- + + /** + * The default page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + val pageSizeBytes: Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } + + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryIsAllocatedInHeap: Boolean = + !conf.getBoolean("spark.unsafe.offHeap", false) + + /** + * Allocates memory for use by Unsafe/Tungsten code. + */ + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = + if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE } diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index fa44f3723415d..9c2c2e90a2282 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -33,14 +33,16 @@ import org.apache.spark.storage.{BlockId, BlockStatus} private[spark] class StaticMemoryManager( conf: SparkConf, override val maxExecutionMemory: Long, - override val maxStorageMemory: Long) - extends MemoryManager { + override val maxStorageMemory: Long, + numCores: Int) + extends MemoryManager(conf, numCores) { - def this(conf: SparkConf) { + def this(conf: SparkConf, numCores: Int) { this( conf, StaticMemoryManager.getMaxExecutionMemory(conf), - StaticMemoryManager.getMaxStorageMemory(conf)) + StaticMemoryManager.getMaxStorageMemory(conf), + numCores) } // Max number of bytes worth of blocks to evict when unrolling @@ -52,7 +54,7 @@ private[spark] class StaticMemoryManager( * Acquire N bytes of memory for execution. * @return number of bytes successfully granted (<= N). */ - override def acquireExecutionMemory( + override def doAcquireExecutionMemory( numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { assert(numBytes >= 0) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 5bf78d5b674b3..a3093030a0f93 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -42,10 +42,14 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * up most of the storage space, in which case the new blocks will be evicted immediately * according to their respective storage levels. */ -private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager { +private[spark] class UnifiedMemoryManager( + conf: SparkConf, + maxMemory: Long, + numCores: Int) + extends MemoryManager(conf, numCores) { - def this(conf: SparkConf) { - this(conf, UnifiedMemoryManager.getMaxMemory(conf)) + def this(conf: SparkConf, numCores: Int) { + this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores) } /** @@ -91,7 +95,7 @@ private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) exte * Blocks evicted in the process, if any, are added to `evictedBlocks`. * @return number of bytes successfully granted (<= N). */ - override def acquireExecutionMemory( + private[memory] override def doAcquireExecutionMemory( numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { assert(numBytes >= 0) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9edf9f048f9fd..4fb32ba8cb188 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,8 +25,8 @@ import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -89,10 +89,6 @@ private[spark] abstract class Task[T]( } finally { context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for shuffles - SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() - } Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7c3e2b5a3703b..b0abda4a81b8d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -98,13 +98,14 @@ private[spark] class BlockStoreShuffleReader[K, C]( case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. - val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - sorter.iterator + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala deleted file mode 100644 index 9bd18da47f1a2..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark._ -import org.apache.spark.memory.{StaticMemoryManager, MemoryManager} -import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.unsafe.array.ByteArrayMethods - -/** - * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling - * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory - * from this pool and release it as it spills data out. When a task ends, all its memory will be - * released by the Executor. - * - * This class tries to ensure that each task gets a reasonable share of memory, instead of some - * task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state - * and using wait() and notifyAll() to signal changes. - * - * Use `ShuffleMemoryManager.create()` factory method to create a new instance. - * - * @param memoryManager the interface through which this manager acquires execution memory - * @param pageSizeBytes number of bytes for each page, by default. - */ -private[spark] -class ShuffleMemoryManager protected ( - memoryManager: MemoryManager, - val pageSizeBytes: Long) - extends Logging { - - private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes - - private def currentTaskAttemptId(): Long = { - // In case this is called on the driver, return an invalid task attempt id. - Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) - } - - /** - * Try to acquire up to numBytes memory for the current task, and return the number of bytes - * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active tasks) before it is forced to spill. This can - * happen if the number of tasks increases but an older task had a lot of memory already. - */ - def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!taskMemory.contains(taskAttemptId)) { - taskMemory(taskAttemptId) = 0L - // This will later cause waiting tasks to wake up and check numTasks again - memoryManager.notifyAll() - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - // TODO: simplify this to limit each task to its own slot - while (true) { - val numActiveTasks = taskMemory.keys.size - val curMem = taskMemory(taskAttemptId) - val maxMemory = memoryManager.maxExecutionMemory - val freeMemory = maxMemory - taskMemory.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - // Only give it as much memory as is free, which might be none if it reached 1 / numTasks - val toGrant = math.min(maxToGrant, freeMemory) - - if (curMem < maxMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { - return acquire(toGrant) - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") - memoryManager.wait() - } - } else { - return acquire(toGrant) - } - } - 0L // Never reached - } - - /** - * Acquire N bytes of execution memory from the memory manager for the current task. - * @return number of bytes actually acquired (<= N). - */ - private def acquire(numBytes: Long): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - // TODO: just do this in `acquireExecutionMemory` (SPARK-10985) - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } - taskMemory(taskAttemptId) += acquired - acquired - } - - /** Release numBytes bytes for the current task. */ - def release(numBytes: Long): Unit = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - val curMem = taskMemory.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") - } - if (taskMemory.contains(taskAttemptId)) { - taskMemory(taskAttemptId) -= numBytes - memoryManager.releaseExecutionMemory(numBytes) - } - memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed - } - - /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisTask(): Unit = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.remove(taskAttemptId).foreach { numBytes => - memoryManager.releaseExecutionMemory(numBytes) - } - memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed - } - - /** Returns the memory consumption, in bytes, for the current task */ - def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.getOrElse(taskAttemptId, 0L) - } -} - - -private[spark] object ShuffleMemoryManager { - - def create( - conf: SparkConf, - memoryManager: MemoryManager, - numCores: Int): ShuffleMemoryManager = { - val maxMemory = memoryManager.maxExecutionMemory - val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) - new ShuffleMemoryManager(memoryManager, pageSize) - } - - /** - * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size. - */ - def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { - val conf = new SparkConf - val memoryManager = new StaticMemoryManager( - conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue) - new ShuffleMemoryManager(memoryManager, pageSizeBytes) - } - - @VisibleForTesting - def createForTesting(maxMemory: Long): ShuffleMemoryManager = { - create(maxMemory, 4 * 1024 * 1024) - } - - /** - * Sets the page size, in bytes. - * - * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value - * by looking at the number of cores available to the process, and the total amount of memory, - * and then divide it by a factor of safety. - */ - private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { - val minPageSize = 1L * 1024 * 1024 // 1MB - val maxPageSize = 64L * minPageSize // 64MB - val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() - // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case - val safetyFactor = 16 - val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) - val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.getSizeAsBytes("spark.buffer.pageSize", default) - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 1105167d39d8d..66b6bbc61fe8e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -133,7 +133,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), - env.shuffleMemoryManager, unsafeShuffleHandle, mapId, context, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index bbd9c1ab53cd8..808317b017a0f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -52,13 +52,13 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. new ExternalSorter[K, V, V]( - aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) + context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } sorter.insertAll(records) @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) + val partitionLengths = sorter.writePartitionedFile(blockId, outputFile) shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cfa58f5ef408a..f6d81ee5bf05e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator import org.apache.spark.executor.ShuffleWriteMetrics @@ -55,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, - blockManager: BlockManager = SparkEnv.get.blockManager) + blockManager: BlockManager = SparkEnv.get.blockManager, + context: TaskContext = TaskContext.get()) extends Iterable[(K, C)] with Serializable with Logging with Spillable[SizeTracker] { + if (context == null) { + throw new IllegalStateException( + "Spillable collections should not be instantiated outside of tasks") + } + + // Backwards-compatibility constructor for binary compatibility + def this( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer, + blockManager: BlockManager) { + this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) + } + + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -118,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C]( * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. */ def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + if (currentMap == null) { + throw new IllegalStateException( + "Cannot insert new elements into a map after calling iterator") + } // An update function for the map that we reuse across entries to avoid allocating // a new closure each time var curEntry: Product2[K, V] = null @@ -215,17 +239,26 @@ class ExternalAppendOnlyMap[K, V, C]( } /** - * Return an iterator that merges the in-memory map with the spilled maps. + * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { + if (currentMap == null) { + throw new IllegalStateException( + "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") + } if (spilledMaps.isEmpty) { - currentMap.iterator + CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) } else { new ExternalIterator() } } + private def freeCurrentMap(): Unit = { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } + /** * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps */ @@ -237,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( + currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -493,12 +527,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - val context = TaskContext.get() - // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in - // a TaskContext. - if (context != null) { - context.addTaskCompletionListener(context => cleanup()) - } + context.addTaskCompletionListener(context => cleanup()) } /** Convenience function to hash the given (K, C) pair by the key. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c48c453a90d01..a44e72b7c16d3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} @@ -87,6 +88,7 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( + context: TaskContext, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, @@ -94,6 +96,8 @@ private[spark] class ExternalSorter[K, V, C]( extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + private val conf = SparkEnv.get.conf private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) @@ -640,7 +644,6 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - context: TaskContext, outputFile: File): Array[Long] = { // Track location of each range in the output file @@ -686,8 +689,11 @@ private[spark] class ExternalSorter[K, V, C]( } def stop(): Unit = { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() + releaseMemory() } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index d2a68ca7a3b4c..a76891acf0baf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,8 @@ package org.apache.spark.util.collection -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.{Logging, SparkEnv} /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -40,7 +40,7 @@ private[spark] trait Spillable[C] extends Logging { protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + protected[this] def taskMemoryManager: TaskMemoryManager // Initial threshold for the size of a collection before we start tracking its memory usage // For testing only @@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -92,7 +92,7 @@ private[spark] trait Spillable[C] extends Logging { spill(collection) _elementsRead = 0 _memoryBytesSpilled += currentMemory - releaseMemoryForThisThread() + releaseMemory() } shouldSpill } @@ -103,11 +103,11 @@ private[spark] trait Spillable[C] extends Logging { def memoryBytesSpilled: Long = _memoryBytesSpilled /** - * Release our memory back to the shuffle pool so that other threads can grab it. + * Release our memory back to the execution pool so that other tasks can grab it. */ - private def releaseMemoryForThisThread(): Unit = { + def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold) myMemoryThreshold = initialMemoryThreshold } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java similarity index 74% rename from unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java rename to core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 06fb081183659..f381db0c62653 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -15,33 +15,28 @@ * limitations under the License. */ -package org.apache.spark.unsafe.memory; +package org.apache.spark.memory; import org.junit.Assert; import org.junit.Test; -public class TaskMemoryManagerSuite { +import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryBlock; - @Test - public void leakedNonPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - manager.allocate(1024); // leak memory - Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); - } +public class TaskMemoryManagerSuite { @Test public void leakedPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); manager.allocatePage(4096); // leak memory Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } @Test public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); final MemoryBlock dataPage = manager.allocatePage(256); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: @@ -53,8 +48,8 @@ public void encodePageNumberAndOffsetOffHeap() { @Test public void encodePageNumberAndOffsetOnHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 232ae4d926bcd..7fb2f92ca80e8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -21,18 +21,19 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; public class PackedRecordPointerSuite { @Test public void heap() { + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128); final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, @@ -49,8 +50,9 @@ public void heap() { @Test public void offHeap() { + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128); final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 1ef3c5ff64bac..5049a5306ff21 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -24,11 +24,11 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { @@ -58,8 +58,9 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 29d9823b1f71b..d65926949c036 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -39,7 +39,6 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -54,19 +53,15 @@ import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.GrantEverythingMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + TaskMemoryManager taskMemoryManager; final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); File mergedOutputFile; File tempDir; @@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite { final Serializer serializer = new KryoSerializer(new SparkConf()); TaskMetrics taskMetrics; - @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @@ -111,11 +105,11 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf().set("spark.buffer.pageSize", "128m"); + conf = new SparkConf() + .set("spark.buffer.pageSize", "128m") + .set("spark.unsafe.offHeap", "false"); taskMetrics = new TaskMetrics(); - - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); + taskMemoryManager = new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -203,7 +197,6 @@ private UnsafeShuffleWriter createWriter( blockManager, shuffleBlockResolver, taskMemoryManager, - shuffleMemoryManager, new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, @@ -405,11 +398,12 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { @Test public void writeEnoughDataToTriggerSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to allocate new data page - .then(returnsFirstArg()); // Grant new sort buffer and data page. + taskMemoryManager = spy(taskMemoryManager); + doCallRealMethod() // initialize sort buffer + .doCallRealMethod() // allocate initial data page + .doReturn(0L) // deny request to allocate new page + .doCallRealMethod() // grant new sort buffer and data page + .when(taskMemoryManager).acquireExecutionMemory(anyLong()); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; @@ -417,7 +411,7 @@ public void writeEnoughDataToTriggerSpill() throws Exception { dataToWrite.add(new Tuple2(i, bigByteArray)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -432,18 +426,19 @@ public void writeEnoughDataToTriggerSpill() throws Exception { @Test public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to grow sort buffer - .then(returnsFirstArg()); // Grant new sort buffer and data page. + taskMemoryManager = spy(taskMemoryManager); + doCallRealMethod() // initialize sort buffer + .doCallRealMethod() // allocate initial data page + .doReturn(0L) // deny request to allocate new page + .doCallRealMethod() // grant new sort buffer and data page + .when(taskMemoryManager).acquireExecutionMemory(anyLong()); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -509,13 +504,13 @@ public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; final long pageSizeBytes = 256; final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + taskMemoryManager = spy(taskMemoryManager); + when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, shuffleBlockResolver, taskMemoryManager, - shuffleMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index ab480b60adaed..6e52496cf933b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -21,15 +21,13 @@ import java.nio.ByteBuffer; import java.util.*; +import org.apache.spark.memory.TaskMemoryManager; import org.junit.*; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.*; -import static org.mockito.AdditionalMatchers.geq; -import static org.mockito.Mockito.*; -import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.Platform; @@ -39,42 +37,29 @@ public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); - private ShuffleMemoryManager shuffleMemoryManager; + private GrantEverythingMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; - private TaskMemoryManager sizeLimitedTaskMemoryManager; private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { - shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES); - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); - // Mocked memory manager for tests that check the maximum array size, since actually allocating - // such large arrays will cause us to run out of memory in our tests. - sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class); - when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer( - new Answer() { - @Override - public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { - if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { - throw new OutOfMemoryError("Requested array size exceeds VM limit"); - } - return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]); - } - } - ); + memoryManager = + new GrantEverythingMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); } @After public void tearDown() { Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - Assert.assertEquals(0L, leakedShuffleMemory); + if (taskMemoryManager != null) { + long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask(); + taskMemoryManager = null; + Assert.assertEquals(0L, leakedMemory); } } - protected abstract MemoryAllocator getMemoryAllocator(); + protected abstract boolean useOffHeapMemoryAllocator(); private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; @@ -110,8 +95,7 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.numElements()); final int keyLengthInWords = 10; @@ -126,8 +110,7 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -179,8 +162,7 @@ public void setAndRetrieveAKey() { private void iteratorTestBase(boolean destructive) throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -265,8 +247,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 24; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); + final BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -335,9 +317,7 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES); - + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { @@ -386,8 +366,7 @@ public void randomizedStressTest() { @Test public void randomizedTestWithRecordsLargerThanPageSize() { final long pageSizeBytes = 128; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes); // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); @@ -436,9 +415,9 @@ public void randomizedTestWithRecordsLargerThanPageSize() { @Test public void failureToAllocateFirstPage() { - shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024); - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + memoryManager.markExecutionAsOutOfMemory(); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); + memoryManager.markExecutionAsOutOfMemory(); try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = @@ -454,12 +433,14 @@ public void failureToAllocateFirstPage() { @Test public void failureToGrow() { - shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10); - BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024); try { boolean success = true; int i; - for (i = 0; i < 1024; i++) { + for (i = 0; i < 127; i++) { + if (i > 0) { + memoryManager.markExecutionAsOutOfMemory(); + } final long[] arr = new long[]{i}; final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); success = @@ -478,7 +459,7 @@ public void failureToGrow() { @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES); + new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -486,36 +467,13 @@ public void initialCapacityBoundsChecking() { try { new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, + taskMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } - - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - // Can allocate _at_ the max capacity - // BytesToBytesMap map = new BytesToBytesMap( - // sizeLimitedTaskMemoryManager, - // shuffleMemoryManager, - // BytesToBytesMap.MAX_CAPACITY, - // PAGE_SIZE_BYTES); - // map.free(); - } - - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - @Ignore - public void resizingLargeMap() { - // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, - BytesToBytesMap.MAX_CAPACITY - 64, - PAGE_SIZE_BYTES); - map.growAndRehash(); - map.free(); } @Test @@ -523,8 +481,7 @@ public void testPeakMemoryUsed() { final long recordLengthBytes = 24; final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes); // Since BytesToBytesMap is append-only, we expect the total memory consumption to be // monotonically increasing. More specifically, every time we allocate a new page it @@ -564,8 +521,7 @@ public void testPeakMemoryUsed() { @Test public void testAcquirePageInConstructor() { - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); assertEquals(1, map.getNumDataPages()); map.free(); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java index 5a10de49f54fe..f0bad4d760c1d 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.UNSAFE; + protected boolean useOffHeapMemoryAllocator() { + return true; } - } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java index 12cc9b25d93b3..d76bb4fd05c5f 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.HEAP; + protected boolean useOffHeapMemoryAllocator() { + return false; } - } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a5bbaa95fa456..94d50b94fde3f 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -46,20 +46,19 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final GrantEverythingMemoryManager memoryManager = + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -82,7 +81,6 @@ public int compare( SparkConf sparkConf; File tempDir; - ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @@ -102,7 +100,6 @@ public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); - shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); @@ -143,13 +140,7 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th @After public void tearDown() { try { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); - } - assertEquals(0, leakedUnsafeMemory); + assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); } finally { Utils.deleteRecursively(tempDir); tempDir = null; @@ -178,7 +169,6 @@ private static void insertRecord( private UnsafeExternalSorter newSorter() throws IOException { return UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, @@ -236,12 +226,16 @@ public void testSortingEmptyArrays() throws Exception { @Test public void spillingOccursInResponseToMemoryPressure() throws Exception { - shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes); final UnsafeExternalSorter sorter = newSorter(); - final int numRecords = (int) pageSizeBytes / 4; - for (int i = 0; i <= numRecords; i++) { + // This should be enough records to completely fill up a data page: + final int numRecords = (int) (pageSizeBytes / (4 + 4)); + for (int i = 0; i < numRecords; i++) { insertNumber(sorter, numRecords - i); } + assertEquals(1, sorter.getNumberOfAllocatedPages()); + memoryManager.markExecutionAsOutOfMemory(); + // The insertion of this record should trigger a spill: + insertNumber(sorter, 0); // Ensure that spill files were created assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); // Read back the sorted data: @@ -255,6 +249,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } + assertEquals(numRecords + 1, i); sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @@ -323,7 +318,6 @@ public void testPeakMemoryUsed() throws Exception { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 778e813df6b54..d5de56a0512f9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,11 +26,11 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; public class UnsafeInMemorySorterSuite { @@ -43,7 +43,8 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -64,8 +65,8 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Lychee", "Mango" }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index f58756e6f6179..0242cbc9244a8 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // cause is preserved val thrownDueToTaskFailure = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128) throw new Exception("intentional task failure") iter }.count() @@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // If the task succeeded but memory was leaked, then the task should fail due to that leak val thrownDueToMemoryLeak = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128) iter }.count() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala similarity index 56% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala rename to core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala index c4358f409b6ef..fe102d8aeb2a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala @@ -15,51 +15,25 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.memory import scala.collection.mutable -import org.apache.spark.memory.MemoryManager -import org.apache.spark.shuffle.ShuffleMemoryManager -import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} - -/** - * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. - */ -class TestShuffleMemoryManager - extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) { - private var oom = false - - override def tryToAcquire(numBytes: Long): Long = { +class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) { + private[memory] override def doAcquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { if (oom) { oom = false 0 } else { - // Uncomment the following to trace memory allocations. - // println(s"tryToAcquire $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - val acquired = super.tryToAcquire(numBytes) - acquired + _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory + numBytes } } - - override def release(numBytes: Long): Unit = { - // Uncomment the following to trace memory releases. - // println(s"release $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - super.release(numBytes) - } - - def markAsOutOfMemory(): Unit = { - oom = true - } -} - -private class GrantEverythingMemoryManager extends MemoryManager { - override def acquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes override def acquireStorageMemory( blockId: BlockId, numBytes: Long, @@ -68,8 +42,13 @@ private class GrantEverythingMemoryManager extends MemoryManager { blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { } override def releaseStorageMemory(numBytes: Long): Unit = { } override def maxExecutionMemory: Long = Long.MaxValue override def maxStorageMemory: Long = Long.MaxValue + + private var oom = false + + def markExecutionAsOutOfMemory(): Unit = { + oom = true + } } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 36e4566310715..1265087743a98 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -19,10 +19,14 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} + import org.mockito.Matchers.{any, anyLong} import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.storage.MemoryStore @@ -126,6 +130,136 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, "ensure free space should not have been called!") } + + /** + * Create a MemoryManager with the specified execution memory limit and no storage memory. + */ + protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager + + // -- Tests of sharing of execution memory between tasks ---------------------------------------- + // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. + + implicit val ec = ExecutionContext.global + + test("single task requesting execution memory") { + val manager = createMemoryManager(1000L) + val taskMemoryManager = new TaskMemoryManager(manager, 0) + + assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + + taskMemoryManager.releaseExecutionMemory(500L) + assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L) + + taskMemoryManager.cleanUpAllAllocatedMemory() + assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + } + + test("two tasks requesting full execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 500 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result1, futureTimeout) === 500L) + assert(Await.result(t2Result1, futureTimeout) === 500L) + + // Have both tasks each request 500 bytes more; both should immediately return 0 as they are + // both now at 1 / N + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result2, 200.millis) === 0L) + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("two tasks cannot grow past 1 / N of execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 250 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) } + assert(Await.result(t1Result1, futureTimeout) === 250L) + assert(Await.result(t2Result1, futureTimeout) === 250L) + + // Have both tasks each request 500 bytes more. + // We should only grant 250 bytes to each of them on this second request + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result2, futureTimeout) === 250L) + assert(Await.result(t2Result2, futureTimeout) === 250L) + } + + test("tasks can block to get at least 1 / 2N of execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + t1MemManager.releaseExecutionMemory(250L) + // The memory freed from t1 should now be granted to t2. + assert(Await.result(t2Result1, futureTimeout) === 250L) + // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) } + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("TaskMemoryManager.cleanUpAllAllocatedMemory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + // t1 releases all of its memory, so t2 should be able to grab all of the memory + t1MemManager.cleanUpAllAllocatedMemory() + assert(Await.result(t2Result1, futureTimeout) === 500L) + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t2Result2, futureTimeout) === 500L) + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t2Result3, 200.millis) === 0L) + } + + test("tasks should not be granted a negative amount of execution memory") { + // This is a regression test for SPARK-4715. + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) } + assert(Await.result(t1Result1, futureTimeout) === 700L) + + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) } + assert(Await.result(t2Result1, futureTimeout) === 300L) + + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) } + assert(Await.result(t1Result2, 200.millis) === 0L) + } } private object MemoryManagerSuite { diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala new file mode 100644 index 0000000000000..4b4c3b0311328 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} + +/** + * Helper methods for mocking out memory-management-related classes in tests. + */ +object MemoryTestingUtils { + def fakeTaskContext(env: SparkEnv): TaskContext = { + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) + new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = env.metricsSystem, + internalAccumulators = Seq.empty) + } +} diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 6cae1f871e24b..885c450d6d4f5 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -36,27 +36,35 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { maxExecutionMem: Long, maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem) + conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } + override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + new StaticMemoryManager( + conf, + maxExecutionMemory = maxMemory, + maxStorageMemory = 0, + numCores = 1) + } + test("basic execution memory") { val maxExecutionMem = 1000L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) assert(mm.executionMemoryUsed === 0L) - assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) // Acquire up to the max - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) mm.releaseExecutionMemory(800L) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired mm.releaseExecutionMemory(maxExecutionMem) @@ -108,10 +116,10 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("ain't nobody love like you do") val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) // Only execution memory should increase - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e7baa50dc2cd0..0c97f2bd89651 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -34,11 +34,15 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. */ private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { - val mm = new UnifiedMemoryManager(conf, maxMemory) + val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } + override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + new UnifiedMemoryManager(conf, maxMemory, numCores = 1) + } + private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { mm invokePrivate PrivateMethod[Long]('storageRegionSize)() } @@ -56,18 +60,18 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val maxMemory = 1000L val (mm, _) = makeThings(maxMemory) assert(mm.executionMemoryUsed === 0L) - assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) // Acquire up to the max - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) assert(mm.executionMemoryUsed === maxMemory) - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) assert(mm.executionMemoryUsed === maxMemory) mm.releaseExecutionMemory(800L) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired mm.releaseExecutionMemory(maxMemory) @@ -132,12 +136,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes require(mm.storageMemoryUsed > storageRegionSize, s"bad test: storage memory used should exceed the storage region") // Execution needs to request 250 bytes to evict storage memory - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) assertEnsureFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted - assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) assertEnsureFreeSpaceCalled(ms, 200L) assert(mm.executionMemoryUsed === 300L) mm.releaseAllStorageMemory() @@ -151,7 +155,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes s"bad test: storage memory used should be within the storage region") // Execution cannot evict storage because the latter is within the storage fraction, // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 - assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) assertEnsureFreeSpaceNotCalled(ms) @@ -170,7 +174,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes require(executionRegionSize === expectedExecutionRegionSize, "bad test: storage region size is unexpected") // Acquire enough execution memory to exceed the execution region - assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) @@ -188,7 +192,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.releaseExecutionMemory(maxMemory) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region - assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala deleted file mode 100644 index 5877aa042d4af..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.util.concurrent.CountDownLatch -import java.util.concurrent.atomic.AtomicInteger - -import org.mockito.Mockito._ -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.{SparkFunSuite, TaskContext} -import org.apache.spark.executor.TaskMetrics - -class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { - - val nextTaskAttemptId = new AtomicInteger() - - /** Launch a thread with the given body block and return it. */ - private def startThread(name: String)(body: => Unit): Thread = { - val thread = new Thread("ShuffleMemorySuite " + name) { - override def run() { - try { - val taskAttemptId = nextTaskAttemptId.getAndIncrement - val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) - val taskMetrics = new TaskMetrics - when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) - when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics) - TaskContext.setTaskContext(mockTaskContext) - body - } finally { - TaskContext.unset() - } - } - } - thread.start() - thread - } - - test("single task requesting memory") { - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - assert(manager.tryToAcquire(100L) === 100L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(200L) === 100L) - assert(manager.tryToAcquire(100L) === 0L) - assert(manager.tryToAcquire(100L) === 0L) - - manager.release(500L) - assert(manager.tryToAcquire(300L) === 300L) - assert(manager.tryToAcquire(300L) === 200L) - - manager.releaseMemoryForThisTask() - assert(manager.tryToAcquire(1000L) === 1000L) - assert(manager.tryToAcquire(100L) === 0L) - } - - test("two threads requesting full memory") { - // Two threads request 500 bytes first, wait for each other to get it, and then request - // 500 more; we should immediately return 0 as both are now at 1 / N - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 500L) - assert(state.t2Result1 === 500L) - assert(state.t1Result2 === 0L) - assert(state.t2Result2 === 0L) - } - - - test("tasks cannot grow past 1 / N") { - // Two tasks request 250 bytes first, wait for each other to get it, and then request - // 500 more; we should only grant 250 bytes to each of them on this second request - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 250L) - assert(state.t2Result1 === 250L) - assert(state.t1Result2 === 250L) - assert(state.t2Result2 === 250L) - } - - test("tasks can block to get at least 1 / 2N memory") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases 250 bytes, which should then be granted to t2. Further requests - // by t2 will return false right away because it now has 1 / 2N of the memory. - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result = -1L - var t2Result2 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise - Thread.sleep(300) - manager.release(250L) - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val result = manager.tryToAcquire(250L) - val endTime = System.currentTimeMillis() - state.synchronized { - state.t2Result = result - // A second call should return 0 because we're now already at 1 / 2N - state.t2Result2 = manager.tryToAcquire(100L) - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both threads should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released 250 - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result === 250L, "t2 could not allocate memory") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - assert(state.t2Result2 === 0L, "t1 got extra memory the second time") - } - } - - test("releaseMemoryForThisTask") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases all its memory. t2 should now be able to grab all the memory. - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result1 = -1L - var t2Result2 = -1L - var t2Result3 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other task blocks for some time otherwise - Thread.sleep(300) - manager.releaseMemoryForThisTask() - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val r1 = manager.tryToAcquire(500L) - val endTime = System.currentTimeMillis() - val r2 = manager.tryToAcquire(500L) - val r3 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.t2Result2 = r2 - state.t2Result3 = r3 - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both tasks should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released all of it - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") - assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") - assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - } - } - - test("tasks should not be granted a negative size") { - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - manager.tryToAcquire(700L) - - val latch = new CountDownLatch(1) - startThread("t1") { - manager.tryToAcquire(300L) - latch.countDown() - } - latch.await() // Wait until `t1` calls `tryToAcquire` - - val granted = manager.tryToAcquire(300L) - assert(0 === granted, "granted is negative") - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index cc44c676b27ac..6e3f500e15dc0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -61,7 +61,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val store = new BlockManager(name, rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(store.memoryStore) @@ -261,7 +261,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1) val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) memManager.setMemoryStore(failableStore.memoryStore) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f3fab33ca2e31..d49015afcd594 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -68,7 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) @@ -823,7 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200) + val memoryManager = new StaticMemoryManager( + conf, + maxExecutionMemory = Long.MaxValue, + maxStorageMemory = 1200, + numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 5cb506ea2164e..dc3185a6d505a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec - +import org.apache.spark.memory.MemoryTestingUtils class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} @@ -32,8 +32,11 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = buf1 ++= buf2 - private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( - createCombiner[T], mergeValue[T], mergeCombiners[T]) + private def createExternalMap[T] = { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T], context = context) + } private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -49,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { conf } - test("simple insert") { + test("single insert insert") { val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] - - // Single insert map.insert(1, 10) - var it = map.iterator + val it = map.iterator assert(it.hasNext) val kv = it.next() assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) assert(!it.hasNext) + sc.stop() + } - // Multiple insert + test("multiple insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) map.insert(2, 20) map.insert(3, 30) - it = map.iterator + val it = map.iterator assert(it.hasNext) assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( (1, ArrayBuffer[Int](10)), @@ -144,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] + val nullInt = null.asInstanceOf[Int] map.insert(1, 5) map.insert(2, 6) map.insert(3, 7) - assert(map.size === 3) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( - (1, Seq[Int](5)), - (2, Seq[Int](6)), - (3, Seq[Int](7)) - )) - - // Null keys - val nullInt = null.asInstanceOf[Int] + map.insert(4, nullInt) map.insert(nullInt, 8) - assert(map.size === 4) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( + map.insert(nullInt, nullInt) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted)) + assert(result === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)), - (nullInt, Seq[Int](8)) + (4, Seq[Int](nullInt)), + (nullInt, Seq[Int](nullInt, 8)) )) - // Null values - map.insert(4, nullInt) - map.insert(nullInt, nullInt) - assert(map.size === 5) - val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result === Set[(Int, Set[Int])]( - (1, Set[Int](5)), - (2, Set[Int](6)), - (3, Set[Int](7)), - (4, Set[Int](nullInt)), - (nullInt, Set[Int](nullInt, 8)) - )) sc.stop() } @@ -344,7 +334,9 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val map = + new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index e2cb791771d99..d7b2d07a40052 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.collection +import org.apache.spark.memory.MemoryTestingUtils + import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -98,6 +100,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -109,7 +112,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner _, mergeValue _, mergeCombiners _) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) val collisionPairs = Seq( ("Aa", "BB"), // 2112 @@ -158,8 +161,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) - val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) + val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) @@ -180,6 +184,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -188,7 +193,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) - val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) + val sorter = + new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None) sorter.insertAll( (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) assert(sorter.numSpills > 0, "sorter did not spill") @@ -204,6 +210,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -214,7 +221,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), @@ -271,31 +278,32 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def emptyDataStream(conf: SparkConf) { conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None) assert(sorter.iterator.toSeq === Seq()) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), None, None) + context, Some(agg), Some(new HashPartitioner(3)), None, None) assert(sorter2.iterator.toSeq === Seq()) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord), None) assert(sorter3.iterator.toSeq === Seq()) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), None, None) + context, None, Some(new HashPartitioner(3)), None, None) assert(sorter4.iterator.toSeq === Seq()) sorter4.stop() } @@ -303,6 +311,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def fewElementsPerPartition(conf: SparkConf) { conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] @@ -313,28 +322,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None) sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), None, None) + context, Some(agg), Some(new HashPartitioner(7)), None, None) sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord), None) sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) + context, None, Some(new HashPartitioner(7)), None, None) sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() @@ -345,12 +354,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { conf.set("spark.shuffle.manager", "sort") conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val ord = implicitly[Ordering[Int]] val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord), None) sorter.insertAll(elements) assert(sorter.numSpills > 0, "sorter did not spill") val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -432,8 +442,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val diskBlockManager = sc.env.blockManager.diskBlockManager val ord = implicitly[Ordering[Int]] val expectedSize = if (withFailures) size - 1 else size + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord), None) if (withFailures) { intercept[SparkException] { sorter.insertAll((0 until size).iterator.map { i => @@ -501,7 +512,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { None } val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None - val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = + new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None) sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) if (withSpilling) { assert(sorter.numSpills > 0, "sorter did not spill") @@ -538,8 +551,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val testData = Array.tabulate(size) { _ => rand.nextInt().toString } + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter1 = new ExternalSorter[String, String, String]( - None, None, Some(wrongOrdering), None) + context, None, None, Some(wrongOrdering), None) val thrown = intercept[IllegalArgumentException] { sorter1.insertAll(testData.iterator.map(i => (i, i))) assert(sorter1.numSpills > 0, "sorter did not spill") @@ -561,7 +575,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) sorter2.insertAll(testData.iterator.map(i => (i, i))) assert(sorter2.numSpills > 0, "sorter did not spill") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 7d94e0566faa9..810c74fd2fb96 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -67,7 +67,6 @@ public UnsafeExternalRowSorter( final TaskContext taskContext = TaskContext.get(); sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), - sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, new RowComparator(ordering, schema.length()), diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 09511ff35f785..82c645df284de 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -22,7 +22,6 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkEnv; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -32,7 +31,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -88,8 +87,6 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. - * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with - * other tasks. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) @@ -99,15 +96,14 @@ public UnsafeFixedWidthAggregationMap( StructType aggregationBufferSchema, StructType groupingKeySchema, TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); + this.map = + new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; // Initialize the buffer for aggregation value @@ -256,7 +252,7 @@ public void printPerfMetrics() { public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map); + SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9df5780e4fd84..46301f0042954 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -24,7 +24,6 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.TaskContext; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -34,7 +33,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.collection.unsafe.sort.*; /** @@ -50,14 +49,19 @@ public final class UnsafeKVExternalSorter { private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes) - throws IOException { - this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null); + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes) throws IOException { + this(keySchema, valueSchema, blockManager, pageSizeBytes, null); } - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes, + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -73,7 +77,6 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, if (map == null) { sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, @@ -115,7 +118,6 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( taskContext.taskMemoryManager(), - shuffleMemoryManager, blockManager, taskContext, new KVComparator(ordering, keySchema.length()), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 7cd0f7b81e46c..fb2fc98e34fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import scala.collection.mutable.ArrayBuffer import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType * * This iterator first uses hash-based aggregation to process input rows. It uses * a hash map to store groups and their corresponding aggregation buffers. If we - * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * this map cannot allocate memory from memory manager, * it switches to sort-based aggregation. The process of the switch has the following step: * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. @@ -480,10 +480,9 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, + TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity - SparkEnv.get.shuffleMemoryManager.pageSizeBytes, + TaskContext.get().taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index cfd64c1d9eb34..1b59b19d9420d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -344,8 +344,7 @@ private[sql] class DynamicPartitionWriterContainer( StructType.fromAttributes(partitionColumns), StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, - SparkEnv.get.shuffleMemoryManager, - SparkEnv.get.shuffleMemoryManager.pageSizeBytes) + TaskContext.get().taskMemoryManager().pageSizeBytes) sorter.insertKV(currentKey, getOutputRow(inputRow)) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bc255b27502b2..cc8abb1ba463c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.local.LocalNode import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.MemoryLocation import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -320,21 +320,20 @@ private[joins] final class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory - val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + // TODO(josh): This needs to be revisited before we merge this patch; making this change now + // so that tests compile: + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) - val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes) + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) - // Dummy shuffle memory manager which always grants all memory allocation requests. - // We use this because it doesn't make sense count shared broadcast variables' memory usage - // towards individual tasks' quotas. In the future, we should devise a better way of handling - // this. - val shuffleMemoryManager = - ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes) + // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit + // during code review binaryMap = new BytesToBytesMap( taskMemoryManager, - shuffleMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 9385e5734db5c..dd92dda480601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -49,7 +49,8 @@ case class Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) val context = TaskContext.get() @@ -124,7 +125,7 @@ case class TungstenSort( } } - val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 1739798a24e0a..dbf4863b767bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,13 +23,12 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String /** @@ -49,23 +48,22 @@ class UnsafeFixedWidthAggregationMapSuite private def emptyAggregationBuffer: InternalRow = InternalRow(0) private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + private var memoryManager: GrantEverythingMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null - private var shuffleMemoryManager: TestShuffleMemoryManager = null def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) - assert(leakedShuffleMemory === 0) taskMemoryManager = null } TaskContext.unset() } test(name) { - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - shuffleMemoryManager = new TestShuffleMemoryManager + val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + memoryManager = new GrantEverythingMemoryManager(conf) + taskMemoryManager = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -110,7 +108,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity, PAGE_SIZE_BYTES, false // disable perf metrics @@ -125,7 +122,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -153,7 +149,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -176,14 +171,13 @@ class UnsafeFixedWidthAggregationMapSuite testWithMemoryLeakDetection("test external sorting") { // Memory consumption in the beginning of the task. - val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -200,7 +194,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) + assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -214,7 +208,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -238,7 +232,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -258,7 +251,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -281,14 +274,13 @@ class UnsafeFixedWidthAggregationMapSuite testWithMemoryLeakDetection("test external sorting with empty records") { // Memory consumption in the beginning of the task. - val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, StructType(Nil), StructType(Nil), taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -303,7 +295,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) + assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -311,7 +303,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -332,34 +324,28 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { - val smm = ShuffleMemoryManager.createForTesting(65536) val pageSize = 4096 val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, groupKeySchema, taskMemoryManager, - smm, 128, // initial capacity pageSize, false // disable perf metrics ) - // Insert into the map until we've run out of space val rand = new Random(42) - var hasSpace = true - while (hasSpace) { + for (i <- 1 to 100) { val str = rand.nextString(1024) val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) - if (buf == null) { - hasSpace = false - } else { - buf.setInt(0, str.length) - } + buf.setInt(0, str.length) } - - // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte - assert(smm.tryToAcquire(1) === 0) + // Simulate running out of space + memoryManager.markExecutionAsOutOfMemory() + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + assert(buf == null) // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 // because we would try to acquire space for the in-memory sorter pointer array before diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index d3be568a8758c..13dc1754c9ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -108,9 +108,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, spill: Boolean): Unit = { - - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager + val memoryManager = + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, partitionId = 0, @@ -121,14 +121,14 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { internalAccumulators = Seq.empty)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) // 1% chance we will spill if (rand.nextDouble() < 0.01 && spill) { - shuffleMemMgr.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -170,12 +170,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) // Make sure there is no memory leak - val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory - if (shuffleMemMgr != null) { - val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() - assert(0L === leakedShuffleMemory) - } - assert(0 === leakedUnsafeMemory) + assert(0 === taskMemMgr.cleanUpAllAllocatedMemory) TaskContext.unset() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1680d7e0a85ce..d32572b54b8a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -112,7 +113,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val data = (1 to 10000).iterator.map { i => (i, converter(Row(i))) } + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) + val taskContext = new TaskContextImpl( + 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, partitioner = Some(new HashPartitioner(10)), serializer = Some(new UnsafeRowSerializer(numFields = 1))) @@ -122,10 +128,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { assert(sorter.numSpills > 0) // Merging spilled files should not throw assertion error - val taskContext = - new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } { // Clean up if (sc != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index cc0ac1b07c21a..475037bd45379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.memory.TaskMemoryManager class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { test("memory acquired on construction") { - val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) TaskContext.setTaskContext(taskContext) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index b2b6848719639..c17fb7238151b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -254,7 +254,7 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java deleted file mode 100644 index cbbe8594627a5..0000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import java.lang.ref.WeakReference; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import javax.annotation.concurrent.GuardedBy; - -/** - * Manages memory for an executor. Individual operators / tasks allocate memory through - * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. - */ -public class ExecutorMemoryManager { - - /** - * Allocator, exposed for enabling untracked allocations of temporary data structures. - */ - public final MemoryAllocator allocator; - - /** - * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. - */ - final boolean inHeap; - - @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap>>(); - - private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; - - /** - * Construct a new ExecutorMemoryManager. - * - * @param allocator the allocator that will be used - */ - public ExecutorMemoryManager(MemoryAllocator allocator) { - this.inHeap = allocator instanceof HeapMemoryAllocator; - this.allocator = allocator; - } - - /** - * Returns true if allocations of the given size should go through the pooling mechanism and - * false otherwise. - */ - private boolean shouldPool(long size) { - // Very small allocations are less likely to benefit from pooling. - // At some point, we should explore supporting pooling for off-heap memory, but for now we'll - // ignore that case in the interest of simplicity. - return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). - */ - MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { - synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); - if (pool != null) { - while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); - return memory; - } - } - bufferPoolsBySize.remove(size); - } - } - return allocator.allocate(size); - } else { - return allocator.allocate(size); - } - } - - void free(MemoryBlock memory) { - final long size = memory.size(); - if (shouldPool(size)) { - synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); - if (pool == null) { - pool = new LinkedList>(); - bufferPoolsBySize.put(size, pool); - } - pool.add(new WeakReference(memory)); - } - } else { - allocator.free(memory); - } - } - -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 6722301df19d1..ebe90d9e63d83 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -17,22 +17,71 @@ package org.apache.spark.unsafe.memory; +import javax.annotation.concurrent.GuardedBy; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; + /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ public class HeapMemoryAllocator implements MemoryAllocator { + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap<>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + return size >= POOLING_THRESHOLD_BYTES; + } + @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + } long[] array = new long[(int) (size / 8)]; return MemoryBlock.fromLongArray(array); } @Override public void free(MemoryBlock memory) { - // Do nothing + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList<>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference<>(memory)); + } + } else { + // Do nothing + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index dd75820834370..e3e79471154df 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -30,9 +30,10 @@ public class MemoryBlock extends MemoryLocation { /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * MemoryManager. This is package-private and is modified by MemoryManager. + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. */ - int pageNumber = -1; + public int pageNumber = -1; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); From 87f82a5fb9c4350a97c761411069245f07aad46f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 25 Oct 2015 21:57:34 -0700 Subject: [PATCH 035/324] [SPARK-11127][STREAMING] upgrade AWS SDK and Kinesis Client Library (KCL) AWS SDK 1.9.40 is the latest 1.9.x release. KCL 1.5.1 is the latest release that using AWS SDK 1.9.x. The main goal is to have Kinesis consumer be able to read messages generated from Kinesis Producer Library (KPL). The API should be compatible with old versions. tdas brkyvz Author: Xiangrui Meng Closes #9153 from mengxr/SPARK-11127. --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 445e65c0459bf..3dfc434fb553b 100644 --- a/pom.xml +++ b/pom.xml @@ -152,8 +152,8 @@ 1.7.7 hadoop2 0.7.1 - 1.9.16 - 1.3.0 + 1.9.40 + 1.4.0 4.3.2 From 07ced43424447699e47106de9ca2fa714756bdeb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 25 Oct 2015 22:47:39 -0700 Subject: [PATCH 036/324] [SPARK-11253] [SQL] reset all accumulators in physical operators before execute an action With this change, our query execution listener can get the metrics correctly. The UI still looks good after this change. screen shot 2015-10-23 at 11 25 14 am screen shot 2015-10-23 at 11 25 01 am Author: Wenchen Fan Closes #9215 from cloud-fan/metric. --- .../org/apache/spark/sql/DataFrame.scala | 3 + .../sql/execution/metric/SQLMetrics.scala | 7 +- .../sql/util/DataFrameCallbackSuite.scala | 81 ++++++++++++++++++- 3 files changed, 87 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index bf25bcde208e2..25ad3bb993f4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1974,6 +1974,9 @@ class DataFrame private[sql]( */ private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { try { + df.queryExecution.executedPlan.foreach { plan => + plan.metrics.valuesIterator.foreach(_.reset()) + } val start = System.nanoTime() val result = action(df) val end = System.nanoTime() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 075b7ad881112..1c253e3942e95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -28,7 +28,12 @@ import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} */ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( name: String, val param: SQLMetricParam[R, T]) - extends Accumulable[R, T](param.zero, param, Some(name), true) + extends Accumulable[R, T](param.zero, param, Some(name), true) { + + def reset(): Unit = { + this.value = param.zero + } +} /** * Create a layer for specialized metric. We cannot add `@specialized` to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index eb056cd519717..b46b0d2f6040a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.util -import org.apache.spark.SparkException +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ import org.apache.spark.sql.{functions, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.test.SharedSQLContext -import scala.collection.mutable.ArrayBuffer - class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { import testImplicits._ import functions._ @@ -54,6 +54,8 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1)._1 == "count") assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) assert(metrics(1)._3 > 0) + + sqlContext.listenerManager.unregister(listener) } test("execute callback functions when a DataFrame action failed") { @@ -79,5 +81,78 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0)._1 == "collect") assert(metrics(0)._2.analyzed.isInstanceOf[Project]) assert(metrics(0)._3.getMessage == e.getMessage) + + sqlContext.listenerManager.unregister(listener) + } + + test("get numRows metrics by callback") { + val metrics = ArrayBuffer.empty[Long] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += qe.executedPlan.longMetric("numInputRows").value.value + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() + + assert(metrics.length == 3) + assert(metrics(0) == 1) + assert(metrics(1) == 1) + assert(metrics(2) == 2) + + sqlContext.listenerManager.unregister(listener) + } + + // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never + // updated, we can filter it out later. However, when we aggregate(sum) accumulator values at + // driver side for SQL physical operators, these -1 values will make our result smaller. + // A easy fix is to create a new SQLMetric(including new MetricValue, MetricParam, etc.), but we + // can do it later because the impact is just too small (1048576 tasks for 1 MB). + ignore("get size metrics by callback") { + val metrics = ArrayBuffer.empty[Long] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += qe.executedPlan.longMetric("dataSize").value.value + val bottomAgg = qe.executedPlan.children(0).children(0) + metrics += bottomAgg.longMetric("dataSize").value.value + } + } + sqlContext.listenerManager.register(listener) + + val sparkListener = new SaveInfoListener + sqlContext.sparkContext.addSparkListener(sparkListener) + + val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") + df.groupBy("i").count().collect() + + def getPeakExecutionMemory(stageId: Int): Long = { + val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables + .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + + assert(peakMemoryAccumulator.size == 1) + peakMemoryAccumulator.head._2.value.toLong + } + + assert(sparkListener.getCompletedStageInfos.length == 2) + val bottomAggDataSize = getPeakExecutionMemory(0) + val topAggDataSize = getPeakExecutionMemory(1) + + // For this simple case, the peakExecutionMemory of a stage should be the data size of the + // aggregate operator, as we only have one memory consuming operator per stage. + assert(metrics.length == 2) + assert(metrics(0) == topAggDataSize) + assert(metrics(1) == bottomAggDataSize) + + sqlContext.listenerManager.unregister(listener) } } From 05c4bdb57947f44924b4fbdd8e4e2101f2f816f5 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 26 Oct 2015 09:25:19 +0100 Subject: [PATCH 037/324] [SPARK-11279][PYSPARK] Add DataFrame#toDF in PySpark Author: Jeff Zhang Closes #9248 from zjffdu/SPARK-11279. --- python/pyspark/sql/dataframe.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 36fc6e0611dc1..3baff8147753d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1266,6 +1266,18 @@ def drop(self, col): raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix + def toDF(self, *cols): + """Returns a new class:`DataFrame` that with new specified column names + + :param cols: list of new column names (string) + + >>> df.toDF('f1', 'f2').collect() + [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] + """ + jdf = self._jdf.toDF(self._jseq(cols)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. From 616be29c7f2ebc184bd5ec97210da36a2174d80c Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 26 Oct 2015 09:34:15 +0000 Subject: [PATCH 038/324] [SPARK-5966][WIP] Spark-submit deploy-mode cluster is not compatible with master local> MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … master local> Author: Kevin Yu Closes #9220 from kevinyu98/working_on_spark-5966. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 640cc325281a9..84ae122f44370 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -328,6 +328,8 @@ object SparkSubmit { case (STANDALONE, CLUSTER) if args.isR => printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") + case (LOCAL, CLUSTER) => + printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => From 3689beb98b6a6db61e35049fdb57b0cd6aad8019 Mon Sep 17 00:00:00 2001 From: Narine Kokhlikyan Date: Mon, 26 Oct 2015 15:12:25 -0700 Subject: [PATCH 039/324] [SPARK-10979][SPARKR] Sparkrmerge: Add merge to DataFrame with R signature Add merge function to DataFrame, which supports R signature. https://stat.ethz.ch/R-manual/R-devel/library/base/html/merge.html Author: Narine Kokhlikyan Closes #9012 from NarineK/sparkrmerge. --- R/pkg/R/DataFrame.R | 140 ++++++++++++++++++++++++++++++- R/pkg/inst/tests/test_sparkSQL.R | 37 +++++++- 2 files changed, 169 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 2acbd081cd504..c8944459542af 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1457,15 +1457,147 @@ setMethod("join", dataFrame(sdf) }) -#' @rdname merge +#' #' @name merge #' @aliases join +#' @title Merges two data frames +#' @param x the first data frame to be joined +#' @param y the second data frame to be joined +#' @param by a character vector specifying the join columns. If by is not +#' specified, the common column names in \code{x} and \code{y} will be used. +#' @param by.x a character vector specifying the joining columns for x. +#' @param by.y a character vector specifying the joining columns for y. +#' @param all.x a boolean value indicating whether all the rows in x should +#' be including in the join +#' @param all.y a boolean value indicating whether all the rows in y should +#' be including in the join +#' @param sort a logical argument indicating whether the resulting columns should be sorted +#' @details If all.x and all.y are set to FALSE, a natural join will be returned. If +#' all.x is set to TRUE and all.y is set to FALSE, a left outer join will +#' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right +#' outer join will be returned. If all.x and all.y are set to TRUE, a full +#' outer join will be returned. +#' @rdname merge +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) +#' merge(df1, df2) # Performs a Cartesian +#' merge(df1, df2, by = "col1") # Performs an inner join based on expression +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE, all.y = TRUE) +#' merge(df1, df2, by.x = "col1", by.y = "col2", all = TRUE, sort = FALSE) +#' merge(df1, df2, by = "col1", all = TRUE, suffixes = c("-X", "-Y")) +#' } setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), - function(x, y, joinExpr = NULL, joinType = NULL, ...) { - join(x, y, joinExpr, joinType) - }) + function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, + all = FALSE, all.x = all, all.y = all, + sort = TRUE, suffixes = c("_x","_y"), ... ) { + + if (length(suffixes) != 2) { + stop("suffixes must have length 2") + } + + # join type is identified based on the values of all, all.x and all.y + # default join type is inner, according to R it should be natural but since it + # is not supported in spark inner join is used + joinType <- "inner" + if (all || (all.x && all.y)) { + joinType <- "outer" + } else if (all.x) { + joinType <- "left_outer" + } else if (all.y) { + joinType <- "right_outer" + } + # join expression is based on by.x, by.y if both by.x and by.y are not missing + # or on by, if by.x or by.y are missing or have different lengths + if (length(by.x) > 0 && length(by.x) == length(by.y)) { + joinX <- by.x + joinY <- by.y + } else if (length(by) > 0) { + # if join columns have the same name for both dataframes, + # they are used in join expression + joinX <- by + joinY <- by + } else { + # if by or both by.x and by.y have length 0, use Cartesian Product + joinRes <- join(x, y) + return (joinRes) + } + + # sets alias for making colnames unique in dataframes 'x' and 'y' + colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + + # selects columns with their aliases from dataframes + # in case same column names are present in both data frames + xsel <- select(x, colsX) + ysel <- select(y, colsY) + + # generates join conditions and adds them into a list + # it also considers alias names of the columns while generating join conditions + joinColumns <- lapply(seq_len(length(joinX)), function(i) { + colX <- joinX[[i]] + colY <- joinY[[i]] + + if (colX %in% by) { + colX <- paste(colX, suffixes[1], sep = "") + } + if (colY %in% by) { + colY <- paste(colY, suffixes[2], sep = "") + } + + colX <- getColumn(xsel, colX) + colY <- getColumn(ysel, colY) + + colX == colY + }) + + # concatenates join columns with '&' and executes join + joinExpr <- Reduce("&", joinColumns) + joinRes <- join(xsel, ysel, joinExpr, joinType) + + # sorts the result by 'by' columns if sort = TRUE + if (sort && length(by) > 0) { + colNameWithSuffix <- paste(by, suffixes[2], sep = "") + joinRes <- do.call("arrange", c(joinRes, colNameWithSuffix, decreasing = FALSE)) + } + + joinRes + }) + +#' +#' Creates a list of columns by replacing the intersected ones with aliases. +#' The name of the alias column is formed by concatanating the original column name and a suffix. +#' +#' @param x a DataFrame on which the +#' @param intersectedColNames a list of intersected column names +#' @param suffix a suffix for the column name +#' @return list of columns +#' +generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { + allColNames <- names(x) + # sets alias for making colnames unique in dataframe 'x' + cols <- lapply(allColNames, function(colName) { + col <- getColumn(x, colName) + if (colName %in% intersectedColNames) { + newJoin <- paste(colName, suffix, sep = "") + if (newJoin %in% allColNames){ + stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + "Please use different suffixes for the intersected columns.") + } + col <- alias(col, newJoin) + } + col + }) + cols +} #' UnionAll #' diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 67d8b23cd7b8d..540854d114b23 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1105,11 +1105,40 @@ test_that("join() and merge() on a DataFrame", { expect_equal(count(joined9), 4) expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) - merged <- select(merge(df, df2, df$name == df2$name, "outer"), - alias(df$age + 5, "newAge"), df$name, df2$test) - expect_equal(names(merged), c("newAge", "name", "test")) + merged <- merge(df, df2, by.x = "name", by.y = "name", all.x = TRUE, all.y = TRUE) expect_equal(count(merged), 4) - expect_equal(collect(orderBy(merged, merged$name))$newAge[3], 24) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19) + + merged <- merge(df, df2, suffixes = c("-X","-Y")) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30) + + merged <- merge(df, df2, by = "name", suffixes = c("-X","-Y"), sort = FALSE) + expect_equal(count(merged), 3) + expect_equal(names(merged), c("age", "name-X", "name-Y", "test")) + expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael") + + merged <- merge(df, df2, by = "name", all = T, sort = T) + expect_equal(count(merged), 4) + expect_equal(names(merged), c("age", "name_x", "name_y", "test")) + expect_equal(collect(orderBy(merged, merged$"name_y"))$"name_x"[1], "Andy") + + merged <- merge(df, df2, by = NULL) + expect_equal(count(merged), 12) + expect_equal(names(merged), c("age", "name", "name", "test")) + + mockLines3 <- c("{\"name\":\"Michael\", \"name_y\":\"Michael\", \"test\": \"yes\"}", + "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}", + "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}", + "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") + jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines3, jsonPath3) + df3 <- jsonFile(sqlContext, jsonPath3) + expect_error(merge(df, df3), + paste("The following column name: name_y occurs more than once in the 'DataFrame'.", + "Please use different suffixes for the intersected columns.", sep = "")) }) test_that("toJSON() returns an RDD of the correct values", { From b60aab8a95e2a35a1d4023a9d0a0d9724e4164f9 Mon Sep 17 00:00:00 2001 From: Frank Rosner Date: Mon, 26 Oct 2015 15:46:59 -0700 Subject: [PATCH 040/324] [SPARK-11258] Converting a Spark DataFrame into an R data.frame is slow / requires a lot of memory https://issues.apache.org/jira/browse/SPARK-11258 I was not able to locate an existing unit test for this function so I wrote one. Author: Frank Rosner Closes #9222 from FRosner/master. --- .../org/apache/spark/sql/api/r/SQLUtils.scala | 16 ++++---- .../spark/sql/api/r/SQLUtilsSuite.scala | 38 +++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index b0120a8d0dc4f..b3f134614c6bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -130,16 +130,18 @@ private[r] object SQLUtils { } def dfToCols(df: DataFrame): Array[Array[Any]] = { - // localDF is Array[Row] - val localDF = df.collect() + val localDF: Array[Row] = df.collect() val numCols = df.columns.length + val numRows = localDF.length - // result is Array[Array[Any]] - (0 until numCols).map { colIdx => - localDF.map { row => - row(colIdx) + val colArray = new Array[Array[Any]](numCols) + for (colNo <- 0 until numCols) { + colArray(colNo) = new Array[Any](numRows) + for (rowNo <- 0 until numRows) { + colArray(colNo)(rowNo) = localDF(rowNo)(colNo) } - }.toArray + } + colArray } def saveMode(mode: String): SaveMode = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala new file mode 100644 index 0000000000000..f54e23e3aa6cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/r/SQLUtilsSuite.scala @@ -0,0 +1,38 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.r + +import org.apache.spark.sql.test.SharedSQLContext + +class SQLUtilsSuite extends SharedSQLContext { + + import testImplicits._ + + test("dfToCols should collect and transpose a data frame") { + val df = Seq( + (1, 2, 3), + (4, 5, 6) + ).toDF + assert(SQLUtils.dfToCols(df) === Array( + Array(1, 4), + Array(2, 5), + Array(3, 6) + )) + } + +} From 4bb2b3698ffed58cc5159db36f8b11573ad26b23 Mon Sep 17 00:00:00 2001 From: Alexander Slesarenko Date: Mon, 26 Oct 2015 23:49:14 +0100 Subject: [PATCH 041/324] [SQL][DOC] Minor document fixes in interfaces.scala rxin just noticed this while reading the code. Author: Alexander Slesarenko Closes #9284 from aslesarenko/doc-typos. --- .../org/apache/spark/sql/sources/interfaces.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 84eef0f8a672c..a9a013e936fd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.SerializableConfiguration * This allows users to give the data source alias as the format type over the fully qualified * class name. * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * @since 1.5.0 */ @@ -74,7 +74,7 @@ trait DataSourceRegister { * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * @since 1.3.0 */ @@ -100,7 +100,7 @@ trait RelationProvider { * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that * users need to provide a schema when using a [[SchemaRelationProvider]]. @@ -135,7 +135,7 @@ trait SchemaRelationProvider { * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the * data source 'org.apache.spark.sql.json.DefaultSource' * - * A new instance of this class with be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a DDL call is made. * * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when @@ -195,7 +195,7 @@ trait CreatableRelationProvider { * implementation should inherit from one of the descendant `Scan` classes, which define various * abstract methods for execution. * - * BaseRelations must also define a equality function that only returns true when the two + * BaseRelations must also define an equality function that only returns true when the two * instances will return the same data. This equality function is used when determining when * it is safe to substitute cached results for a given relation. * @@ -208,7 +208,7 @@ abstract class BaseRelation { /** * Returns an estimated size of this relation in bytes. This information is used by the planner - * to decided when it is safe to broadcast a relation and can be overridden by sources that + * to decide when it is safe to broadcast a relation and can be overridden by sources that * know the size ahead of time. By default, the system will assume that tables are too * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. @@ -383,7 +383,7 @@ abstract class OutputWriter { /** * ::Experimental:: - * A [[BaseRelation]] that provides much of the common code required for formats that store their + * A [[BaseRelation]] that provides much of the common code required for relations that store their * data to an HDFS compatible filesystem. * * For the read path, similar to [[PrunedFilteredScan]], it can eliminate unneeded columns and From d4c397a64af4cec899fdaa3e617ed20333cc567d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 26 Oct 2015 18:27:02 -0700 Subject: [PATCH 042/324] [SPARK-11325] [SQL] Alias 'alias' in Scala's DataFrame API Author: Nong Li Closes #9286 from nongli/spark-11325. --- .../scala/org/apache/spark/sql/DataFrame.scala | 14 ++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 25ad3bb993f4e..32d9b0b1d9888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -698,6 +698,20 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) + /** + * Returns a new [[DataFrame]] with an alias set. Same as `as`. + * @group dfops + * @since 1.6.0 + */ + def alias(alias: String): DataFrame = as(alias) + + /** + * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`. + * @group dfops + * @since 1.6.0 + */ + def alias(alias: Symbol): DataFrame = as(alias) + /** * Selects a set of column based expressions. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f4c7aa34e560c..59565a6b13d40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -105,6 +105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.head(2).head.schema === testData.schema) } + test("dataframe alias") { + val df = Seq(Tuple1(1)).toDF("c").as("t") + val dfAlias = df.alias("t2") + df.col("t.c") + dfAlias.col("t2.c") + } + test("simple explode") { val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") From 82464fb2e02ca4e4d425017815090497b79dc93f Mon Sep 17 00:00:00 2001 From: Stephen De Gennaro Date: Mon, 26 Oct 2015 19:55:10 -0700 Subject: [PATCH 043/324] [SPARK-10947] [SQL] With schema inference from JSON into a Dataframe, add option to infer all primitive object types as strings Currently, when a schema is inferred from a JSON file using sqlContext.read.json, the primitive object types are inferred as string, long, boolean, etc. However, if the inferred type is too specific (JSON obviously does not enforce types itself), this can cause issues with merging dataframe schemas. This pull request adds the option "primitivesAsString" to the JSON DataFrameReader which when true (defaults to false if not set) will infer all primitives as strings. Below is an example usage of this new functionality. ``` val jsonDf = sqlContext.read.option("primitivesAsString", "true").json(sampleJsonFile) scala> jsonDf.printSchema() root |-- bigInteger: string (nullable = true) |-- boolean: string (nullable = true) |-- double: string (nullable = true) |-- integer: string (nullable = true) |-- long: string (nullable = true) |-- null: string (nullable = true) |-- string: string (nullable = true) ``` Author: Stephen De Gennaro Closes #9249 from stephend-realitymine/stephend-primitives. --- .../apache/spark/sql/DataFrameReader.scala | 10 +- .../datasources/json/InferSchema.scala | 20 ++- .../datasources/json/JSONRelation.scala | 14 +- .../datasources/json/JsonSuite.scala | 138 +++++++++++++++++- 4 files changed, 171 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 824220d85e04d..6a194a443ab17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -256,8 +256,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def json(jsonRDD: RDD[String]): DataFrame = { val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble + val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean sqlContext.baseRelationToDataFrame( - new JSONRelation(Some(jsonRDD), samplingRatio, userSpecifiedSchema, None, None)(sqlContext)) + new JSONRelation( + Some(jsonRDD), + samplingRatio, + primitivesAsString, + userSpecifiedSchema, + None, + None)(sqlContext) + ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index d0780028dacb1..b9914c581a657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -35,7 +35,8 @@ private[sql] object InferSchema { def apply( json: RDD[String], samplingRatio: Double = 1.0, - columnNameOfCorruptRecords: String): StructType = { + columnNameOfCorruptRecords: String, + primitivesAsString: Boolean = false): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) { json @@ -50,7 +51,7 @@ private[sql] object InferSchema { try { Utils.tryWithResource(factory.createParser(row)) { parser => parser.nextToken() - inferField(parser) + inferField(parser, primitivesAsString) } } catch { case _: JsonParseException => @@ -70,14 +71,14 @@ private[sql] object InferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser): DataType = { + private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser) + inferField(parser, primitivesAsString) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -92,7 +93,10 @@ private[sql] object InferSchema { case START_OBJECT => val builder = Seq.newBuilder[StructField] while (nextUntil(parser, END_OBJECT)) { - builder += StructField(parser.getCurrentName, inferField(parser), nullable = true) + builder += StructField( + parser.getCurrentName, + inferField(parser, primitivesAsString), + nullable = true) } StructType(builder.result().sortBy(_.name)) @@ -103,11 +107,15 @@ private[sql] object InferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser)) + elementType = compatibleType(elementType, inferField(parser, primitivesAsString)) } ArrayType(elementType) + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType + + case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ parser.getNumberType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 794b889a93627..5f104fca7d629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -52,14 +52,23 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) - new JSONRelation(None, samplingRatio, dataSchema, None, partitionColumns, paths)(sqlContext) + new JSONRelation( + None, + samplingRatio, + primitivesAsString, + dataSchema, + None, + partitionColumns, + paths)(sqlContext) } } private[sql] class JSONRelation( val inputRDD: Option[RDD[String]], val samplingRatio: Double, + val primitivesAsString: Boolean, val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], @@ -105,7 +114,8 @@ private[sql] class JSONRelation( InferSchema( inputRDD.getOrElse(createBaseRdd(files)), samplingRatio, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord, + primitivesAsString) } checkConstraints(jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 7540223bf2771..d3fd409291f29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -632,6 +632,136 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Loading a JSON dataset primitivesAsString returns schema with primitive types as strings") { + val dir = Utils.createTempDir() + dir.delete() + val path = dir.getCanonicalPath + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path) + + val expectedSchema = StructType( + StructField("bigInteger", StringType, true) :: + StructField("boolean", StringType, true) :: + StructField("double", StringType, true) :: + StructField("integer", StringType, true) :: + StructField("long", StringType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select * from jsonTable"), + Row("92233720368547758070", + "true", + "1.7976931348623157E308", + "10", + "21474836470", + null, + "this is a simple string.") + ) + } + + test("Loading a JSON dataset primitivesAsString returns complex fields as strings") { + val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1) + + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(StringType, true), true), true) :: + StructField("arrayOfBigInteger", ArrayType(StringType, true), true) :: + StructField("arrayOfBoolean", ArrayType(StringType, true), true) :: + StructField("arrayOfDouble", ArrayType(StringType, true), true) :: + StructField("arrayOfInteger", ArrayType(StringType, true), true) :: + StructField("arrayOfLong", ArrayType(StringType, true), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType, true), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", StringType, true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil), true), true) :: + StructField("struct", StructType( + StructField("field1", StringType, true) :: + StructField("field2", StringType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( + StructField("field1", ArrayType(StringType, true), true) :: + StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + + jsonDF.registerTempTable("jsonTable") + + // Access elements of a primitive array. + checkAnswer( + sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from jsonTable"), + Row("str1", "str2", null) + ) + + // Access an array of null values. + checkAnswer( + sql("select arrayOfNull from jsonTable"), + Row(Seq(null, null, null, null)) + ) + + // Access elements of a BigInteger array (we use DecimalType internally). + checkAnswer( + sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"), + Row("922337203685477580700", "-922337203685477580800", null) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray1[0], arrayOfArray1[1] from jsonTable"), + Row(Seq("1", "2", "3"), Seq("str1", "str2")) + ) + + // Access elements of an array of arrays. + checkAnswer( + sql("select arrayOfArray2[0], arrayOfArray2[1] from jsonTable"), + Row(Seq("1", "2", "3"), Seq("1.1", "2.1", "3.1")) + ) + + // Access elements of an array inside a filed with the type of ArrayType(ArrayType). + checkAnswer( + sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from jsonTable"), + Row("str2", "2.1") + ) + + // Access elements of an array of structs. + checkAnswer( + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + Row( + Row("true", "str1", null), + Row("false", null, null), + Row(null, null, null), + null) + ) + + // Access a struct and fields inside of it. + checkAnswer( + sql("select struct, struct.field1, struct.field2 from jsonTable"), + Row( + Row("true", "92233720368547758070"), + "true", + "92233720368547758070") :: Nil + ) + + // Access an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1, structWithArrayFields.field2 from jsonTable"), + Row(Seq("4", "5", "6"), Seq("str1", "str2")) + ) + + // Access elements of an array field of a struct. + checkAnswer( + sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), + Row("5", null) + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() @@ -960,9 +1090,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val jsonDF = sqlContext.read.json(primitiveFieldAndType) val primTable = sqlContext.read.json(jsonDF.toJSON) - primTable.registerTempTable("primativeTable") + primTable.registerTempTable("primitiveTable") checkAnswer( - sql("select * from primativeTable"), + sql("select * from primitiveTable"), Row(new java.math.BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, @@ -1039,24 +1169,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val relation0 = new JSONRelation( Some(empty), 1.0, + false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), None, None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, + false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), None, None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, + false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), None, None)(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, + false, Some(StructType(StructField("b", IntegerType, true) :: Nil)), None, None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) From dc3220ce11c7513b1452c82ee82cb86e908bcc2d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 26 Oct 2015 20:58:18 -0700 Subject: [PATCH 044/324] [SPARK-11209][SPARKR] Add window functions into SparkR [step 1]. Author: Sun Rui Closes #9193 from sun-rui/SPARK-11209. --- R/pkg/NAMESPACE | 4 + R/pkg/R/functions.R | 98 +++++++++++++++++++ R/pkg/R/generics.R | 16 +++ R/pkg/inst/tests/test_sparkSQL.R | 2 + .../apache/spark/api/r/RBackendHandler.scala | 3 +- 5 files changed, 122 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 52f7a0106aae6..b73bed3128242 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -119,6 +119,7 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "cumeDist", "date_add", "date_format", "date_sub", @@ -150,8 +151,10 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "lag", "last", "last_day", + "lead", "least", "length", "levenshtein", @@ -177,6 +180,7 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "ntile", "otherwise", "pmod", "quarter", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a72fb7bb42fef..366290fe66276 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2013,3 +2013,101 @@ setMethod("ifelse", "otherwise", no) column(jc) }) + +###################### Window functions###################### + +#' cumeDist +#' +#' Window function: returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row. +#' +#' N = total number of rows in the partition +#' cumeDist(x) = number of values before (and including) x / N +#' +#' This is equivalent to the CUME_DIST function in SQL. +#' +#' @rdname cumeDist +#' @name cumeDist +#' @family window_funcs +#' @export +#' @examples \dontrun{cumeDist()} +setMethod("cumeDist", + signature(x = "missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + column(jc) + }) + +#' lag +#' +#' Window function: returns the value that is `offset` rows before the current row, and +#' `defaultValue` if there is less than `offset` rows before the current row. For example, +#' an `offset` of one will return the previous row at any given point in the window partition. +#' +#' This is equivalent to the LAG function in SQL. +#' +#' @rdname lag +#' @name lag +#' @family window_funcs +#' @export +#' @examples \dontrun{lag(df$c)} +setMethod("lag", + signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lag", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' lead +#' +#' Window function: returns the value that is `offset` rows after the current row, and +#' `null` if there is less than `offset` rows after the current row. For example, +#' an `offset` of one will return the next row at any given point in the window partition. +#' +#' This is equivalent to the LEAD function in SQL. +#' +#' @rdname lead +#' @name lead +#' @family window_funcs +#' @export +#' @examples \dontrun{lead(df$c)} +setMethod("lead", + signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + function(x, offset, defaultValue = NULL) { + col <- if (class(x) == "Column") { + x@jc + } else { + x + } + + jc <- callJStatic("org.apache.spark.sql.functions", + "lead", col, as.integer(offset), defaultValue) + column(jc) + }) + +#' ntile +#' +#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window +#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second +#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. +#' +#' This is equivalent to the NTILE function in SQL. +#' +#' @rdname ntile +#' @name ntile +#' @family window_funcs +#' @export +#' @examples \dontrun{ntile(1)} +setMethod("ntile", + signature(x = "numeric"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ntile", as.integer(x)) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4a419f785e92c..c11c3c8d3e150 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -714,6 +714,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname cumeDist +#' @export +setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) + #' @rdname datediff #' @export setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) @@ -790,6 +794,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +#' @rdname lag +#' @export +setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) + #' @rdname last #' @export setGeneric("last", function(x) { standardGeneric("last") }) @@ -798,6 +806,10 @@ setGeneric("last", function(x) { standardGeneric("last") }) #' @export setGeneric("last_day", function(x) { standardGeneric("last_day") }) +#' @rdname lead +#' @export +setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) + #' @rdname least #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) @@ -858,6 +870,10 @@ setGeneric("negate", function(x) { standardGeneric("negate") }) #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) +#' @rdname ntile +#' @export +setGeneric("ntile", function(x) { standardGeneric("ntile") }) + #' @rdname countDistinct #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 540854d114b23..e1d4499925fe7 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -829,6 +829,8 @@ test_that("column functions", { c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c13 <- cumeDist() + ntile(1) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 2a792d81994fd..0095548c463cc 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -224,7 +224,8 @@ private[r] class RBackendHandler(server: RBackend) case _ => parameterType } } - if (!parameterWrapperType.isInstance(args(i))) { + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { argMatched = false } } From a150e6c1b03b64a35855b8074b2fe077a6081a34 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 26 Oct 2015 21:14:26 -0700 Subject: [PATCH 045/324] [SPARK-10562] [SQL] support mixed case partitionBy column names for tables stored in metastore https://issues.apache.org/jira/browse/SPARK-10562 Author: Wenchen Fan Closes #9226 from cloud-fan/par. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 61 +++++++++++-------- .../sql/hive/MetastoreDataSourcesSuite.scala | 9 ++- .../sql/hive/execution/SQLQuerySuite.scala | 11 ++++ 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fdb576bedbbaf..f4d45714fae4e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -143,6 +143,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + def partColsFromParts: Option[Seq[String]] = { + table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols => + (0 until numPartCols.toInt).map { index => + val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull + if (partCol == null) { + throw new AnalysisException( + "Could not read partitioned columns from the metastore because it is corrupted " + + s"(missing part $index of the it, $numPartCols parts are expected).") + } + + partCol + } + } + } + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. // After SPARK-6024, we removed this flag. // Although we are not using spark.sql.sources.schema any more, we need to still support. @@ -155,7 +170,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // We only need names at here since userSpecifiedSchema we loaded from the metastore // contains partition columns. We can always get datatypes of partitioning columns // from userSpecifiedSchema. - val partitionColumns = table.partitionColumns.map(_.name) + val partitionColumns = partColsFromParts.getOrElse(Nil) // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... @@ -218,25 +233,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - val metastorePartitionColumns = userSpecifiedSchema.map { schema => - val fields = partitionColumns.map(col => schema(col)) - fields.map { field => - HiveColumn( - name = field.name, - hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), - comment = "") - }.toSeq - }.getOrElse { - if (partitionColumns.length > 0) { - // The table does not have a specified schema, which means that the schema will be inferred - // when we load the table. So, we are not expecting partition columns and we will discover - // partitions when we load the table. However, if there are specified partition columns, - // we simply ignore them and provide a warning message. - logWarning( - s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + - s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") + if (userSpecifiedSchema.isDefined && partitionColumns.length > 0) { + tableProperties.put("spark.sql.sources.schema.numPartCols", partitionColumns.length.toString) + partitionColumns.zipWithIndex.foreach { case (partCol, index) => + tableProperties.put(s"spark.sql.sources.schema.partCol.$index", partCol) } - Seq.empty[HiveColumn] + } + + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { + // The table does not have a specified schema, which means that the schema will be inferred + // when we load the table. So, we are not expecting partition columns and we will discover + // partitions when we load the table. However, if there are specified partition columns, + // we simply ignore them and provide a warning message. + logWarning( + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } val tableType = if (isExternal) { @@ -255,8 +266,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive HiveTable( specifiedDatabase = Option(dbName), name = tblName, - schema = Seq.empty, - partitionColumns = metastorePartitionColumns, + schema = Nil, + partitionColumns = Nil, tableType = tableType, properties = tableProperties.toMap, serdeProperties = options) @@ -272,14 +283,14 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } - val partitionColumns = schemaToHiveColumn(relation.partitionColumns) - val dataColumns = schemaToHiveColumn(relation.schema).filterNot(partitionColumns.contains) + assert(partitionColumns.isEmpty) + assert(relation.partitionColumns.isEmpty) HiveTable( specifiedDatabase = Option(dbName), name = tblName, - schema = dataColumns, - partitionColumns = partitionColumns, + schema = schemaToHiveColumn(relation.schema), + partitionColumns = Nil, tableType = tableType, properties = tableProperties.toMap, serdeProperties = options, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index d2928876887bd..f74eb1500b989 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -753,10 +753,15 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv invalidateTable(tableName) val metastoreTable = catalog.client.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + + val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt + assert(numPartCols == 2) + val actualPartitionColumns = StructType( - metastoreTable.partitionColumns.map(c => - StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + (0 until numPartCols).map { index => + df.schema(metastoreTable.properties(s"spark.sql.sources.schema.partCol.$index")) + }) // Make sure partition columns are correctly stored in metastore. assert( expectedPartitionColumns.sameType(actualPartitionColumns), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 396150be76e83..fd380641dcc71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1410,4 +1410,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-10562: partition by column with mixed case name") { + withTable("tbl10562") { + val df = Seq(2012 -> "a").toDF("Year", "val") + df.write.partitionBy("Year").saveAsTable("tbl10562") + checkAnswer(sql("SELECT Year FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT yEAr FROM tbl10562"), Row(2012)) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year > 2015"), Nil) + checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) + } + } } From 943d4fa204a827ca8ecc39d9cf04e86890ee9840 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 26 Oct 2015 21:17:53 -0700 Subject: [PATCH 046/324] [SPARK-11289][DOC] Substitute code examples in ML features extractors with include_example mengxr https://issues.apache.org/jira/browse/SPARK-11289 I make some changes in ML feature extractors. I.e. TF-IDF, Word2Vec, and CountVectorizer. I add new example code in spark/examples, hope it is the right place to add those examples. Author: Xusen Yin Closes #9266 from yinxusen/SPARK-11289. --- docs/ml-features.md | 217 +----------------- .../ml/JavaCountVectorizerExample.java | 69 ++++++ .../spark/examples/ml/JavaTfIdfExample.java | 79 +++++++ .../examples/ml/JavaWord2VecExample.java | 67 ++++++ examples/src/main/python/ml/tf_idf_example.py | 47 ++++ .../src/main/python/ml/word2vec_example.py | 45 ++++ .../examples/ml/CountVectorizerExample.scala | 59 +++++ .../spark/examples/ml/TfIdfExample.scala | 53 +++++ .../spark/examples/ml/Word2VecExample.scala | 53 +++++ 9 files changed, 480 insertions(+), 209 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java create mode 100644 examples/src/main/python/ml/tf_idf_example.py create mode 100644 examples/src/main/python/ml/word2vec_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 44a98829393e7..142afac2f3f95 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -37,23 +37,7 @@ In the following code segment, we start with a set of sentences. We split each Refer to the [HashingTF Scala docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and the [IDF Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IDF) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} - -val sentenceData = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsData = tokenizer.transform(sentenceData) -val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) -val featurizedData = hashingTF.transform(wordsData) -val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") -val idfModel = idf.fit(featurizedData) -val rescaledData = idfModel.transform(featurizedData) -rescaledData.select("features", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TfIdfExample.scala %}

@@ -61,49 +45,7 @@ rescaledData.select("features", "label").take(3).foreach(println) Refer to the [HashingTF Java docs](api/java/org/apache/spark/ml/feature/HashingTF.html) and the [IDF Java docs](api/java/org/apache/spark/ml/feature/IDF.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.IDF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsData = tokenizer.transform(sentenceData); -int numFeatures = 20; -HashingTF hashingTF = new HashingTF() - .setInputCol("words") - .setOutputCol("rawFeatures") - .setNumFeatures(numFeatures); -DataFrame featurizedData = hashingTF.transform(wordsData); -IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); -IDFModel idfModel = idf.fit(featurizedData); -DataFrame rescaledData = idfModel.transform(featurizedData); -for (Row r : rescaledData.select("features", "label").take(3)) { - Vector features = r.getAs(0); - Double label = r.getDouble(1); - System.out.println(features); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTfIdfExample.java %}
@@ -111,24 +53,7 @@ for (Row r : rescaledData.select("features", "label").take(3)) { Refer to the [HashingTF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.HashingTF) and the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import HashingTF, IDF, Tokenizer - -sentenceData = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsData = tokenizer.transform(sentenceData) -hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) -featurizedData = hashingTF.transform(wordsData) -idf = IDF(inputCol="rawFeatures", outputCol="features") -idfModel = idf.fit(featurizedData) -rescaledData = idfModel.transform(featurizedData) -for features_label in rescaledData.select("features", "label").take(3): - print(features_label) -{% endhighlight %} +{% include_example python/ml/tf_idf_example.py %}
@@ -149,26 +74,7 @@ In the following code segment, we start with a set of documents, each of which i Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Word2Vec - -// Input data: Each row is a bag of words from a sentence or document. -val documentDF = sqlContext.createDataFrame(Seq( - "Hi I heard about Spark".split(" "), - "I wish Java could use case classes".split(" "), - "Logistic regression models are neat".split(" ") -).map(Tuple1.apply)).toDF("text") - -// Learn a mapping from words to Vectors. -val word2Vec = new Word2Vec() - .setInputCol("text") - .setOutputCol("result") - .setVectorSize(3) - .setMinCount(0) -val model = word2Vec.fit(documentDF) -val result = model.transform(documentDF) -result.select("result").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/Word2VecExample.scala %}
@@ -176,43 +82,7 @@ result.select("result").take(3).foreach(println) Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; - -JavaSparkContext jsc = ... -SQLContext sqlContext = ... - -// Input data: Each row is a bag of words from a sentence or document. -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), - RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), - RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) -}); -DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); - -// Learn a mapping from words to Vectors. -Word2Vec word2Vec = new Word2Vec() - .setInputCol("text") - .setOutputCol("result") - .setVectorSize(3) - .setMinCount(0); -Word2VecModel model = word2Vec.fit(documentDF); -DataFrame result = model.transform(documentDF); -for (Row r: result.select("result").take(3)) { - System.out.println(r); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaWord2VecExample.java %}
@@ -220,22 +90,7 @@ for (Row r: result.select("result").take(3)) { Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Word2Vec - -# Input data: Each row is a bag of words from a sentence or document. -documentDF = sqlContext.createDataFrame([ - ("Hi I heard about Spark".split(" "), ), - ("I wish Java could use case classes".split(" "), ), - ("Logistic regression models are neat".split(" "), ) -], ["text"]) -# Learn a mapping from words to Vectors. -word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result") -model = word2Vec.fit(documentDF) -result = model.transform(documentDF) -for feature in result.select("result").take(3): - print(feature) -{% endhighlight %} +{% include_example python/ml/word2vec_example.py %}
@@ -283,30 +138,7 @@ Refer to the [CountVectorizer Scala docs](api/scala/index.html#org.apache.spark. and the [CountVectorizerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.CountVectorizer -import org.apache.spark.mllib.util.CountVectorizerModel - -val df = sqlContext.createDataFrame(Seq( - (0, Array("a", "b", "c")), - (1, Array("a", "b", "b", "c", "a")) -)).toDF("id", "words") - -// fit a CountVectorizerModel from the corpus -val cvModel: CountVectorizerModel = new CountVectorizer() - .setInputCol("words") - .setOutputCol("features") - .setVocabSize(3) - .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary - .fit(df) - -// alternatively, define CountVectorizerModel with a-priori vocabulary -val cvm = new CountVectorizerModel(Array("a", "b", "c")) - .setInputCol("words") - .setOutputCol("features") - -cvModel.transform(df).select("features").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/CountVectorizerExample.scala %}
@@ -315,40 +147,7 @@ Refer to the [CountVectorizer Java docs](api/java/org/apache/spark/ml/feature/Co and the [CountVectorizerModel Java docs](api/java/org/apache/spark/ml/feature/CountVectorizerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.CountVectorizer; -import org.apache.spark.ml.feature.CountVectorizerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; - -// Input data: Each row is a bag of words from a sentence or document. -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("a", "b", "c")), - RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) -)); -StructType schema = new StructType(new StructField [] { - new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); - -// fit a CountVectorizerModel from the corpus -CountVectorizerModel cvModel = new CountVectorizer() - .setInputCol("text") - .setOutputCol("feature") - .setVocabSize(3) - .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary - .fit(df); - -// alternatively, define CountVectorizerModel with a-priori vocabulary -CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) - .setInputCol("text") - .setOutputCol("feature"); - -cvModel.transform(df).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java new file mode 100644 index 0000000000000..ac33adb65292f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.CountVectorizer; +import org.apache.spark.ml.feature.CountVectorizerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaCountVectorizerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaCountVectorizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("a", "b", "c")), + RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) + )); + StructType schema = new StructType(new StructField [] { + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + // fit a CountVectorizerModel from the corpus + CountVectorizerModel cvModel = new CountVectorizer() + .setInputCol("text") + .setOutputCol("feature") + .setVocabSize(3) + .setMinDF(2) + .fit(df); + + // alternatively, define CountVectorizerModel with a-priori vocabulary + CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) + .setInputCol("text") + .setOutputCol("feature"); + + cvModel.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java new file mode 100644 index 0000000000000..a41a5ec9bff05 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.IDF; +import org.apache.spark.ml.feature.IDFModel; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTfIdfExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTfIdfExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(0, "I wish Java could use case classes"), + RowFactory.create(1, "Logistic regression models are neat") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); + int numFeatures = 20; + HashingTF hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("rawFeatures") + .setNumFeatures(numFeatures); + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { + Vector features = r.getAs(0); + Double label = r.getDouble(1); + System.out.println(features); + System.out.println(label); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java new file mode 100644 index 0000000000000..d472375ca9825 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.Word2Vec; +import org.apache.spark.ml.feature.Word2VecModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaWord2VecExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaWord2VecExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + // Learn a mapping from words to Vectors. + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + for (Row r : result.select("result").take(3)) { + System.out.println(r); + } + // $example off$ + } +} diff --git a/examples/src/main/python/ml/tf_idf_example.py b/examples/src/main/python/ml/tf_idf_example.py new file mode 100644 index 0000000000000..c92313378eec7 --- /dev/null +++ b/examples/src/main/python/ml/tf_idf_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import HashingTF, IDF, Tokenizer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="TfIdfExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (0, "I wish Java could use case classes"), + (1, "Logistic regression models are neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsData = tokenizer.transform(sentenceData) + hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) + featurizedData = hashingTF.transform(wordsData) + idf = IDF(inputCol="rawFeatures", outputCol="features") + idfModel = idf.fit(featurizedData) + rescaledData = idfModel.transform(featurizedData) + for features_label in rescaledData.select("features", "label").take(3): + print(features_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/word2vec_example.py b/examples/src/main/python/ml/word2vec_example.py new file mode 100644 index 0000000000000..53c77feb10145 --- /dev/null +++ b/examples/src/main/python/ml/word2vec_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Word2Vec +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="Word2VecExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Input data: Each row is a bag of words from a sentence or document. + documentDF = sqlContext.createDataFrame([ + ("Hi I heard about Spark".split(" "), ), + ("I wish Java could use case classes".split(" "), ), + ("Logistic regression models are neat".split(" "), ) + ], ["text"]) + # Learn a mapping from words to Vectors. + word2Vec = Word2Vec(vectorSize=3, minCount=0, inputCol="text", outputCol="result") + model = word2Vec.fit(documentDF) + result = model.transform(documentDF) + for feature in result.select("result").take(3): + print(feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala new file mode 100644 index 0000000000000..ba916f66c4c07 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object CountVectorizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("CounterVectorizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, Array("a", "b", "c")), + (1, Array("a", "b", "b", "c", "a")) + )).toDF("id", "words") + + // fit a CountVectorizerModel from the corpus + val cvModel: CountVectorizerModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) + .setMinDF(2) + .fit(df) + + // alternatively, define CountVectorizerModel with a-priori vocabulary + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features") + + cvModel.transform(df).select("features").show() + // $example off$ + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala new file mode 100644 index 0000000000000..40c33e4e7d44e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TfIdfExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("TfIdfExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceData = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (0, "I wish Java could use case classes"), + (1, "Logistic regression models are neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val wordsData = tokenizer.transform(sentenceData) + val hashingTF = new HashingTF() + .setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) + val featurizedData = hashingTF.transform(wordsData) + val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") + val idfModel = idf.fit(featurizedData) + val rescaledData = idfModel.transform(featurizedData) + rescaledData.select("features", "label").take(3).foreach(println) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala new file mode 100644 index 0000000000000..631ab4c8efa0d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Word2Vec +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object Word2VecExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("Word2Vec example") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Input data: Each row is a bag of words from a sentence or document. + val documentDF = sqlContext.createDataFrame(Seq( + "Hi I heard about Spark".split(" "), + "I wish Java could use case classes".split(" "), + "Logistic regression models are neat".split(" ") + ).map(Tuple1.apply)).toDF("text") + + // Learn a mapping from words to Vectors. + val word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0) + val model = word2Vec.fit(documentDF) + val result = model.transform(documentDF) + result.select("result").take(3).foreach(println) + // $example off$ + } +} +// scalastyle:on println From 5d4f6abec4e371093e01c084656173e9cfabf29b Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 26 Oct 2015 21:28:18 -0700 Subject: [PATCH 047/324] [SPARK-10271][PYSPARK][MLLIB] Added @since tags to pyspark.mllib.clustering Duplicated the since decorator from pyspark.sql into pyspark (also tweaked to handle functions without docstrings). Added since to methods + "versionadded::" to classes (derived from the git file history in pyspark). Author: noelsmith Closes #8627 from noel-smith/SPARK-10271-since-mllib-clustering. --- python/pyspark/mllib/clustering.py | 69 +++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 6964a45db2493..c451df17cf264 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -28,7 +28,7 @@ from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector @@ -96,21 +96,26 @@ class KMeansModel(Saveable, Loader): ... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)])) >>> model.clusterCenters [array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])] + + .. versionadded:: 0.9.0 """ def __init__(self, centers): self.centers = centers @property + @since('1.0.0') def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return self.centers @property + @since('1.4.0') def k(self): """Total number of clusters.""" return len(self.centers) + @since('0.9.0') def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 @@ -126,6 +131,7 @@ def predict(self, x): best_distance = distance return best + @since('1.4.0') def computeCost(self, rdd): """ Return the K-means cost (sum of squared distances of points to @@ -135,20 +141,32 @@ def computeCost(self, rdd): [_convert_to_vector(c) for c in self.centers]) return cost + @since('1.4.0') def save(self, sc, path): + """ + Save this model to the given path. + """ java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.4.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) return KMeansModel(_java2py(sc, java_model.clusterCenters())) class KMeans(object): + """ + .. versionadded:: 0.9.0 + """ @classmethod + @since('0.9.0') def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): """Train a k-means clustering model.""" @@ -222,9 +240,12 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): True >>> labels[3]==labels[4] True + + .. versionadded:: 1.3.0 """ @property + @since('1.4.0') def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is @@ -233,6 +254,7 @@ def weights(self): return array(self.call("weights")) @property + @since('1.4.0') def gaussians(self): """ Array of MultivariateGaussian where gaussians[i] represents @@ -243,10 +265,12 @@ def gaussians(self): for gaussian in zip(*self.call("gaussians"))] @property + @since('1.4.0') def k(self): """Number of gaussians in mixture.""" return len(self.weights) + @since('1.3.0') def predict(self, x): """ Find the cluster to which the points in 'x' has maximum membership @@ -262,6 +286,7 @@ def predict(self, x): raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @since('1.3.0') def predictSoft(self, x): """ Find the membership of each point in 'x' to all mixture components. @@ -279,6 +304,7 @@ def predictSoft(self, x): "but got %s." % type(x)) @classmethod + @since('1.5.0') def load(cls, sc, path): """Load the GaussianMixtureModel from disk. @@ -302,8 +328,11 @@ class GaussianMixture(object): :param maxIterations: Number of iterations. Default to 100 :param seed: Random Seed :param initialModel: GaussianMixtureModel for initializing learning + + .. versionadded:: 1.3.0 """ @classmethod + @since('1.3.0') def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """Train a Gaussian Mixture clustering model.""" initialModelWeights = None @@ -358,15 +387,19 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.5.0 """ @property + @since('1.5.0') def k(self): """ Returns the number of clusters. """ return self.call("k") + @since('1.5.0') def assignments(self): """ Returns the cluster assignments of this model. @@ -375,7 +408,11 @@ def assignments(self): lambda x: (PowerIterationClustering.Assignment(*x))) @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ model = cls._load_java(sc, path) wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) return PowerIterationClusteringModel(wrapper) @@ -390,9 +427,12 @@ class PowerIterationClustering(object): From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data. + + .. versionadded:: 1.5.0 """ @classmethod + @since('1.5.0') def train(cls, rdd, k, maxIterations=100, initMode="random"): """ :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the @@ -415,6 +455,8 @@ def train(cls, rdd, k, maxIterations=100, initMode="random"): class Assignment(namedtuple("Assignment", ["id", "cluster"])): """ Represents an (id, cluster) tuple. + + .. versionadded:: 1.5.0 """ @@ -474,17 +516,21 @@ class StreamingKMeansModel(KMeansModel): 0 >>> stkm.predict([1.5, 1.5]) 1 + + .. versionadded:: 1.5.0 """ def __init__(self, clusterCenters, clusterWeights): super(StreamingKMeansModel, self).__init__(centers=clusterCenters) self._clusterWeights = list(clusterWeights) @property + @since('1.5.0') def clusterWeights(self): """Return the cluster weights.""" return self._clusterWeights @ignore_unicode_prefix + @since('1.5.0') def update(self, data, decayFactor, timeUnit): """Update the centroids, according to data @@ -523,6 +569,8 @@ class StreamingKMeans(object): :param decayFactor: float, forgetfulness of the previous centroids. :param timeUnit: can be "batches" or "points". If points, then the decayfactor is raised to the power of no. of new points. + + .. versionadded:: 1.5.0 """ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): self._k = k @@ -533,6 +581,7 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): self._timeUnit = timeUnit self._model = None + @since('1.5.0') def latestModel(self): """Return the latest model""" return self._model @@ -547,16 +596,19 @@ def _validate(self, dstream): "Expected dstream to be of type DStream, " "got type %s" % type(dstream)) + @since('1.5.0') def setK(self, k): """Set number of clusters.""" self._k = k return self + @since('1.5.0') def setDecayFactor(self, decayFactor): """Set decay factor.""" self._decayFactor = decayFactor return self + @since('1.5.0') def setHalfLife(self, halfLife, timeUnit): """ Set number of batches after which the centroids of that @@ -566,6 +618,7 @@ def setHalfLife(self, halfLife, timeUnit): self._decayFactor = exp(log(0.5) / halfLife) return self + @since('1.5.0') def setInitialCenters(self, centers, weights): """ Set initial centers. Should be set before calling trainOn. @@ -573,6 +626,7 @@ def setInitialCenters(self, centers, weights): self._model = StreamingKMeansModel(centers, weights) return self + @since('1.5.0') def setRandomCenters(self, dim, weight, seed): """ Set the initial centres to be random samples from @@ -584,6 +638,7 @@ def setRandomCenters(self, dim, weight, seed): self._model = StreamingKMeansModel(clusterCenters, clusterWeights) return self + @since('1.5.0') def trainOn(self, dstream): """Train the model on the incoming dstream.""" self._validate(dstream) @@ -593,6 +648,7 @@ def update(rdd): dstream.foreachRDD(update) + @since('1.5.0') def predictOn(self, dstream): """ Make predictions on a dstream. @@ -601,6 +657,7 @@ def predictOn(self, dstream): self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) + @since('1.5.0') def predictOnValues(self, dstream): """ Make predictions on a keyed dstream. @@ -649,16 +706,21 @@ class LDAModel(JavaModelWrapper): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.5.0 """ + @since('1.5.0') def topicsMatrix(self): """Inferred topics, where each topic is represented by a distribution over terms.""" return self.call("topicsMatrix").toArray() + @since('1.5.0') def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") + @since('1.5.0') def save(self, sc, path): """Save the LDAModel on to disk. @@ -672,6 +734,7 @@ def save(self, sc, path): self._java_model.save(sc._jsc.sc(), path) @classmethod + @since('1.5.0') def load(cls, sc, path): """Load the LDAModel from disk. @@ -688,8 +751,12 @@ def load(cls, sc, path): class LDA(object): + """ + .. versionadded:: 1.5.0 + """ @classmethod + @since('1.5.0') def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): """Train a LDA model. From 3cac6614a4fe60b1446bf704d0a35787d385fb86 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 26 Oct 2015 21:47:42 -0700 Subject: [PATCH 048/324] [SPARK-11184][MLLIB] Declare most of .mllib code not-Experimental Remove "Experimental" from .mllib code that has been around since 1.4.0 or earlier Author: Sean Owen Closes #9169 from srowen/SPARK-11184. --- .../mllib/classification/ClassificationModel.scala | 4 +--- .../mllib/classification/LogisticRegression.scala | 10 +--------- .../apache/spark/mllib/classification/SVM.scala | 8 +------- .../StreamingLogisticRegressionWithSGD.scala | 4 +--- .../spark/mllib/clustering/GaussianMixture.scala | 5 +---- .../mllib/clustering/GaussianMixtureModel.scala | 6 +----- .../org/apache/spark/mllib/clustering/LDA.scala | 5 +---- .../apache/spark/mllib/clustering/LDAModel.scala | 9 --------- .../clustering/PowerIterationClustering.scala | 11 +---------- .../spark/mllib/clustering/StreamingKMeans.scala | 8 +------- .../evaluation/BinaryClassificationMetrics.scala | 5 +---- .../spark/mllib/evaluation/MulticlassMetrics.scala | 4 +--- .../spark/mllib/evaluation/RankingMetrics.scala | 4 +--- .../spark/mllib/evaluation/RegressionMetrics.scala | 4 +--- .../apache/spark/mllib/feature/ChiSqSelector.scala | 6 +----- .../spark/mllib/feature/ElementwiseProduct.scala | 4 +--- .../org/apache/spark/mllib/feature/HashingTF.scala | 4 +--- .../scala/org/apache/spark/mllib/feature/IDF.scala | 6 +----- .../apache/spark/mllib/feature/Normalizer.scala | 4 +--- .../spark/mllib/feature/StandardScaler.scala | 6 +----- .../org/apache/spark/mllib/feature/Word2Vec.scala | 12 +++--------- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 14 +------------- .../mllib/linalg/SingularValueDecomposition.scala | 2 -- .../mllib/linalg/distributed/BlockMatrix.scala | 5 +---- .../linalg/distributed/CoordinateMatrix.scala | 6 +----- .../linalg/distributed/IndexedRowMatrix.scala | 6 +----- .../spark/mllib/linalg/distributed/RowMatrix.scala | 6 +----- .../org/apache/spark/mllib/random/RandomRDDs.scala | 4 +--- .../mllib/regression/IsotonicRegression.scala | 8 +------- .../spark/mllib/regression/RegressionModel.scala | 3 +-- .../StreamingLinearRegressionWithSGD.scala | 4 +--- .../apache/spark/mllib/stat/KernelDensity.scala | 4 +--- .../org/apache/spark/mllib/stat/Statistics.scala | 4 +--- .../apache/spark/mllib/stat/test/TestResult.scala | 4 ---- .../org/apache/spark/mllib/tree/DecisionTree.scala | 4 +--- .../spark/mllib/tree/GradientBoostedTrees.scala | 4 +--- .../org/apache/spark/mllib/tree/RandomForest.scala | 4 +--- .../tree/configuration/BoostingStrategy.scala | 5 +---- .../mllib/tree/configuration/FeatureType.scala | 4 +--- .../tree/configuration/QuantileStrategy.scala | 4 +--- .../spark/mllib/tree/configuration/Strategy.scala | 5 +---- .../apache/spark/mllib/tree/impl/TimeTracker.scala | 3 --- .../spark/mllib/tree/model/DecisionTreeModel.scala | 4 +--- .../mllib/tree/model/treeEnsembleModels.scala | 7 +------ .../org/apache/spark/mllib/util/MLUtils.scala | 8 +------- 45 files changed, 43 insertions(+), 208 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 85a413243b049..5161bc72659c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -19,17 +19,15 @@ package org.apache.spark.mllib.classification import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Represents a classification model that predicts to which of a set of categories an example * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. */ -@Experimental @Since("0.8.0") trait ClassificationModel extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 5ceff5b2259ea..2d52abc122bf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} @@ -82,35 +82,29 @@ class LogisticRegressionModel @Since("1.3.0") ( private var threshold: Option[Double] = Some(0.5) /** - * :: Experimental :: * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. */ @Since("1.0.0") - @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } /** - * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. * It is only used for binary classification. */ @Since("1.3.0") - @Experimental def getThreshold: Option[Double] = threshold /** - * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. * It is only used for binary classification. */ @Since("1.0.0") - @Experimental def clearThreshold(): this.type = { threshold = None this @@ -359,13 +353,11 @@ class LogisticRegressionWithLBFGS } /** - * :: Experimental :: * Set the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. */ @Since("1.3.0") - @Experimental def setNumClasses(numClasses: Int): this.type = { require(numClasses > 1) numOfLinearPredictor = numClasses - 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 896565cd90e89..a8d3fd4177a23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ @@ -43,32 +43,26 @@ class SVMModel @Since("1.1.0") ( private var threshold: Option[Double] = Some(0.0) /** - * :: Experimental :: * Sets the threshold that separates positive predictions from negative predictions. An example * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. */ @Since("1.0.0") - @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) this } /** - * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. */ @Since("1.3.0") - @Experimental def getThreshold: Option[Double] = threshold /** - * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. */ @Since("1.0.0") - @Experimental def clearThreshold(): this.type = { threshold = None this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 75630054d1368..47bff5ebdde47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.StreamingLinearAlgorithm /** - * :: Experimental :: * Train or predict a logistic regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LogisticRegressionWithSGD` for model equation) @@ -43,7 +42,6 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * .trainOn(DStream) * }}} */ -@Experimental @Since("1.3.0") class StreamingLogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index f82bd82c20371..7b203e2f40815 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -30,8 +30,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: - * * This class performs expectation maximization for multivariate Gaussian * Mixture Models (GMMs). A GMM represents a composite distribution of * independent Gaussian distributions with associated "mixing" weights @@ -52,7 +50,6 @@ import org.apache.spark.util.Utils * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ -@Experimental @Since("1.3.0") class GaussianMixture private ( private var k: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index a5902190d4637..2115f7d99c182 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -33,8 +33,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, Row} /** - * :: Experimental :: - * * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are * the respective mean and covariance for each Gaussian distribution i=1..k. @@ -45,7 +43,6 @@ import org.apache.spark.sql.{SQLContext, Row} * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ @Since("1.3.0") -@Experimental class GaussianMixtureModel @Since("1.3.0") ( @Since("1.3.0") val weights: Array[Double], @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { @@ -132,7 +129,6 @@ class GaussianMixtureModel @Since("1.3.0") ( } @Since("1.4.0") -@Experimental object GaussianMixtureModel extends Loader[GaussianMixtureModel] { private object SaveLoadV1_0 { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 92a321afb0ca3..eb802a365ed6e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -28,8 +28,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: - * * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. * * Terminology: @@ -45,7 +43,6 @@ import org.apache.spark.util.Utils * (Wikipedia)]] */ @Since("1.3.0") -@Experimental class LDA private ( private var k: Int, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 15129e0dd5a91..31d8a9fdea1c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -35,14 +35,11 @@ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue /** - * :: Experimental :: - * * Latent Dirichlet Allocation (LDA) model. * * This abstraction permits for different underlying representations, * including local and distributed data structures. */ -@Experimental @Since("1.3.0") abstract class LDAModel private[clustering] extends Saveable { @@ -184,15 +181,12 @@ abstract class LDAModel private[clustering] extends Saveable { } /** - * :: Experimental :: - * * Local LDA model. * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. * @param topics Inferred topics (vocabSize x k matrix). */ -@Experimental @Since("1.3.0") class LocalLDAModel private[clustering] ( @Since("1.3.0") val topics: Matrix, @@ -481,14 +475,11 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } /** - * :: Experimental :: - * * Distributed LDA model. * This model stores the inferred topics, the full training dataset, and the topic distributions. * When computing topics for new documents, it may give more accurate answers * than the [[LocalLDAModel]]. */ -@Experimental @Since("1.3.0") class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 6c76e26fd1626..7cd9b08fa8e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -21,7 +21,7 @@ import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl @@ -33,15 +33,12 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.{Logging, SparkContext, SparkException} /** - * :: Experimental :: - * * Model produced by [[PowerIterationClustering]]. * * @param k number of clusters * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ @Since("1.3.0") -@Experimental class PowerIterationClusteringModel @Since("1.3.0") ( @Since("1.3.0") val k: Int, @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment]) @@ -107,8 +104,6 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode } /** - * :: Experimental :: - * * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by * [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise @@ -120,7 +115,6 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ -@Experimental @Since("1.3.0") class PowerIterationClustering private[clustering] ( private var k: Int, @@ -239,17 +233,14 @@ class PowerIterationClustering private[clustering] ( } @Since("1.3.0") -@Experimental object PowerIterationClustering extends Logging { /** - * :: Experimental :: * Cluster assignment. * @param id node id * @param cluster assigned cluster id */ @Since("1.3.0") - @Experimental case class Assignment(id: Long, cluster: Int) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 1d50ffec96faf..80843719f50b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -30,8 +30,6 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** - * :: Experimental :: - * * StreamingKMeansModel extends MLlib's KMeansModel for streaming * algorithms, so it can keep track of a continuously updated weight * associated with each cluster, and also update the model by @@ -65,7 +63,6 @@ import org.apache.spark.util.random.XORShiftRandom * as batches or points. */ @Since("1.2.0") -@Experimental class StreamingKMeansModel @Since("1.2.0") ( @Since("1.2.0") override val clusterCenters: Array[Vector], @Since("1.2.0") val clusterWeights: Array[Double]) @@ -149,8 +146,6 @@ class StreamingKMeansModel @Since("1.2.0") ( } /** - * :: Experimental :: - * * StreamingKMeans provides methods for configuring a * streaming k-means analysis, training the model on streaming, * and using the model to make predictions on streaming data. @@ -168,7 +163,6 @@ class StreamingKMeansModel @Since("1.2.0") ( * }}} */ @Since("1.2.0") -@Experimental class StreamingKMeans @Since("1.2.0") ( @Since("1.2.0") var k: Int, @Since("1.2.0") var decayFactor: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 508fe532b1306..12cf22095720a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -17,15 +17,13 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.Logging -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql.DataFrame /** - * :: Experimental :: * Evaluator for binary classification. * * @param scoreAndLabels an RDD of (score, label) pairs. @@ -43,7 +41,6 @@ import org.apache.spark.sql.DataFrame * partition boundaries. */ @Since("1.0.0") -@Experimental class BinaryClassificationMetrics @Since("1.3.0") ( @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") val numBins: Int) extends Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 00e837661dfc2..c5104960cfcb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.evaluation import scala.collection.Map -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -32,7 +31,6 @@ import org.apache.spark.sql.DataFrame * @param predictionAndLabels an RDD of (prediction, label) pairs. */ @Since("1.1.0") -@Experimental class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index a7f43f0b110f5..cc01936dd34b2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD @@ -36,7 +36,6 @@ import org.apache.spark.rdd.RDD * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ @Since("1.2.0") -@Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) extends Logging with Serializable { @@ -159,7 +158,6 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] } -@Experimental object RankingMetrics { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 799ebb980ef01..1d8f4fe340fb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors @@ -25,13 +25,11 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate import org.apache.spark.sql.DataFrame /** - * :: Experimental :: * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. */ @Since("1.2.0") -@Experimental class RegressionMetrics @Since("1.2.0") ( predictionAndObservations: RDD[(Double, Double)]) extends Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 5246faf221914..d4d022afde051 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -23,7 +23,7 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics @@ -33,13 +33,11 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.{SQLContext, Row} /** - * :: Experimental :: * Chi Squared selector model. * * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ @Since("1.3.0") -@Experimental class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { @@ -173,7 +171,6 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { } /** - * :: Experimental :: * Creates a ChiSquared feature selector. * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) @@ -181,7 +178,6 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { * select all features. */ @Since("1.3.0") -@Experimental class ChiSqSelector @Since("1.3.0") ( @Since("1.3.0") val numTopFeatures: Int) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index d0a6cf61687a8..c757fc7f06c58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ /** - * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. * @param scalingVec The values used to scale the reference vector's individual components. */ @Since("1.4.0") -@Experimental class ElementwiseProduct @Since("1.4.0") ( @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index e47d524b61623..c93ed64183ad6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,20 +22,18 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * * @param numFeatures number of features (default: 2^20^) */ @Since("1.1.0") -@Experimental class HashingTF(val numFeatures: Int) extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 68078ccfa3d60..cffa9fba05c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Inverse document frequency (IDF). * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total * number of documents and `d(t)` is the number of documents that contain term `t`. @@ -38,7 +37,6 @@ import org.apache.spark.rdd.RDD * should appear for filtering */ @Since("1.1.0") -@Experimental class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) { @Since("1.1.0") @@ -159,10 +157,8 @@ private object IDF { } /** - * :: Experimental :: * Represents an IDF model that can transform term frequency vectors. */ -@Experimental @Since("1.1.0") class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 8d5a22520d6b8..af0c8e1d8a9d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** - * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -32,7 +31,6 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors * @param p Normalization in L^p^ space, p = 2 by default. */ @Since("1.1.0") -@Experimental class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { @Since("1.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index f018b453bae7e..6fe573c528943 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,13 +18,12 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * @@ -33,7 +32,6 @@ import org.apache.spark.rdd.RDD * @param withStd True by default. Scales the data to unit standard deviation. */ @Since("1.1.0") -@Experimental class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging { @Since("1.1.0") @@ -64,7 +62,6 @@ class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) exten } /** - * :: Experimental :: * Represents a StandardScaler model that can transform vectors. * * @param std column standard deviation values @@ -73,7 +70,6 @@ class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) exten * @param withMean whether to center the data before scaling */ @Since("1.1.0") -@Experimental class StandardScalerModel @Since("1.3.0") ( @Since("1.3.0") val std: Vector, @Since("1.1.0") val mean: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 58857c338f546..f3e4d346e358a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -31,15 +31,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.SQLContext /** * Entry in vocabulary @@ -53,7 +52,6 @@ private case class VocabWord( ) /** - * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus * and then learns vector representation of words in the vocabulary. @@ -71,7 +69,6 @@ private case class VocabWord( * Distributed Representations of Words and Phrases and their Compositionality. */ @Since("1.1.0") -@Experimental class Word2Vec extends Serializable with Logging { private var vectorSize = 100 @@ -427,7 +424,6 @@ class Word2Vec extends Serializable with Logging { } /** - * :: Experimental :: * Word2Vec model * @param wordIndex maps each word to an index, which can retrieve the corresponding * vector from wordVectors @@ -435,7 +431,6 @@ class Word2Vec extends Serializable with Logging { * to the word mapped with index i can be retrieved by the slice * (i * vectorSize, i * vectorSize + vectorSize) */ -@Experimental @Since("1.1.0") class Word2VecModel private[mllib] ( private val wordIndex: Map[String, Int], @@ -558,7 +553,6 @@ class Word2VecModel private[mllib] ( } @Since("1.4.0") -@Experimental object Word2VecModel extends Loader[Word2VecModel] { private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index aea5c4f8a8a7d..70ef1ed30c71a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ @@ -33,15 +33,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: - * * Model trained by [[FPGrowth]], which holds frequent itemsets. * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type - * */ @Since("1.3.0") -@Experimental class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** @@ -56,8 +52,6 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( } /** - * :: Experimental :: - * * A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in * [[http://dx.doi.org/10.1145/1454008.1454027 Li et al., PFP: Parallel FP-Growth for Query * Recommendation]]. PFP distributes computation in such a way that each worker executes an @@ -74,7 +68,6 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( * */ @Since("1.3.0") -@Experimental class FPGrowth private ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { @@ -213,12 +206,7 @@ class FPGrowth private ( } } -/** - * :: Experimental :: - * - */ @Since("1.3.0") -@Experimental object FPGrowth { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 4dcf8f28f2023..4591cb88ef152 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -20,11 +20,9 @@ package org.apache.spark.mllib.linalg import org.apache.spark.annotation.{Experimental, Since} /** - * :: Experimental :: * Represents singular value decomposition (SVD) factors. */ @Since("1.0.0") -@Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 81a6c0550bda7..09527dcf5d9e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{Logging, Partitioner, SparkException} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -115,8 +115,6 @@ private[mllib] object GridPartitioner { } /** - * :: Experimental :: - * * Represents a distributed matrix in blocks of local matrices. * * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that @@ -132,7 +130,6 @@ private[mllib] object GridPartitioner { * zero, the number of columns will be calculated when `numCols` is invoked. */ @Since("1.3.0") -@Experimental class BlockMatrix @Since("1.3.0") ( @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)], @Since("1.3.0") val rowsPerBlock: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 644f293d88a75..8a70f34e70f6a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -19,23 +19,20 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** - * :: Experimental :: * Represents an entry in an distributed matrix. * @param i row index * @param j column index * @param value value of the entry */ @Since("1.0.0") -@Experimental case class MatrixEntry(i: Long, j: Long, value: Double) /** - * :: Experimental :: * Represents a matrix in coordinate format. * * @param entries matrix entries @@ -45,7 +42,6 @@ case class MatrixEntry(i: Long, j: Long, value: Double) * columns will be determined by the max column index plus one. */ @Since("1.0.0") -@Experimental class CoordinateMatrix @Since("1.0.0") ( @Since("1.0.0") val entries: RDD[MatrixEntry], private var nRows: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index b20ea0dc50da5..e6af0c0ec7ec2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -19,21 +19,18 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition /** - * :: Experimental :: * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. */ @Since("1.0.0") -@Experimental case class IndexedRow(index: Long, vector: Vector) /** - * :: Experimental :: * Represents a row-oriented [[org.apache.spark.mllib.linalg.distributed.DistributedMatrix]] with * indexed rows. * @@ -44,7 +41,6 @@ case class IndexedRow(index: Long, vector: Vector) * columns will be determined by the size of the first row. */ @Since("1.0.0") -@Experimental class IndexedRowMatrix @Since("1.0.0") ( @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index b8a7adceb15b6..52c0f19c645d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -26,8 +26,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BS import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -35,7 +34,6 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * Represents a row-oriented distributed Matrix with no meaningful row indices. * * @param rows rows stored as an RDD[Vector] @@ -45,7 +43,6 @@ import org.apache.spark.storage.StorageLevel * columns will be determined by the size of the first row. */ @Since("1.0.0") -@Experimental class RowMatrix @Since("1.0.0") ( @Since("1.0.0") val rows: RDD[Vector], private var nRows: Long, @@ -676,7 +673,6 @@ class RowMatrix @Since("1.0.0") ( } @Since("1.0.0") -@Experimental object RowMatrix { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 41d7c4d355f61..b0a716936ae6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector @@ -29,10 +29,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** - * :: Experimental :: * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ -@Experimental @Since("1.1.0") object RandomRDDs { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 877d31ba41303..ec78ea24539b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -29,7 +29,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -37,8 +37,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext /** - * :: Experimental :: - * * Regression model for isotonic regression. * * @param boundaries Array of boundaries for which predictions are known. @@ -49,7 +47,6 @@ import org.apache.spark.sql.SQLContext * */ @Since("1.3.0") -@Experimental class IsotonicRegressionModel @Since("1.3.0") ( @Since("1.3.0") val boundaries: Array[Double], @Since("1.3.0") val predictions: Array[Double], @@ -233,8 +230,6 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } /** - * :: Experimental :: - * * Isotonic regression. * Currently implemented using parallelized pool adjacent violators algorithm. * Only univariate (single feature) algorithm supported. @@ -252,7 +247,6 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { * * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ -@Experimental @Since("1.3.0") class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 0e72d6591ce83..a95a54225a085 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.regression import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @Since("0.8.0") -@Experimental trait RegressionModel extends Serializable { /** * Predict values for the given data set using the model trained. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index fe1d487cdd078..fe2a46b9eecc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** - * :: Experimental :: * Train or predict a linear regression model on streaming data. Training uses * Stochastic Gradient Descent to update the model based on each new batch of * incoming data from a DStream (see `LinearRegressionWithSGD` for model equation) @@ -40,7 +39,6 @@ import org.apache.spark.mllib.linalg.Vector * .setInitialWeights(Vectors.dense(...)) * .trainOn(DStream) */ -@Experimental @Since("1.1.0") class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 4a856f7f3434a..f253963270bc4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.stat import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD /** - * :: Experimental :: * Kernel density estimation. Given a sample from a population, estimate its probability density * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported. * @@ -39,7 +38,6 @@ import org.apache.spark.rdd.RDD * }}} */ @Since("1.4.0") -@Experimental class KernelDensity extends Serializable { import KernelDensity._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 84d64a5bfb38e..bcb33a7a04677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} @@ -30,11 +30,9 @@ import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovS import org.apache.spark.rdd.RDD /** - * :: Experimental :: * API for statistical functions in MLlib. */ @Since("1.1.0") -@Experimental object Statistics { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index b0916d3e84651..8a29fd39a9106 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -20,11 +20,9 @@ package org.apache.spark.mllib.stat.test import org.apache.spark.annotation.{Experimental, Since} /** - * :: Experimental :: * Trait for hypothesis test results. * @tparam DF Return type of `degreesOfFreedom`. */ -@Experimental @Since("1.1.0") trait TestResult[DF] { @@ -79,10 +77,8 @@ trait TestResult[DF] { } /** - * :: Experimental :: * Object containing the test results for the chi-squared hypothesis test. */ -@Experimental @Since("1.1.0") class ChiSqTestResult private[stat] (override val pValue: Double, @Since("1.1.0") override val degreesOfFreedom: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 53d6482f8057c..af1f7e74c004d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo @@ -36,7 +36,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom /** - * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. * It supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type @@ -44,7 +43,6 @@ import org.apache.spark.util.random.XORShiftRandom * categorical), depth of the tree, quantile calculation strategy, etc. */ @Since("1.0.0") -@Experimental class DecisionTree @Since("1.0.0") (private val strategy: Strategy) extends Serializable with Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 66a07e31360d8..729a211574822 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint @@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: * A class that implements * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] * for regression and binary classification. @@ -50,7 +49,6 @@ import org.apache.spark.storage.StorageLevel * @param boostingStrategy Parameters for the gradient boosting algorithm. */ @Since("1.2.0") -@Experimental class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 63a902f3eb51e..a684cdd18c2fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils import org.apache.spark.util.random.SamplingUtils /** - * :: Experimental :: * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] * learning algorithm for classification and regression. * It supports both continuous and categorical features. @@ -66,7 +65,6 @@ import org.apache.spark.util.random.SamplingUtils * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. */ -@Experimental private class RandomForest ( private val strategy: Strategy, private val numTrees: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index fc13bcfd8e998..d2513a9d5c5bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** - * :: Experimental :: * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. * * @param treeStrategy Parameters for the tree algorithm. We support regression and binary @@ -47,7 +46,6 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Since("1.2.0") -@Experimental case class BoostingStrategy @Since("1.4.0") ( // Required boosting parameters @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, @@ -79,7 +77,6 @@ case class BoostingStrategy @Since("1.4.0") ( } @Since("1.2.0") -@Experimental object BoostingStrategy { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index 4e0cd473def06..1470295d8a932 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -17,14 +17,12 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since /** - * :: Experimental :: * Enum to describe whether a feature is "continuous" or "categorical" */ @Since("1.0.0") -@Experimental object FeatureType extends Enumeration { @Since("1.0.0") type FeatureType = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 8262db8a4f111..1c16f136eb3eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -17,14 +17,12 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since /** - * :: Experimental :: * Enum for selecting the quantile calculation strategy */ @Since("1.0.0") -@Experimental object QuantileStrategy extends Enumeration { @Since("1.0.0") type QuantileStrategy = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 89cc13b7c06cf..372d6617a4014 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,13 +20,12 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** - * :: Experimental :: * Stores all the configuration options for tree construction * @param algo Learning goal. Supported: * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], @@ -68,7 +67,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * [[org.apache.spark.SparkContext]], this setting is ignored. */ @Since("1.0.0") -@Experimental class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var algo: Algo, @Since("1.0.0") @BeanProperty var impurity: Impurity, @@ -179,7 +177,6 @@ class Strategy @Since("1.3.0") ( } @Since("1.2.0") -@Experimental object Strategy { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index aac84243d5ce1..70afaa162b2e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -19,12 +19,9 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable.{HashMap => MutableHashMap} -import org.apache.spark.annotation.Experimental - /** * Time tracker implementation which holds labeled timers. */ -@Experimental private[spark] class TimeTracker extends Serializable { private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index e1bf23f4c34bb..54c136aecf660 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} @@ -35,14 +35,12 @@ import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.util.Utils /** - * :: Experimental :: * Decision tree model for classification or regression. * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression */ @Since("1.0.0") -@Experimental class DecisionTreeModel @Since("1.0.0") ( @Since("1.0.0") val topNode: Node, @Since("1.0.0") val algo: Algo) extends Serializable with Saveable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index df5b8feab5d5d..90e032e3d9842 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -38,16 +38,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils - /** - * :: Experimental :: * Represents a random forest model. * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles */ @Since("1.2.0") -@Experimental class RandomForestModel @Since("1.2.0") ( @Since("1.2.0") override val algo: Algo, @Since("1.2.0") override val trees: Array[DecisionTreeModel]) @@ -108,7 +105,6 @@ object RandomForestModel extends Loader[RandomForestModel] { } /** - * :: Experimental :: * Represents a gradient boosted trees model. * * @param algo algorithm for the ensemble model, either Classification or Regression @@ -116,7 +112,6 @@ object RandomForestModel extends Loader[RandomForestModel] { * @param treeWeights tree ensemble weights */ @Since("1.2.0") -@Experimental class GradientBoostedTreesModel @Since("1.2.0") ( @Since("1.2.0") override val algo: Algo, @Since("1.2.0") override val trees: Array[DecisionTreeModel], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 81c2f0ce6e12c..414ea99cfd8c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,9 +19,7 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD @@ -30,8 +28,6 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -263,13 +259,11 @@ object MLUtils { } /** - * :: Experimental :: * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. */ @Since("1.0.0") - @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => From 8b292b19c9b3aaaa51b919a12132e099e5be832d Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 26 Oct 2015 22:00:24 -0700 Subject: [PATCH 049/324] [SPARK-10654][MLLIB] Add columnSimilarities to IndexedRowMatrix Add columnSimilarities to IndexedRowMatrix by delegating to functionality already in RowMatrix. With a test. Author: Reza Zadeh Closes #8792 from rezazadeh/colsims. --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 13 +++++++++++++ .../linalg/distributed/IndexedRowMatrixSuite.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index e6af0c0ec7ec2..976299124cedd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -68,6 +68,19 @@ class IndexedRowMatrix @Since("1.0.0") ( nRows } + + /** + * Compute all cosine similarities between columns of this matrix using the brute-force + * approach of computing normalized dot products. + * + * @return An n x n sparse upper-triangular matrix of cosine similarities between + * columns of this matrix. + */ + @Since("1.6.0") + def columnSimilarities(): CoordinateMatrix = { + toRowMatrix().columnSimilarities() + } + /** * Drops row indices and converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 0ecb7a221a503..6de6cf2fa8634 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -153,6 +153,18 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("similar columns") { + val A = new IndexedRowMatrix(indexedRows) + val gram = A.computeGramianMatrix().toBreeze.toDenseMatrix + + val G = A.columnSimilarities().toBreeze() + + for (i <- 0 until n; j <- i + 1 until n) { + val trueResult = gram(i, j) / scala.math.sqrt(gram(i, i) * gram(j, j)) + assert(math.abs(G(i, j) - trueResult) < 1e-6) + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } From d77d198fcc7c532a699f062a3e3877a7679809da Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 26 Oct 2015 23:53:41 -0700 Subject: [PATCH 050/324] [SPARK-11297] Add new code tags mengxr https://issues.apache.org/jira/browse/SPARK-11297 Add new code tags to hold the same look and feel with previous documents. Author: Xusen Yin Closes #9265 from yinxusen/SPARK-11297. --- docs/css/main.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/css/main.css b/docs/css/main.css index 89305a7d3a358..d770173be1014 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -74,6 +74,10 @@ code { color: #444444; } +div .highlight pre { + font-size: 12px; +} + a code { color: #0088cc; } From feb8d6a44fbfc31a880aaaac0cfcaadc91786073 Mon Sep 17 00:00:00 2001 From: Sem Mulder Date: Tue, 27 Oct 2015 07:55:10 +0000 Subject: [PATCH 051/324] [SPARK-11276][CORE] SizeEstimator prevents class unloading The SizeEstimator keeps a cache of ClassInfos but this cache uses Class objects as keys. Which results in strong references to the Class objects. If these classes are dynamically created this prevents the corresponding ClassLoader from being GCed. Leading to PermGen exhaustion. We use a Map with WeakKeys to prevent this issue. Author: Sem Mulder Closes #9244 from SemMulder/fix-sizeestimator-classunloading. --- .../main/scala/org/apache/spark/util/SizeEstimator.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 14b1f2a17e707..23ee4eff0881b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import com.google.common.collect.MapMaker + import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} @@ -29,7 +31,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet - /** * :: DeveloperApi :: * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in @@ -73,7 +74,8 @@ object SizeEstimator extends Logging { private val ALIGN_SIZE = 8 // A cache of ClassInfo objects for each class - private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] + // We use weakKeys to allow GC of dynamically created classes + private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]() // Object and pointer sizes are arch dependent private var is64bit = false From 8f888eea1aef5a28916ec406a99fc19648681ecf Mon Sep 17 00:00:00 2001 From: Nick Evans Date: Tue, 27 Oct 2015 01:29:06 -0700 Subject: [PATCH 052/324] [SPARK-11270][STREAMING] Add improved equality testing for TopicAndPartition from the Kafka Streaming API jerryshao tdas I know this is kind of minor, and I know you all are busy, but this brings this class in line with the `OffsetRange` class, and makes tests a little more concise. Instead of doing something like: ``` assert topic_and_partition_instance._topic == "foo" assert topic_and_partition_instance._partition == 0 ``` You can do something like: ``` assert topic_and_partition_instance == TopicAndPartition("foo", 0) ``` Before: ``` >>> from pyspark.streaming.kafka import TopicAndPartition >>> TopicAndPartition("foo", 0) == TopicAndPartition("foo", 0) False ``` After: ``` >>> from pyspark.streaming.kafka import TopicAndPartition >>> TopicAndPartition("foo", 0) == TopicAndPartition("foo", 0) True ``` I couldn't find any tests - am I missing something? Author: Nick Evans Closes #9236 from manygrams/topic_and_partition_equality. --- python/pyspark/streaming/kafka.py | 10 ++++++++++ python/pyspark/streaming/tests.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index b35bbaf404cc5..06e159172ab51 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -254,6 +254,16 @@ def __init__(self, topic, partition): def _jTopicAndPartition(self, helper): return helper.createTopicAndPartition(self._topic, self._partition) + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self._topic == other._topic + and self._partition == other._partition) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2c908daa8b214..f7fa481d50235 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -898,6 +898,16 @@ def transformWithOffsetRanges(rdd): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + def test_topic_and_partition_equality(self): + topic_and_partition_a = TopicAndPartition("foo", 0) + topic_and_partition_b = TopicAndPartition("foo", 0) + topic_and_partition_c = TopicAndPartition("bar", 0) + topic_and_partition_d = TopicAndPartition("foo", 1) + + self.assertEqual(topic_and_partition_a, topic_and_partition_b) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds From 17f499920776e0e995434cfa300ff2ff38658fa8 Mon Sep 17 00:00:00 2001 From: maxwell Date: Tue, 27 Oct 2015 01:31:28 -0700 Subject: [PATCH 053/324] [SPARK-5569][STREAMING] fix ObjectInputStreamWithLoader for supporting load array classes. When use Kafka DirectStream API to create checkpoint and restore saved checkpoint when restart, ClassNotFound exception would occur. The reason for this error is that ObjectInputStreamWithLoader extends the ObjectInputStream class and override its resolveClass method. But Instead of Using Class.forName(desc,false,loader), Spark uses loader.loadClass(desc) to instance the class, which do not works with array class. For example: Class.forName("[Lorg.apache.spark.streaming.kafka.OffsetRange.",false,loader) works well while loader.loadClass("[Lorg.apache.spark.streaming.kafka.OffsetRange") would throw an class not found exception. details of the difference between Class.forName and loader.loadClass can be found here. http://bugs.java.com/view_bug.do?bug_id=6446627 Author: maxwell Author: DEMING ZHU Closes #8955 from maxwellzdm/master. --- .../apache/spark/streaming/Checkpoint.scala | 4 ++- .../spark/streaming/CheckpointSuite.scala | 35 +++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 8a6050f5227bf..b7de6dde61c63 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -352,7 +352,9 @@ class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoade override def resolveClass(desc: ObjectStreamClass): Class[_] = { try { - return loader.loadClass(desc.getName()) + // scalastyle:off classforname + return Class.forName(desc.getName(), false, loader) + // scalastyle:on classforname } catch { case e: Exception => } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index a6956533c07a5..84f5294aa39cc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming -import java.io.File +import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} +import org.apache.spark.TestUtils import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag @@ -34,7 +35,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} /** * This test suites tests the checkpointing functionality of DStreams - @@ -579,6 +580,36 @@ class CheckpointSuite extends TestSuiteBase { } } + // This tests whether spark can deserialize array object + // refer to SPARK-5569 + test("recovery from checkpoint contains array object") { + // create a class which is invisible to app class loader + val jar = TestUtils.createJarWithClasses( + classNames = Seq("testClz"), + toStringValue = "testStringValue" + ) + + // invisible to current class loader + val appClassLoader = getClass.getClassLoader + intercept[ClassNotFoundException](appClassLoader.loadClass("testClz")) + + // visible to mutableURLClassLoader + val loader = new MutableURLClassLoader( + Array(jar), appClassLoader) + assert(loader.loadClass("testClz").newInstance().toString == "testStringValue") + + // create and serialize Array[testClz] + // scalastyle:off classforname + val arrayObj = Class.forName("[LtestClz;", false, loader) + // scalastyle:on classforname + val bos = new ByteArrayOutputStream() + new ObjectOutputStream(bos).writeObject(arrayObj) + + // deserialize the Array[testClz] + val ois = new ObjectInputStreamWithLoader( + new ByteArrayInputStream(bos.toByteArray), loader) + assert(ois.readObject().asInstanceOf[Class[_]].getName == "[LtestClz;") + } /** * Tests a streaming operation under checkpointing, by restarting the operation From 958a0ec8fa58ff091f595db2b574a7aa3ff41253 Mon Sep 17 00:00:00 2001 From: Jia Li Date: Tue, 27 Oct 2015 10:57:08 +0100 Subject: [PATCH 054/324] [SPARK-11277][SQL] sort_array throws exception scala.MatchError I'm new to spark. I was trying out the sort_array function then hit this exception. I looked into the spark source code. I found the root cause is that sort_array does not check for an array of NULLs. It's not meaningful to sort an array of entirely NULLs anyway. I'm adding a check on the input array type to SortArray. If the array consists of NULLs entirely, there is no need to sort such array. I have also added a test case for this. Please help to review my fix. Thanks! Author: Jia Li Closes #9247 from jliwork/SPARK-11277. --- .../sql/catalyst/expressions/collectionOperations.scala | 6 +++++- .../catalyst/expressions/CollectionFunctionsSuite.scala | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 89d87726ac649..2cf19b939f734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -89,6 +90,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -109,7 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def nullSafeEval(array: Any, ascending: Any): Any = { val elementType = base.dataType.asInstanceOf[ArrayType].elementType val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + if (elementType != NullType) { + java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + } new GenericArrayData(data.asInstanceOf[Array[Any]]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index a3e81888dfd0d..1aae4678d6278 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -49,6 +49,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) @@ -64,6 +65,12 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + checkEvaluation(new SortArray(a4), Seq(null, null)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) } test("Array contains") { From 360ed832f5213b805ac28cf1d2828be09480f2d6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 27 Oct 2015 11:28:59 +0100 Subject: [PATCH 055/324] [SPARK-11303][SQL] filter should not be pushed down into sample When sampling and then filtering DataFrame, the SQL Optimizer will push down filter into sample and produce wrong result. This is due to the sampler is calculated based on the original scope rather than the scope after filtering. Author: Yanbo Liang Closes #9294 from yanboliang/spark-11303. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 ---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0139b9e87ce84..d37f43888fd4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -74,10 +74,6 @@ object DefaultOptimizer extends Optimizer { object SamplePushDown extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down filter into sample - case Filter(condition, s @ Sample(lb, up, replace, seed, child)) => - Sample(lb, up, replace, seed, - Filter(condition, child)) // Push down projection into sample case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => Sample(lb, up, replace, seed, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 298c32290697a..f5ae3ae49b460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1860,4 +1860,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1)) } } + + test("SPARK-11303: filter should not be pushed down into sample") { + val df = sqlContext.range(100) + List(true, false).foreach { withReplacement => + val sampled = df.sample(withReplacement, 0.1, 1) + val sampledOdd = sampled.filter("id % 2 != 0") + val sampledEven = sampled.filter("id % 2 = 0") + assert(sampled.count() == sampledOdd.count() + sampledEven.count()) + } + } } From 9fc16a82adb5f3db2a250765c11393794404a51b Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 27 Oct 2015 10:46:43 -0700 Subject: [PATCH 056/324] [SPARK-11306] Fix hang when JVM exits. This commit fixes a bug where, in Standalone mode, if a task fails and crashes the JVM, the failure is considered a "normal failure" (meaning it's considered unrelated to the task), so the failure isn't counted against the task's maximum number of failures: https://github.com/apache/spark/commit/af3bc59d1f5d9d952c2d7ad1af599c49f1dbdaf0#diff-a755f3d892ff2506a7aa7db52022d77cL138. As a result, if a task fails in a way that results in it crashing the JVM, it will continuously be re-launched, resulting in a hang. This commit fixes that problem. This bug was introduced by #8007; andrewor14 mccheah vanzin can you take a look at this? This error is hard to trigger because we handle executor losses through 2 code paths (the second is via Akka, where Akka notices that the executor endpoint is disconnected). In my setup, the Akka code path completes first, and doesn't have this bug, so things work fine (see my recent email to the dev list about this). If I manually disable the Akka code path, I can see the hang (and this commit fixes the issue). Author: Kay Ousterhout Closes #9273 from kayousterhout/SPARK-11306. --- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 2625c3e7ac718..a4214c496166d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -137,7 +137,7 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code, isNormalExit = true, message) + case Some(code) => ExecutorExited(code, isNormalExit = false, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) From 3bdbbc6c972567861044dd6a6dc82f35cd12442d Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Tue, 27 Oct 2015 11:05:14 -0700 Subject: [PATCH 057/324] [SPARK-6488][MLLIB][PYTHON] Support addition/multiplication in PySpark's BlockMatrix This PR adds addition and multiplication to PySpark's `BlockMatrix` class via `add` and `multiply` functions. Author: Mike Dusenberry Closes #9139 from dusenberrymw/SPARK-6488_Add_Addition_and_Multiplication_to_PySpark_BlockMatrix. --- python/pyspark/mllib/linalg/distributed.py | 68 ++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index aec407de90aa3..0e76050788630 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -775,6 +775,74 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") + def add(self, other): + """ + Adds two block matrices together. The matrices must have the + same size and matching `rowsPerBlock` and `colsPerBlock` values. + If one of the sub matrix blocks that are being added is a + SparseMatrix, the resulting sub matrix block will also be a + SparseMatrix, even if it is being added to a DenseMatrix. If + two dense sub matrix blocks are added, the output block will + also be a DenseMatrix. + + >>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)]) + >>> mat1 = BlockMatrix(blocks1, 3, 2) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.add(mat2).toLocalMatrix() + DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0) + + >>> mat1.add(mat3).toLocalMatrix() + DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + + def multiply(self, other): + """ + Left multiplies this BlockMatrix by `other`, another + BlockMatrix. The `colsPerBlock` of this matrix must equal the + `rowsPerBlock` of `other`. If `other` contains any SparseMatrix + blocks, they will have to be converted to DenseMatrix blocks. + The output BlockMatrix will only consist of DenseMatrix blocks. + This may cause some performance issues until support for + multiplying two sparse matrices is added. + + >>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6]) + >>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12]) + >>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) + >>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]) + >>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12]) + >>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)]) + >>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)]) + >>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)]) + >>> mat1 = BlockMatrix(blocks1, 2, 3) + >>> mat2 = BlockMatrix(blocks2, 3, 2) + >>> mat3 = BlockMatrix(blocks3, 3, 2) + + >>> mat1.multiply(mat2).toLocalMatrix() + DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0) + + >>> mat1.multiply(mat3).toLocalMatrix() + DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0) + """ + if not isinstance(other, BlockMatrix): + raise TypeError("Other should be a BlockMatrix, got %s" % type(other)) + + other_java_block_matrix = other._java_matrix_wrapper._java_model + java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix) + return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock) + def toLocalMatrix(self): """ Collect the distributed matrix on the driver as a DenseMatrix. From 5a5f65905a202e59bc85170b01c57a883718ddf6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 27 Oct 2015 13:28:52 -0700 Subject: [PATCH 058/324] [SPARK-11347] [SQL] Support for joinWith in Datasets This PR adds a new operation `joinWith` to a `Dataset`, which returns a `Tuple` for each pair where a given `condition` evaluates to true. ```scala case class ClassData(a: String, b: Int) val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() > ds1.joinWith(ds2, $"_1" === $"a").collect() res0: Array((ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) ``` This operation is similar to the relation `join` function with one important difference in the result schema. Since `joinWith` preserves objects present on either side of the join, the result schema is similarly nested into a tuple under the column names `_1` and `_2`. This type of join can be useful both for preserving type-safety with the original object types as well as working with relational data where either side of the join has column names in common. ## Required Changes to Encoders In the process of working on this patch, several deficiencies to the way that we were handling encoders were discovered. Specifically, it turned out to be very difficult to `rebind` the non-expression based encoders to extract the nested objects from the results of joins (and also typed selects that return tuples). As a result the following changes were made. - `ClassEncoder` has been renamed to `ExpressionEncoder` and has been improved to also handle primitive types. Additionally, it is now possible to take arbitrary expression encoders and rewrite them into a single encoder that returns a tuple. - All internal operations on `Dataset`s now require an `ExpressionEncoder`. If the users tries to pass a non-`ExpressionEncoder` in, an error will be thrown. We can relax this requirement in the future by constructing a wrapper class that uses expressions to project the row to the expected schema, shielding the users code from the required remapping. This will give us a nice balance where we don't force user encoders to understand attribute references and binding, but still allow our native encoder to leverage runtime code generation to construct specific encoders for a given schema that avoid an extra remapping step. - Additionally, the semantics for different types of objects are now better defined. As stated in the `ExpressionEncoder` scaladoc: - Classes will have their sub fields extracted by name using `UnresolvedAttribute` expressions and `UnresolvedExtractValue` expressions. - Tuples will have their subfields extracted by position using `BoundReference` expressions. - Primitives will have their values extracted from the first ordinal with a schema that defaults to the name `value`. - Finally, the binding lifecycle for `Encoders` has now been unified across the codebase. Encoders are now `resolved` to the appropriate schema in the constructor of `Dataset`. This process replaces an unresolved expressions with concrete `AttributeReference` expressions. Binding then happens on demand, when an encoder is going to be used to construct an object. This closely mirrors the lifecycle for standard expressions when executing normal SQL or `DataFrame` queries. Author: Michael Armbrust Closes #9300 from marmbrus/datasets-tuples. --- .../spark/sql/catalyst/ScalaReflection.scala | 43 +++- .../sql/catalyst/encoders/ClassEncoder.scala | 101 -------- .../spark/sql/catalyst/encoders/Encoder.scala | 38 +-- .../catalyst/encoders/ExpressionEncoder.scala | 217 ++++++++++++++++++ .../catalyst/encoders/ProductEncoder.scala | 47 ---- .../sql/catalyst/encoders/RowEncoder.scala | 5 +- .../sql/catalyst/encoders/package.scala} | 29 +-- .../catalyst/encoders/primitiveTypes.scala | 100 -------- .../spark/sql/catalyst/encoders/tuples.scala | 173 -------------- .../plans/logical/basicOperators.scala | 28 +-- ...ite.scala => ExpressionEncoderSuite.scala} | 39 ++-- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 190 ++++++++------- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 13 +- .../spark/sql/execution/basicOperators.scala | 16 +- .../org/apache/spark/sql/DatasetSuite.scala | 89 ++++++- .../org/apache/spark/sql/QueryTest.scala | 44 +++- 18 files changed, 563 insertions(+), 615 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala rename sql/catalyst/src/{test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala => main/scala/org/apache/spark/sql/catalyst/encoders/package.scala} (56%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/{ProductEncoderSuite.scala => ExpressionEncoderSuite.scala} (91%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c25161ee81b66..9cbb7c2ffdc76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,6 +146,10 @@ trait ScalaReflection { * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling unbind/bind with a new schema. */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) @@ -159,8 +163,14 @@ trait ScalaReflection { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType) = + path + .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(sys.error("Constructors must start at a class type")) + def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true)) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => @@ -387,12 +397,17 @@ trait ScalaReflection { val className: String = t.erasure.typeSymbol.asClass.fullName val cls = Utils.classForName(className) - val arguments = params.head.map { p => + val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = dataTypeFor(fieldType) + val dataType = schemaFor(fieldType).dataType - constructorFor(fieldType, Some(addToPath(fieldName))) + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } } val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) @@ -413,7 +428,10 @@ trait ScalaReflection { /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct] + extractorFor(inputObject, typeTag[T].tpe) match { + case s: CreateNamedStruct => s + case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) + } } } @@ -602,6 +620,21 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if t <:< definitions.IntTpe => + BoundReference(0, IntegerType, false) + case t if t <:< definitions.LongTpe => + BoundReference(0, LongType, false) + case t if t <:< definitions.DoubleTpe => + BoundReference(0, DoubleType, false) + case t if t <:< definitions.FloatTpe => + BoundReference(0, FloatType, false) + case t if t <:< definitions.ShortTpe => + BoundReference(0, ShortType, false) + case t if t <:< definitions.ByteTpe => + BoundReference(0, ByteType, false) + case t if t <:< definitions.BooleanTpe => + BoundReference(0, BooleanType, false) + case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala deleted file mode 100644 index b484b8fde6369..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A generic encoder for JVM objects. - * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. - * @param clsTag A classtag for `T`. - */ -case class ClassEncoder[T]( - schema: StructType, - extractExpressions: Seq[Expression], - constructExpression: Expression, - clsTag: ClassTag[T]) - extends Encoder[T] { - - @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) - private val inputRow = new GenericMutableRow(1) - - @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) - private val dataType = ObjectType(clsTag.runtimeClass) - - override def toRow(t: T): InternalRow = { - inputRow(0) = t - extractProjection(inputRow) - } - - override def fromRow(row: InternalRow): T = { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } - - override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { - val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - val resolvedExpression = analyzedPlan.expressions.head.children.head - val boundExpression = BindReferences.bindReference(resolvedExpression, schema) - - copy(constructExpression = boundExpression) - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) - } - - override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = { - var remaining = schema - copy(constructExpression = constructExpression transform { - case u: UnresolvedAttribute => - val pos = remaining.head - remaining = remaining.drop(1) - pos - }) - } - - protected val attrs = extractExpressions.map(_.collect { - case a: Attribute => s"#${a.exprId}" - case b: BoundReference => s"[${b.ordinal}]" - }.headOption.getOrElse("")) - - - protected val schemaString = - schema - .zip(attrs) - .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") - - override def toString: String = s"class[$schemaString]" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index efb872ddb81e5..329a132d3d8b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst.encoders + import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType /** @@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. */ -trait Encoder[T] { +trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ def schema: StructType /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] - - /** - * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to - * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should - * copy the result before making another call if required. - */ - def toRow(t: T): InternalRow - - /** - * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `bind` an encoder to a specific schema before you can call this function. - */ - def fromRow(row: InternalRow): T - - /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the - * given schema. - */ - def bind(schema: Seq[Attribute]): Encoder[T] - - /** - * Binds this encoder to the given schema positionally. In this binding, the first reference to - * any input is mapped to `schema(0)`, and so on for each input that is encountered. - */ - def bindOrdinals(schema: Seq[Attribute]): Encoder[T] - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder that - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were orriginally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala new file mode 100644 index 0000000000000..c287aebeeee05 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} + +/** + * A factory for constructing encoders that convert objects and primitves to and from the + * internal row format using catalyst expressions and code generation. By default, the + * expressions used to retrieve values from an input row when producing an object will be created as + * follows: + * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions + * and [[UnresolvedExtractValue]] expressions. + * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions. + * - Primitives will have their values extracted from the first ordinal with a schema that defaults + * to the name `value`. + */ +object ExpressionEncoder { + def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(typeTag[T].tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpression = ScalaReflection.extractorsFor[T](inputObject) + val constructExpression = ScalaReflection.constructorFor[T] + + new ExpressionEncoder[T]( + extractExpression.dataType, + flat, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } + + /** + * Given a set of N encoders, constructs a new encoder that produce objects as items in an + * N-tuple. Note that these encoders should first be bound correctly to the combined input + * schema. + */ + def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + val schema = + StructType( + encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + } + val constructExpression = + NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + /** A helper for producing encoders of Tuple2 from other encoders. */ + def tuple[T1, T2]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = + tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] +} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param extractExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object. + * @param clsTag A classtag for `T`. + */ +case class ExpressionEncoder[T]( + schema: StructType, + flat: Boolean, + extractExpressions: Seq[Expression], + constructExpression: Expression, + clsTag: ClassTag[T]) + extends Encoder[T] { + + if (flat) require(extractExpressions.size == 1) + + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + + /** + * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to + * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should + * copy the result before making another call if required. + */ + def toRow(t: T): InternalRow = { + inputRow(0) = t + extractProjection(inputRow) + } + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * function. + */ + def fromRow(row: InternalRow): T = try { + constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + } + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + copy(constructExpression = analyzedPlan.expressions.head.children.head) + } + + /** + * Returns a copy of this encoder where the expressions used to construct an object from an input + * row have been bound to the ordinals of the given schema. Note that you need to first call + * resolve before bind. + */ + def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) + } + + /** + * Replaces any bound references in the schema with the attributes at the corresponding ordinal + * in the provided schema. This can be used to "relocate" a given encoder to pull values from + * a different schema than it was initially bound to. It can also be used to assign attributes + * to ordinal based extraction (i.e. because the input data was a tuple). + */ + def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(schema) + copy(constructExpression = constructExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + }) + } + + /** + * Given an encoder that has already been bound to a given schema, returns a new encoder + * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, + * when you are trying to use an encoder on grouping keys that were originally part of a larger + * row, but now you have projected out only the key expressions. + */ + def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(oldSchema) + val attributeToNewPosition = AttributeMap.byIndex(newSchema) + copy(constructExpression = constructExpression transform { + case r: BoundReference => + r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) + }) + } + + /** + * Returns a copy of this encoder where the expressions used to create an object given an + * input row have been modified to pull the object out from a nested struct, instead of the + * top level fields. + */ + def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case u: Attribute if u != input => + UnresolvedExtractValue(input, Literal(u.name)) + case b: BoundReference if b != input => + GetStructField( + input, + StructField(s"i[${b.ordinal}]", b.dataType), + b.ordinal) + }) + } + + protected val attrs = extractExpressions.flatMap(_.collect { + case _: UnresolvedAttribute => "" + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }) + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 34f5e6c030f58..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL - * internal binary representation. - */ -object ProductEncoder { - def apply[T <: Product : TypeTag]: ClassEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(typeTag[T].tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] - - new ClassEncoder[T]( - extractExpression.dataType, - extractExpression.flatten, - constructExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e9cc00a2b64ce..0b42130a013b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String * internal binary representation. */ object RowEncoder { - def apply(schema: StructType): ClassEncoder[Row] = { + def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val extractExpressions = extractorsFor(inputObject, schema) val constructExpression = constructorFor(schema) - new ClassEncoder[Row]( + new ExpressionEncoder[Row]( schema, + flat = false, extractExpressions.asInstanceOf[CreateStruct].children, constructExpression, ClassTag(cls)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala similarity index 56% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 52f8383faca92..d4642a500672e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -15,29 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.encoders +package org.apache.spark.sql.catalyst -import org.apache.spark.SparkFunSuite - -class PrimitiveEncoderSuite extends SparkFunSuite { - test("long encoder") { - val enc = new LongEncoder() - val row = enc.toRow(10) - assert(row.getLong(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("int encoder") { - val enc = new IntEncoder() - val row = enc.toRow(10) - assert(row.getInt(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("string encoder") { - val enc = new StringEncoder() - val row = enc.toRow("test") - assert(row.getString(0) == "test") - assert(enc.fromRow(row) == "test") +package object encoders { + private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + case e: ExpressionEncoder[A] => e + case _ => sys.error(s"Only expression encoders are supported today") } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala deleted file mode 100644 index a93f2d7c6115d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ - -/** An encoder for primitive Long types. */ -case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Long] = ClassTag.Long - override def schema: StructType = - StructType(StructField(fieldName, LongType) :: Nil) - - override def fromRow(row: InternalRow): Long = row.getLong(ordinal) - - override def toRow(t: Long): InternalRow = { - row.setLong(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this - override def bind(schema: Seq[Attribute]): Encoder[Long] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this -} - -/** An encoder for primitive Integer types. */ -case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Int] = ClassTag.Int - override def schema: StructType = - StructType(StructField(fieldName, IntegerType) :: Nil) - - override def fromRow(row: InternalRow): Int = row.getInt(ordinal) - - override def toRow(t: Int): InternalRow = { - row.setInt(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this - override def bind(schema: Seq[Attribute]): Encoder[Int] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this -} - -/** An encoder for String types. */ -case class StringEncoder( - fieldName: String = "value", - ordinal: Int = 0) extends Encoder[String] { - - val record = new SpecificMutableRow(StringType :: Nil) - - @transient - lazy val projection = - GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil) - - override def schema: StructType = - StructType( - StructField("value", StringType, nullable = false) :: Nil) - - override def clsTag: ClassTag[String] = scala.reflect.classTag[String] - - - override final def fromRow(row: InternalRow): String = { - row.getString(ordinal) - } - - override final def toRow(value: String): InternalRow = { - val utf8String = UTF8String.fromString(value) - record(0) = utf8String - // TODO: this is a bit of a hack to produce UnsafeRows - projection(record) - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this - override def bind(schema: Seq[Attribute]): Encoder[String] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala deleted file mode 100644 index a48eeda7d2e6f..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructField, StructType} - -// Most of this file is codegen. -// scalastyle:off - -/** - * A set of composite encoders that take sub encoders and map each of their objects to a - * Scala tuple. Note that currently the implementation is fairly limited and only supports going - * from an internal row to a tuple. - */ -object TupleEncoder { - - /** Code generator for composite tuple encoders. */ - def main(args: Array[String]): Unit = { - (2 to 5).foreach { i => - val types = (1 to i).map(t => s"T$t").mkString(", ") - val tupleType = s"($types)" - val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ") - val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ") - val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ") - - println( - s""" - |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] { - | val schema = StructType(Array($fields)) - | - | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType] - | - | def fromRow(row: InternalRow): $tupleType = { - | ($fromRow) - | } - | - | override def toRow(t: $tupleType): InternalRow = - | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - | - | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = { - | this - | } - | - | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - | - | - | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - |} - """.stripMargin) - } - } -} - -class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema))) - - def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)] - - def fromRow(row: InternalRow): (T1, T2) = { - (e1.fromRow(row), e2.fromRow(row)) - } - - override def toRow(t: (T1, T2)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema))) - - def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)] - - def fromRow(row: InternalRow): (T1, T2, T3) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 21a55a5371841..d2d3db0a44484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ @@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode { */ case class MapPartitions[T, U]( func: Iterator[T] => Iterator[U], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def missingInput: AttributeSet = AttributeSet.empty @@ -460,8 +460,8 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumn { def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { - val attrs = implicitly[Encoder[U]].schema.toAttributes - new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child) + val attrs = encoderFor[U].schema.toAttributes + new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) } } @@ -472,8 +472,8 @@ object AppendColumn { */ case class AppendColumn[T, U]( func: T => U, - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns @@ -488,11 +488,11 @@ object MapGroups { child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - implicitly[Encoder[K]], - implicitly[Encoder[T]], - implicitly[Encoder[U]], + encoderFor[K], + encoderFor[T], + encoderFor[U], groupingAttributes, - implicitly[Encoder[U]].schema.toAttributes, + encoderFor[U].schema.toAttributes, child) } } @@ -504,9 +504,9 @@ object MapGroups { */ case class MapGroups[K, T, U]( func: (K, Iterator[T]) => Iterator[U], - kEncoder: Encoder[K], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala similarity index 91% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 008d0bea8a941..a374da4da1f08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -47,7 +47,16 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) -class ProductEncoderSuite extends SparkFunSuite { +class ExpressionEncoderSuite extends SparkFunSuite { + + encodeDecodeTest(1) + encodeDecodeTest(1L) + encodeDecodeTest(1.toDouble) + encodeDecodeTest(1.toFloat) + encodeDecodeTest(true) + encodeDecodeTest(false) + encodeDecodeTest(1.toShort) + encodeDecodeTest(1.toByte) encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) @@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite { { (l, r) => l._2.toString == r._2.toString } /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) = + protected def encodeDecodeTest[T : TypeTag](inputData: T) = encodeDecodeTestCustom[T](inputData)((l, r) => l == r) /** * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it * matches the original. */ - protected def encodeDecodeTestCustom[T <: Product : TypeTag]( + protected def encodeDecodeTestCustom[T : TypeTag]( inputData: T)( c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData") { - val encoder = try ProductEncoder[T] catch { + test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { + val encoder = try ExpressionEncoder[T]() catch { case e: Exception => fail(s"Exception thrown generating encoder", e) } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.bind(schema) + val boundEncoder = encoder.resolve(schema).bind(schema) val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail( @@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} + |Encoder: + |$boundEncoder | """.stripMargin, e) } if (!c(inputData, convertedBack)) { - val types = - convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + val types = convertedBack match { + case c: Product => + c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + case other => other.getClass.getName + } + val encodedData = try { convertedData.toSeq(encoder.schema).zip(encoder.schema).map { @@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite { |${encoder.schema.treeString} | |Extract Expressions: - |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")} - | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} - | + |$boundEncoder """.stripMargin) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 32d9b0b1d9888..aa817a037ef5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -267,7 +267,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution) + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 96213c7630400..e0ab5f593e933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.StructType @@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType * @since 1.6.0 */ @Experimental -class Dataset[T] private[sql]( +class Dataset[T] private( @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution)( - implicit val encoder: Encoder[T]) extends Serializable { + @transient val queryExecution: QueryExecution, + unresolvedEncoder: Encoder[T]) extends Serializable { + + /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ + private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { + case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) + case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") + } private implicit def classTag = encoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = - this(sqlContext, new QueryExecution(sqlContext, plan)) + this(sqlContext, new QueryExecution(sqlContext, plan), encoder) /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ def schema: StructType = encoder.schema @@ -76,7 +83,9 @@ class Dataset[T] private[sql]( * TODO: document binding rules * @since 1.6.0 */ - def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]]) + def as[U : Encoder]: Dataset[U] = { + new Dataset(sqlContext, queryExecution, encoderFor[U]) + } /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have @@ -103,7 +112,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = implicitly[Encoder[T]] + val tEnc = encoderFor[T] val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -150,9 +159,9 @@ class Dataset[T] private[sql]( sqlContext, MapPartitions[T, U]( func, - implicitly[Encoder[T]], - implicitly[Encoder[U]], - implicitly[Encoder[U]].schema.toAttributes, + encoderFor[T], + encoderFor[U], + encoderFor[U].schema.toAttributes, logicalPlan)) } @@ -209,8 +218,8 @@ class Dataset[T] private[sql]( val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns), - implicitly[Encoder[T]].bind(inputPlan.output), + encoderFor[K].resolve(withGroupingKey.newColumns), + encoderFor[T].bind(inputPlan.output), executed, inputPlan.output, withGroupingKey.newColumns) @@ -220,6 +229,18 @@ class Dataset[T] private[sql]( * Typed Relational * * ****************** */ + /** + * Selects a set of column based expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + * @group dfops + * @since 1.3.0 + */ + // Copied from Dataframe to make sure we don't have invalid overloads. + @scala.annotation.varargs + def select(cols: Column*): DataFrame = toDF().select(cols: _*) + /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. * @@ -233,88 +254,64 @@ class Dataset[T] private[sql]( new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) } - // Codegen - // scalastyle:off - - /** sbt scalaShell; println(Seq(1).toDS().genSelect) */ - private def genSelect: String = { - (2 to 5).map { n => - val types = (1 to n).map(i =>s"U$i").mkString(", ") - val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ") - val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ") - val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ") - s""" - |/** - | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - | * @since 1.6.0 - | */ - |def select[$types]($args): Dataset[($types)] = { - | implicit val te = new Tuple${n}Encoder($encoders) - | new Dataset[($types)](sqlContext, - | Project( - | $schema :: Nil, - | logicalPlan)) - |} - | - """.stripMargin - }.mkString("\n") + /** + * Internal helper function for building typed selects that return tuples. For simplicity and + * code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { + val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } + val unresolvedPlan = Project(aliases, logicalPlan) + val execution = new QueryExecution(sqlContext, unresolvedPlan) + // Rebind the encoders to the nested schema that will be produced by the select. + val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { + case (e: ExpressionEncoder[_], a) if !e.flat => + e.nested(a.toAttribute).resolve(execution.analyzed.output) + case (e, a) => + e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) + } + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = { - implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder) - new Dataset[(U1, U2)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = { - implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder) - new Dataset[(U1, U2, U3)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2, U3]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = { - implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder) - new Dataset[(U1, U2, U3, U4)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2, U3, U4]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3], + c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = { - implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder) - new Dataset[(U1, U2, U3, U4, U5)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil, - logicalPlan)) - } - - // scalastyle:on + def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3], + c4: TypedColumn[U4], + c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * * Set operations * @@ -360,6 +357,48 @@ class Dataset[T] private[sql]( */ def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + /* ****** * + * Joins * + * ****** */ + + /** + * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * true. + * + * This is similar to the relation `join` function with one important difference in the + * result schema. Since `joinWith` preserves objects present on either side of the join, the + * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * + * This type of join can be useful both for preserving type-safety with the original object + * types as well as working with relational data where either side of the join has column + * names in common. + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + val left = this.logicalPlan + val right = other.logicalPlan + + val leftData = this.encoder match { + case e if e.flat => Alias(left.output.head, "_1")() + case _ => Alias(CreateStruct(left.output), "_1")() + } + val rightData = other.encoder match { + case e if e.flat => Alias(right.output.head, "_2")() + case _ => Alias(CreateStruct(right.output), "_2")() + } + val leftEncoder = + if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) + val rightEncoder = + if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(leftEncoder, rightEncoder) + + withPlan[(T, U)](other) { (left, right) => + Project( + leftData :: rightData :: Nil, + Join(left, right, Inner, Some(condition.expr))) + } + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -380,13 +419,10 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan))) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R]( - sqlContext, - sqlContext.executePlan( - f(logicalPlan, other.logicalPlan))) + new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5e7198f974389..2cb94430e6178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -491,7 +491,7 @@ class SQLContext private[sql]( def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { - val enc = implicitly[Encoder[T]] + val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index af8474df0de80..f460a86414c41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() - implicit def newIntEncoder: Encoder[Int] = new IntEncoder() - implicit def newLongEncoder: Encoder[Long] = new LongEncoder() - implicit def newStringEncoder: Encoder[String] = new StringEncoder() + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2bb3dba5bd2ba..89938471ee381 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ @@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl */ case class MapPartitions[T, U]( func: Iterator[T] => Iterator[U], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -337,8 +337,8 @@ case class MapPartitions[T, U]( */ case class AppendColumns[T, U]( func: T => U, - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -363,9 +363,9 @@ case class AppendColumns[T, U]( */ case class MapGroups[K, T, U]( func: (K, Iterator[T]) => Iterator[U], - kEncoder: Encoder[K], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 08496249c60cc..aebb390a1d15d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("as tuple") { + val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") + checkAnswer( + data.as[(String, Int)], + ("a", 1), ("b", 2)) + } + test("as case class / collect") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] checkAnswer( @@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } - test("select 3") { + test("select 2") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], - expr("_2").as[Int], - expr("_2 + 1").as[Int]), - ("a", 1, 2), ("b", 2, 3), ("c", 3, 4)) + expr("_2").as[Int]) : Dataset[(String, Int)], + ("a", 1), ("b", 2), ("c", 3)) + } + + test("select 2, primitive and tuple") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("struct(_2, _2)").as[(Int, Int)]), + ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3))) + } + + test("select 2, primitive and class") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) + } + + test("select 2, primitive and class, fields reordered") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkDecoding( + ds.select( + expr("_1").as[String], + expr("named_struct('b', _2, 'a', _1)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) } test("filter") { @@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } + test("joinWith, flat schema") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + checkAnswer( + ds1.joinWith(ds2, $"a.value" === $"b.value"), + (1, 1), (2, 2)) + } + + test("joinWith, expression condition") { + val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val ds2 = Seq(("a", 1), ("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"_1" === $"a"), + (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + } + + test("joinWith tuple with primitive, expression") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(("a", 1), ("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"_2"), + (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) + } + + test("joinWith class with primitive, toDF") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"), + Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) + } + + test("multi-level joinWith") { + val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") + val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") + val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") + + checkAnswer( + ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), + ((("a", 1), ("a", 1)), ("a", 1)), + ((("b", 2), ("b", 2)), ("b", 2))) + + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupBy(v => (1, v._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index aba567512fe32..73e02eb0d9574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder} +import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { @@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest { } } - protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = { + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + * - Special handling is done based on whether the query plan should be expected to return + * the results in sorted order. + * - This function also checks to make sure that the schema for serializing the expected answer + * matches that produced by the dataset (i.e. does manual construction of object match + * the constructed encoder for cases like joins, etc). Note that this means that it will fail + * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead + * which performs a subset of the checks done by this function. + */ + protected def checkAnswer[T : Encoder]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + + checkDecoding(ds, expectedAnswer: _*) + } + + protected def checkDecoding[T]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val decoded = try ds.collect().toSet catch { + case e: Exception => + fail( + s""" + |Exception collecting dataset as objects + |${ds.encoder} + |${ds.encoder.constructExpression.treeString} + |${ds.queryExecution} + """.stripMargin, e) + } + + if (decoded != expectedAnswer.toSet) { + fail( + s"""Decoded objects do not match expected objects: + |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} + |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} + |${ds.encoder.constructExpression.treeString} + """.stripMargin) + } } /** From 9dba5fb2b59174cefde5b62a5c892fe5925bea38 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Tue, 27 Oct 2015 13:55:03 -0700 Subject: [PATCH 059/324] [SPARK-10024][PYSPARK] Python API RF and GBT related params clear up implement {RandomForest, GBT, TreeEnsemble, TreeClassifier, TreeRegressor}Params for Python API in pyspark/ml/{classification, regression}.py Author: vectorijk Closes #9233 from vectorijk/spark-10024. --- python/pyspark/ml/classification.py | 182 +++------------- python/pyspark/ml/regression.py | 324 ++++++++++++---------------- 2 files changed, 168 insertions(+), 338 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 88815e561f572..4cbe7fbd482da 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.ml.regression import ( - RandomForestParams, DecisionTreeModel, TreeEnsembleModels) + RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc @@ -205,8 +205,34 @@ class TreeClassifierParams(object): """ supportedImpurities = ["entropy", "gini"] + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + + ", ".join(supportedImpurities)) + + def __init__(self): + super(TreeClassifierParams, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = Param(self, "impurity", "Criterion used for information " + + "gain calculation (case-insensitive). Supported options: " + + ", ".join(self.supportedImpurities)) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self -class GBTParams(object): + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. """ @@ -216,7 +242,7 @@ class GBTParams(object): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - HasCheckpointInterval): + TreeClassifierParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -250,11 +276,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1.0 """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", @@ -269,11 +290,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(DecisionTreeClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") @@ -299,19 +315,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeClassificationModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - @inherit_doc class DecisionTreeClassificationModel(DecisionTreeModel): @@ -323,7 +326,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel): @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, HasRawPredictionCol, HasProbabilityCol, - DecisionTreeParams, HasCheckpointInterval): + RandomForestParams, TreeClassifierParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for classification. @@ -357,19 +360,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1.0 """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", @@ -386,23 +376,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(RandomForestClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) - #: param for Fraction of the training data used for learning each decision tree, - # in range (0, 1] - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: param for Number of trees to train (>= 1) - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") - #: param for The number of features to consider for splits at each tree node - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto") @@ -429,58 +402,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - def setNumTrees(self, value): - """ - Sets the value of :py:attr:`numTrees`. - """ - self._paramMap[self.numTrees] = value - return self - - def getNumTrees(self): - """ - Gets the value of numTrees or its default value. - """ - return self.getOrDefault(self.numTrees) - - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - self._paramMap[self.featureSubsetStrategy] = value - return self - - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class RandomForestClassificationModel(TreeEnsembleModels): """ @@ -490,7 +411,7 @@ class RandomForestClassificationModel(TreeEnsembleModels): @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - DecisionTreeParams, HasCheckpointInterval): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for classification. @@ -522,12 +443,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + - "contribution of each estimator") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", @@ -547,15 +462,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - #: Fraction of the training data used for learning each decision tree, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of - # each estimator - self.stepSize = Param(self, "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator") self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1) @@ -593,32 +499,6 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - def setStepSize(self, value): - """ - Sets the value of :py:attr:`stepSize`. - """ - self._paramMap[self.stepSize] = value - return self - - def getStepSize(self): - """ - Gets the value of stepSize or its default value. - """ - return self.getOrDefault(self.stepSize) - class GBTClassificationModel(TreeEnsembleModels): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index eb5f4bd6d70b4..eeb18b3e9d290 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -260,21 +260,127 @@ def predictions(self): return self._call_java("predictions") -class TreeRegressorParams(object): +class TreeEnsembleParams(DecisionTreeParams): + """ + Mixin for Decision Tree-based ensemble algorithms parameters. + """ + + # a placeholder to make it appear in the generated doc + subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " + + "used for learning each decision tree, in range (0, 1].") + + def __init__(self): + super(TreeEnsembleParams, self).__init__() + #: param for Fraction of the training data, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " + + "used for learning each decision tree, in range (0, 1].") + + @since("1.4.0") + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + @since("1.4.0") + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + +class TreeRegressorParams(Params): """ Private class to track supported impurity measures. """ + supportedImpurities = ["variance"] + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + + ", ".join(supportedImpurities)) + def __init__(self): + super(TreeRegressorParams, self).__init__() + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = Param(self, "impurity", "Criterion used for information " + + "gain calculation (case-insensitive). Supported options: " + + ", ".join(self.supportedImpurities)) -class RandomForestParams(object): + @since("1.4.0") + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + @since("1.4.0") + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +class RandomForestParams(TreeEnsembleParams): """ Private class to track supported random forest parameters. """ + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + # a placeholder to make it appear in the generated doc + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies)) + + def __init__(self): + super(RandomForestParams, self).__init__() + #: param for Number of trees to train (>= 1). + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).") + #: param for The number of features to consider for splits at each tree node. + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(self.supportedFeatureSubsetStrategies)) + + @since("1.4.0") + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self._paramMap[self.numTrees] = value + return self + + @since("1.4.0") + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self._paramMap[self.featureSubsetStrategy] = value + return self -class GBTParams(object): + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. """ @@ -283,7 +389,7 @@ class GBTParams(object): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -309,11 +415,6 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -326,11 +427,6 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") @@ -355,21 +451,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) - @since("1.4.0") - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - @since("1.4.0") - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - @inherit_doc class DecisionTreeModel(JavaModel): @@ -422,7 +503,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel): @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - DecisionTreeParams, HasCheckpointInterval): + RandomForestParams, TreeRegressorParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for regression. @@ -447,54 +528,26 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi .. versionadded:: 1.4.0 """ - # a placeholder to make it appear in the generated doc - impurity = Param(Params._dummy(), "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=None): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", numTrees=20, \ - featureSubsetStrategy="auto", seed=None) + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ + featureSubsetStrategy="auto") """ super(RandomForestRegressor, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) - #: param for Criterion used for information gain calculation (case-insensitive). - self.impurity = \ - Param(self, "impurity", - "Criterion used for information gain calculation (case-insensitive). " + - "Supported options: " + ", ".join(TreeRegressorParams.supportedImpurities)) - #: param for Fraction of the training data used for learning each decision tree, - # in range (0, 1] - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: param for Number of trees to train (>= 1) - self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") - #: param for The number of features to consider for splits at each tree node - self.featureSubsetStrategy = \ - Param(self, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="variance", numTrees=20, featureSubsetStrategy="auto") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -502,13 +555,15 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, - impurity="variance", numTrees=20, featureSubsetStrategy="auto"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ - impurity="variance", numTrees=20, featureSubsetStrategy="auto") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \ + featureSubsetStrategy="auto") Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -517,66 +572,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) - @since("1.4.0") - def setImpurity(self, value): - """ - Sets the value of :py:attr:`impurity`. - """ - self._paramMap[self.impurity] = value - return self - - @since("1.4.0") - def getImpurity(self): - """ - Gets the value of impurity or its default value. - """ - return self.getOrDefault(self.impurity) - - @since("1.4.0") - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - @since("1.4.0") - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - @since("1.4.0") - def setNumTrees(self, value): - """ - Sets the value of :py:attr:`numTrees`. - """ - self._paramMap[self.numTrees] = value - return self - - @since("1.4.0") - def getNumTrees(self): - """ - Gets the value of numTrees or its default value. - """ - return self.getOrDefault(self.numTrees) - - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - self._paramMap[self.featureSubsetStrategy] = value - return self - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class RandomForestRegressionModel(TreeEnsembleModels): """ @@ -588,7 +583,7 @@ class RandomForestRegressionModel(TreeEnsembleModels): @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - DecisionTreeParams, HasCheckpointInterval): + GBTParams, HasCheckpointInterval, HasStepSize, HasSeed): """ `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` learning algorithm for regression. @@ -617,23 +612,17 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - subsamplingRate = Param(Params._dummy(), "subsamplingRate", - "Fraction of the training data used for learning each decision tree, " + - "in range (0, 1].") - stepSize = Param(Params._dummy(), "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + - "contribution of each estimator") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", - maxIter=20, stepSize=0.1): + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) @@ -641,18 +630,9 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) - #: Fraction of the training data used for learning each decision tree, in range (0, 1]. - self.subsamplingRate = Param(self, "subsamplingRate", - "Fraction of the training data used for learning each " + - "decision tree, in range (0, 1].") - #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of - # each estimator - self.stepSize = Param(self, "stepSize", - "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + - "the contribution of each estimator") self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -660,13 +640,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="squared", maxIter=20, stepSize=0.1): + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="squared", maxIter=20, stepSize=0.1) + maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1) Sets params for Gradient Boosted Tree Regression. """ kwargs = self.setParams._input_kwargs @@ -690,36 +670,6 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) - @since("1.4.0") - def setSubsamplingRate(self, value): - """ - Sets the value of :py:attr:`subsamplingRate`. - """ - self._paramMap[self.subsamplingRate] = value - return self - - @since("1.4.0") - def getSubsamplingRate(self): - """ - Gets the value of subsamplingRate or its default value. - """ - return self.getOrDefault(self.subsamplingRate) - - @since("1.4.0") - def setStepSize(self, value): - """ - Sets the value of :py:attr:`stepSize`. - """ - self._paramMap[self.stepSize] = value - return self - - @since("1.4.0") - def getStepSize(self): - """ - Gets the value of stepSize or its default value. - """ - return self.getOrDefault(self.stepSize) - class GBTRegressionModel(TreeEnsembleModels): """ @@ -783,7 +733,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None): + quantilesCol=None) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( From 4f030b9e82172659d250281782ac573cbd1438fc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 27 Oct 2015 16:01:26 -0700 Subject: [PATCH 060/324] [SPARK-11324][STREAMING] Flag for closing Write Ahead Logs after a write Currently the Write Ahead Log in Spark Streaming flushes data as writes need to be made. S3 does not support flushing of data, data is written once the stream is actually closed. In case of failure, the data for the last minute (default rolling interval) will not be properly written. Therefore we need a flag to close the stream after the write, so that we achieve read after write consistency. cc tdas zsxwing Author: Burak Yavuz Closes #9285 from brkyvz/caw-wal. --- .../util/FileBasedWriteAheadLog.scala | 6 +++- .../streaming/util/WriteAheadLogUtils.scala | 15 ++++++++- .../streaming/util/WriteAheadLogSuite.scala | 32 +++++++++++++++---- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 9f4a4d6806ab5..bc3f2486c21fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -47,7 +47,8 @@ private[streaming] class FileBasedWriteAheadLog( logDirectory: String, hadoopConf: Configuration, rollingIntervalSecs: Int, - maxFailures: Int + maxFailures: Int, + closeFileAfterWrite: Boolean ) extends WriteAheadLog with Logging { import FileBasedWriteAheadLog._ @@ -80,6 +81,9 @@ private[streaming] class FileBasedWriteAheadLog( while (!succeeded && failures < maxFailures) { try { fileSegment = getLogWriter(time).write(byteBuffer) + if (closeFileAfterWrite) { + resetWriter() + } succeeded = true } catch { case ex: Exception => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 7f6ff12c58d47..0ea970e61b694 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -31,11 +31,15 @@ private[streaming] object WriteAheadLogUtils extends Logging { val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs" val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures" + val RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = + "spark.streaming.receiver.writeAheadLog.closeFileAfterWrite" val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class" val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = + "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" val DEFAULT_ROLLING_INTERVAL_SECS = 60 val DEFAULT_MAX_FAILURES = 3 @@ -60,6 +64,14 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { + if (isDriver) { + conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) + } else { + conf.getBoolean(RECEIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) + } + } + /** * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. @@ -113,7 +125,8 @@ private[streaming] object WriteAheadLogUtils extends Logging { } }.getOrElse { new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf, - getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver)) + getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), + shouldCloseFileAfterWrite(sparkConf, isDriver)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 5e49fd00769ad..93ae41a3d2ecd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -203,6 +203,21 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { assert(writtenData === dataToWrite) } + test("FileBasedWriteAheadLog - close after write flag") { + // Write data with rotation using WriteAheadLog class + val numFiles = 3 + val dataToWrite = Seq.tabulate(numFiles)(_.toString) + // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, + closeFileAfterWrite = true) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size === numFiles) + val writtenData = logFiles.flatMap { file => readDataManually(file)} + assert(writtenData === dataToWrite) + } + test("FileBasedWriteAheadLog - read rotating logs") { // Write data manually for testing reading through WriteAheadLog val writtenData = (1 to 10).map { i => @@ -296,8 +311,8 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { assert(!nonexistentTempPath.exists()) val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog( - new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1) + val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath, + new Configuration(), 1, 1, closeFileAfterWrite = false) assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") wal.read(writtenSegment.head) assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") @@ -356,14 +371,16 @@ object WriteAheadLogSuite { logDirectory: String, data: Seq[String], manualClock: ManualClock = new ManualClock, - closeLog: Boolean = true - ): FileBasedWriteAheadLog = { + closeLog: Boolean = true, + clockAdvanceTime: Int = 500, + closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => - manualClock.advance(500) + manualClock.advance(clockAdvanceTime) wal.write(item, manualClock.getTimeMillis()) } if (closeLog) wal.close() @@ -418,7 +435,8 @@ object WriteAheadLogSuite { /** Read all the data in the log file in a directory using the WriteAheadLog class. */ def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite = false) val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data From 9fbd75ab5d46612e52116ec5b9ced70715cf26b5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 27 Oct 2015 16:14:33 -0700 Subject: [PATCH 061/324] =?UTF-8?q?[SPARK-11212][CORE][STREAMING]=20Make?= =?UTF-8?q?=20preferred=20locations=20support=20ExecutorCacheTaskLocation?= =?UTF-8?q?=20and=20update=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … ReceiverTracker and ReceiverSchedulingPolicy to use it This PR includes the following changes: 1. Add a new preferred location format, `executor__` (e.g., "executor_localhost_2"), to support specifying the executor locations for RDD. 2. Use the new preferred location format in `ReceiverTracker` to optimize the starting time of Receivers when there are multiple executors in a host. The goal of this PR is to enable the streaming scheduler to place receivers (which run as tasks) in specific executors. Basically, I want to have more control on the placement of the receivers such that they are evenly distributed among the executors. We tried to do this without changing the core scheduling logic. But it does not allow specifying particular executor as preferred location, only at the host level. So if there are two executors in the same host, and I want two receivers to run on them (one on each executor), I cannot specify that. Current code only specifies the host as preference, which may end up launching both receivers on the same executor. We try to work around it but restarting a receiver when it does not launch in the desired executor and hope that next time it will be started in the right one. But that cause lots of restarts, and delays in correctly launching the receiver. So this change, would allow the streaming scheduler to specify the exact executor as the preferred location. Also this is not exposed to the user, only the streaming scheduler uses this. Author: zsxwing Closes #9181 from zsxwing/executor-location. --- .../apache/spark/scheduler/TaskLocation.scala | 17 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 1 + .../receiver/ReceiverSupervisorImpl.scala | 5 +- .../scheduler/ReceiverSchedulingPolicy.scala | 110 +++++++++++------- .../streaming/scheduler/ReceiverTracker.scala | 93 ++++++++------- .../scheduler/ReceiverTrackingInfo.scala | 9 +- .../ReceiverSchedulingPolicySuite.scala | 110 +++++++++++------- .../scheduler/ReceiverTrackerSuite.scala | 4 +- 8 files changed, 217 insertions(+), 132 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 1b65926f5c749..1eb6c1614fc0b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -31,7 +31,9 @@ private[spark] sealed trait TaskLocation { */ private [spark] case class ExecutorCacheTaskLocation(override val host: String, executorId: String) - extends TaskLocation + extends TaskLocation { + override def toString: String = s"${TaskLocation.executorLocationTag}${host}_$executorId" +} /** * A location on a host. @@ -53,6 +55,9 @@ private[spark] object TaskLocation { // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. val inMemoryLocationTag = "hdfs_cache_" + // Identify locations of executors with this prefix. + val executorLocationTag = "executor_" + def apply(host: String, executorId: String): TaskLocation = { new ExecutorCacheTaskLocation(host, executorId) } @@ -65,7 +70,15 @@ private[spark] object TaskLocation { def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { - new HostTaskLocation(str) + if (str.startsWith(executorLocationTag)) { + val splits = str.split("_") + if (splits.length != 3) { + throw new IllegalArgumentException("Illegal executor location format: " + str) + } + new ExecutorCacheTaskLocation(splits(1), splits(2)) + } else { + new HostTaskLocation(str) + } } else { new HDFSCacheTaskLocation(hstr) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 695523cc8aa3a..cd6bf723e70cb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -779,6 +779,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg test("Test TaskLocation for different host type.") { assert(TaskLocation("host1") === HostTaskLocation("host1")) assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) + assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) } def createTaskResult(id: Int): DirectTaskResult[Int] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 59ef58d232ee7..167f56aa42281 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -47,7 +47,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { - private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val host = SparkEnv.get.blockManager.blockManagerId.host + private val executorId = SparkEnv.get.blockManager.blockManagerId.executorId private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { @@ -179,7 +180,7 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, hostPort, endpoint) + streamId, receiver.getClass.getSimpleName, host, executorId, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index d2b0be7f4a9c5..234bc8660da8a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming.scheduler import scala.collection.Map import scala.collection.mutable +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils /** * A class that tries to schedule receivers with evenly distributed. There are two phases for @@ -29,23 +29,23 @@ import org.apache.spark.util.Utils * * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. - * It will try to schedule receivers with evenly distributed. ReceiverTracker should update its - * receiverTrackingInfoMap according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledExecutors` for each receiver will set to an executor list that - * contains the scheduled locations. Then when a receiver is starting, it will send a register - * request and `ReceiverTracker.registerReceiver` will be called. In - * `ReceiverTracker.registerReceiver`, if a receiver's scheduled executors is set, it should check - * if the location of this receiver is one of the scheduled executors, if not, the register will + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker should + * update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should check + * if the location of this receiver is one of the scheduled locations, if not, the register will * be rejected. * - The second phase is local scheduling when a receiver is restarting. There are two cases of * receiver restarting: * - If a receiver is restarting because it's rejected due to the real location and the scheduled - * executors mismatching, in other words, it fails to start in one of the locations that + * locations mismatching, in other words, it fails to start in one of the locations that * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are - * still alive in the list of scheduled executors, then use them to launch the receiver job. - * - If a receiver is restarting without a scheduled executors list, or the executors in the list + * still alive in the list of scheduled locations, then use them to launch the receiver job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should - * not set `ReceiverTrackingInfo.scheduledExecutors` for this executor, instead, it should clear + * not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it should clear * it. Then when this receiver is registering, we can know this is a local scheduling, and * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching * location is matching. @@ -69,9 +69,12 @@ private[streaming] class ReceiverSchedulingPolicy { * * * This method is called when we start to launch receivers at the first time. + * + * @return a map for receivers and their scheduled locations */ def scheduleReceivers( - receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + receivers: Seq[Receiver[_]], + executors: Seq[ExecutorCacheTaskLocation]): Map[Int, Seq[TaskLocation]] = { if (receivers.isEmpty) { return Map.empty } @@ -80,16 +83,16 @@ private[streaming] class ReceiverSchedulingPolicy { return receivers.map(_.streamId -> Seq.empty).toMap } - val hostToExecutors = executors.groupBy(executor => Utils.parseHostPort(executor)._1) - val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val hostToExecutors = executors.groupBy(_.host) + val scheduledLocations = Array.fill(receivers.length)(new mutable.ArrayBuffer[TaskLocation]) + val numReceiversOnExecutor = mutable.HashMap[ExecutorCacheTaskLocation, Int]() // Set the initial value to 0 executors.foreach(e => numReceiversOnExecutor(e) = 0) // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. for (i <- 0 until receivers.length) { - // Note: preferredLocation is host but executors are host:port + // Note: preferredLocation is host but executors are host_executorId receivers(i).preferredLocation.foreach { host => hostToExecutors.get(host) match { case Some(executorsOnHost) => @@ -97,7 +100,7 @@ private[streaming] class ReceiverSchedulingPolicy { // this host val leastScheduledExecutor = executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) - scheduledExecutors(i) += leastScheduledExecutor + scheduledLocations(i) += leastScheduledExecutor numReceiversOnExecutor(leastScheduledExecutor) = numReceiversOnExecutor(leastScheduledExecutor) + 1 case None => @@ -106,17 +109,20 @@ private[streaming] class ReceiverSchedulingPolicy { // 1. This executor is not up. But it may be up later. // 2. This executor is dead, or it's not a host in the cluster. // Currently, simply add host to the scheduled executors. - scheduledExecutors(i) += host + + // Note: host could be `HDFSCacheTaskLocation`, so use `TaskLocation.apply` to handle + // this case + scheduledLocations(i) += TaskLocation(host) } } } // For those receivers that don't have preferredLocation, make sure we assign at least one // executor to them. - for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + for (scheduledLocationsForOneReceiver <- scheduledLocations.filter(_.isEmpty)) { // Select the executor that has the least receivers val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) - scheduledExecutorsForOneReceiver += leastScheduledExecutor + scheduledLocationsForOneReceiver += leastScheduledExecutor numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 } @@ -124,22 +130,22 @@ private[streaming] class ReceiverSchedulingPolicy { val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) for (executor <- idleExecutors) { // Assign an idle executor to the receiver that has least candidate executors. - val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + val leastScheduledExecutors = scheduledLocations.minBy(_.size) leastScheduledExecutors += executor } - receivers.map(_.streamId).zip(scheduledExecutors).toMap + receivers.map(_.streamId).zip(scheduledLocations).toMap } /** - * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * Return a list of candidate locations to run the receiver. If the list is empty, the caller can * run this receiver in arbitrary executor. * * This method tries to balance executors' load. Here is the approach to schedule executors * for a receiver. *
    *
  1. - * If preferredLocation is set, preferredLocation should be one of the candidate executors. + * If preferredLocation is set, preferredLocation should be one of the candidate locations. *
  2. *
  3. * Every executor will be assigned to a weight according to the receivers running or @@ -163,40 +169,58 @@ private[streaming] class ReceiverSchedulingPolicy { receiverId: Int, preferredLocation: Option[String], receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], - executors: Seq[String]): Seq[String] = { + executors: Seq[ExecutorCacheTaskLocation]): Seq[TaskLocation] = { if (executors.isEmpty) { return Seq.empty } // Always try to schedule to the preferred locations - val scheduledExecutors = mutable.Set[String]() - scheduledExecutors ++= preferredLocation - - val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => - receiverTrackingInfo.state match { - case ReceiverState.INACTIVE => Nil - case ReceiverState.SCHEDULED => - val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get - // The probability that a scheduled receiver will run in an executor is - // 1.0 / scheduledLocations.size - scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) - case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) - } - }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + val scheduledLocations = mutable.Set[TaskLocation]() + // Note: preferredLocation could be `HDFSCacheTaskLocation`, so use `TaskLocation.apply` to + // handle this case + scheduledLocations ++= preferredLocation.map(TaskLocation(_)) + + val executorWeights: Map[ExecutorCacheTaskLocation, Double] = { + receiverTrackingInfoMap.values.flatMap(convertReceiverTrackingInfoToExecutorWeights) + .groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + } val idleExecutors = executors.toSet -- executorWeights.keys if (idleExecutors.nonEmpty) { - scheduledExecutors ++= idleExecutors + scheduledLocations ++= idleExecutors } else { // There is no idle executor. So select all executors that have the minimum weight. val sortedExecutors = executorWeights.toSeq.sortBy(_._2) if (sortedExecutors.nonEmpty) { val minWeight = sortedExecutors(0)._2 - scheduledExecutors ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + scheduledLocations ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) } else { // This should not happen since "executors" is not empty } } - scheduledExecutors.toSeq + scheduledLocations.toSeq + } + + /** + * This method tries to convert a receiver tracking info to executor weights. Every executor will + * be assigned to a weight according to the receivers running or scheduling on it: + * + * - If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + * - If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight. + */ + private def convertReceiverTrackingInfoToExecutorWeights( + receiverTrackingInfo: ReceiverTrackingInfo): Seq[(ExecutorCacheTaskLocation, Double)] = { + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledLocations = receiverTrackingInfo.scheduledLocations.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledLocations.filter(_.isInstanceOf[ExecutorCacheTaskLocation]).map { location => + location.asInstanceOf[ExecutorCacheTaskLocation] -> (1.0 / scheduledLocations.size) + } + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 2ce80d618b0a3..b183d856f50c3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,20 +17,21 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, CountDownLatch} +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap import scala.concurrent.ExecutionContext import scala.language.existentials import scala.util.{Failure, Success} -import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{TaskLocation, ExecutorCacheTaskLocation} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.{Utils, ThreadUtils, SerializableConfiguration} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils, Utils} /** Enumeration to identify current state of a Receiver */ @@ -47,7 +48,8 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - hostPort: String, + host: String, + executorId: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -235,7 +237,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - hostPort: String, + host: String, + executorId: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress ): Boolean = { @@ -247,18 +250,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false return false } - val scheduledExecutors = receiverTrackingInfos(streamId).scheduledExecutors - val accetableExecutors = if (scheduledExecutors.nonEmpty) { + val scheduledLocations = receiverTrackingInfos(streamId).scheduledLocations + val acceptableExecutors = if (scheduledLocations.nonEmpty) { // This receiver is registering and it's scheduled by - // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledExecutors" to check it. - scheduledExecutors.get + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledLocations" to check it. + scheduledLocations.get } else { // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. scheduleReceiver(streamId) } - if (!accetableExecutors.contains(hostPort)) { + def isAcceptable: Boolean = acceptableExecutors.exists { + case loc: ExecutorCacheTaskLocation => loc.executorId == executorId + case loc: TaskLocation => loc.host == host + } + + if (!isAcceptable) { // Refuse it since it's scheduled to a wrong executor false } else { @@ -266,8 +274,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false val receiverTrackingInfo = ReceiverTrackingInfo( streamId, ReceiverState.ACTIVE, - scheduledExecutors = None, - runningExecutor = Some(hostPort), + scheduledLocations = None, + runningExecutor = Some(ExecutorCacheTaskLocation(host, executorId)), name = Some(name), endpoint = Some(receiverEndpoint)) receiverTrackingInfos.put(streamId, receiverTrackingInfo) @@ -338,25 +346,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } - private def scheduleReceiver(receiverId: Int): Seq[String] = { + private def scheduleReceiver(receiverId: Int): Seq[TaskLocation] = { val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) - val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + val scheduledLocations = schedulingPolicy.rescheduleReceiver( receiverId, preferredLocation, receiverTrackingInfos, getExecutors) - updateReceiverScheduledExecutors(receiverId, scheduledExecutors) - scheduledExecutors + updateReceiverScheduledExecutors(receiverId, scheduledLocations) + scheduledLocations } private def updateReceiverScheduledExecutors( - receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + receiverId: Int, scheduledLocations: Seq[TaskLocation]): Unit = { val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { case Some(oldInfo) => oldInfo.copy(state = ReceiverState.SCHEDULED, - scheduledExecutors = Some(scheduledExecutors)) + scheduledLocations = Some(scheduledLocations)) case None => ReceiverTrackingInfo( receiverId, ReceiverState.SCHEDULED, - Some(scheduledExecutors), + Some(scheduledLocations), runningExecutor = None) } receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) @@ -370,13 +378,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** * Get the list of executors excluding driver */ - private def getExecutors: Seq[String] = { + private def getExecutors: Seq[ExecutorCacheTaskLocation] = { if (ssc.sc.isLocal) { - Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + val blockManagerId = ssc.sparkContext.env.blockManager.blockManagerId + Seq(ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId)) } else { ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location - }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + }.map { case (blockManagerId, _) => + ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId) + }.toSeq } } @@ -431,9 +442,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receive: PartialFunction[Any, Unit] = { // Local messages case StartAllReceivers(receivers) => - val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + val scheduledLocations = schedulingPolicy.scheduleReceivers(receivers, getExecutors) for (receiver <- receivers) { - val executors = scheduledExecutors(receiver.streamId) + val executors = scheduledLocations(receiver.streamId) updateReceiverScheduledExecutors(receiver.streamId, executors) receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation startReceiver(receiver, executors) @@ -441,14 +452,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false case RestartReceiver(receiver) => // Old scheduled executors minus the ones that are not active any more val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) - val scheduledExecutors = if (oldScheduledExecutors.nonEmpty) { + val scheduledLocations = if (oldScheduledExecutors.nonEmpty) { // Try global scheduling again oldScheduledExecutors } else { val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) - // Clear "scheduledExecutors" to indicate we are going to do local scheduling + // Clear "scheduledLocations" to indicate we are going to do local scheduling val newReceiverInfo = oldReceiverInfo.copy( - state = ReceiverState.INACTIVE, scheduledExecutors = None) + state = ReceiverState.INACTIVE, scheduledLocations = None) receiverTrackingInfos(receiver.streamId) = newReceiverInfo schedulingPolicy.rescheduleReceiver( receiver.streamId, @@ -458,7 +469,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } // Assume there is one receiver restarting at one time, so we don't need to update // receiverTrackingInfos - startReceiver(receiver, scheduledExecutors) + startReceiver(receiver, scheduledLocations) case c: CleanupOldBlocks => receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) case UpdateReceiverRateLimit(streamUID, newRate) => @@ -472,9 +483,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { // Remote messages - case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => + case RegisterReceiver(streamId, typ, host, executorId, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress) + registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) @@ -493,13 +504,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** * Return the stored scheduled executors that are still alive. */ - private def getStoredScheduledExecutors(receiverId: Int): Seq[String] = { + private def getStoredScheduledExecutors(receiverId: Int): Seq[TaskLocation] = { if (receiverTrackingInfos.contains(receiverId)) { - val scheduledExecutors = receiverTrackingInfos(receiverId).scheduledExecutors - if (scheduledExecutors.nonEmpty) { + val scheduledLocations = receiverTrackingInfos(receiverId).scheduledLocations + if (scheduledLocations.nonEmpty) { val executors = getExecutors.toSet // Only return the alive executors - scheduledExecutors.get.filter(executors) + scheduledLocations.get.filter { + case loc: ExecutorCacheTaskLocation => executors(loc) + case loc: TaskLocation => true + } } else { Nil } @@ -511,7 +525,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** * Start a receiver along with its scheduled executors */ - private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + private def startReceiver( + receiver: Receiver[_], + scheduledLocations: Seq[TaskLocation]): Unit = { def shouldStartReceiver: Boolean = { // It's okay to start when trackerState is Initialized or Started !(isTrackerStopping || isTrackerStopped) @@ -546,13 +562,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + // Create the RDD using the scheduledLocations to run the receiver in a Spark job val receiverRDD: RDD[Receiver[_]] = - if (scheduledExecutors.isEmpty) { + if (scheduledLocations.isEmpty) { ssc.sc.makeRDD(Seq(receiver), 1) } else { - val preferredLocations = - scheduledExecutors.map(hostPort => Utils.parseHostPort(hostPort)._1).distinct + val preferredLocations = scheduledLocations.map(_.toString).distinct ssc.sc.makeRDD(Seq(receiver -> preferredLocations)) } receiverRDD.setName(s"Receiver $receiverId") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index 043ff4d0ff054..ab0a84f05214d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.scheduler.ReceiverState._ private[streaming] case class ReceiverErrorInfo( @@ -28,7 +29,7 @@ private[streaming] case class ReceiverErrorInfo( * * @param receiverId the unique receiver id * @param state the current Receiver state - * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param scheduledLocations the scheduled locations provided by ReceiverSchedulingPolicy * @param runningExecutor the running executor if the receiver is active * @param name the receiver name * @param endpoint the receiver endpoint. It can be used to send messages to the receiver @@ -37,8 +38,8 @@ private[streaming] case class ReceiverErrorInfo( private[streaming] case class ReceiverTrackingInfo( receiverId: Int, state: ReceiverState, - scheduledExecutors: Option[Seq[String]], - runningExecutor: Option[String], + scheduledLocations: Option[Seq[TaskLocation]], + runningExecutor: Option[ExecutorCacheTaskLocation], name: Option[String] = None, endpoint: Option[RpcEndpointRef] = None, errorInfo: Option[ReceiverErrorInfo] = None) { @@ -47,7 +48,7 @@ private[streaming] case class ReceiverTrackingInfo( receiverId, name.getOrElse(""), state == ReceiverState.ACTIVE, - location = runningExecutor.getOrElse(""), + location = runningExecutor.map(_.host).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index b2a51d72bac2b..05b4e66c63ac6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -20,73 +20,96 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, HostTaskLocation, TaskLocation} class ReceiverSchedulingPolicySuite extends SparkFunSuite { val receiverSchedulingPolicy = new ReceiverSchedulingPolicy test("rescheduleReceiver: empty executors") { - val scheduledExecutors = + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) - assert(scheduledExecutors === Seq.empty) + assert(scheduledLocations === Seq.empty) } test("rescheduleReceiver: receiver preferredLocation") { + val executors = Seq(ExecutorCacheTaskLocation("host2", "2")) val receiverTrackingInfoMap = Map( 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( - 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) - assert(scheduledExecutors.toSet === Set("host1", "host2")) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors) + assert(scheduledLocations.toSet === Set(HostTaskLocation("host1"), executors(0))) } test("rescheduleReceiver: return all idle executors if there are any idle executors") { - val executors = Seq("host1", "host2", "host3", "host4", "host5") - // host3 is idle + val executors = (1 to 5).map(i => ExecutorCacheTaskLocation(s"host$i", s"$i")) + // executor 1 is busy, others are idle. val receiverTrackingInfoMap = Map( - 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some(executors(0)))) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( 1, None, receiverTrackingInfoMap, executors) - assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + assert(scheduledLocations.toSet === executors.tail.toSet) } test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { - val executors = Seq("host1", "host2", "host3", "host4", "host5") + val executors = Seq( + ExecutorCacheTaskLocation("host1", "1"), + ExecutorCacheTaskLocation("host2", "2"), + ExecutorCacheTaskLocation("host3", "3"), + ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5") + ) // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 val receiverTrackingInfoMap = Map( - 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), - 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), - 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None), - 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, Some(Seq("host4", "host5")), None)) - val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, + Some(ExecutorCacheTaskLocation("host1", "1"))), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host2", "2"), ExecutorCacheTaskLocation("host3", "3"))), + None), + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host1", "1"), ExecutorCacheTaskLocation("host3", "3"))), + None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, + Some(Seq(ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5"))), None)) + val scheduledLocations = receiverSchedulingPolicy.rescheduleReceiver( 4, None, receiverTrackingInfoMap, executors) - assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + val expectedScheduledLocations = Set( + ExecutorCacheTaskLocation("host2", "2"), + ExecutorCacheTaskLocation("host4", "4"), + ExecutorCacheTaskLocation("host5", "5") + ) + assert(scheduledLocations.toSet === expectedScheduledLocations) } test("scheduleReceivers: " + "schedule receivers evenly when there are more receivers than executors") { val receivers = (0 until 6).map(new RateTestReceiver(_)) - val executors = (10000 until 10003).map(port => s"localhost:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 2 receivers running on each executor and each receiver has one executor - scheduledExecutors.foreach { case (receiverId, executors) => - assert(executors.size == 1) - numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + scheduledLocations.foreach { case (receiverId, locations) => + assert(locations.size == 1) + assert(locations(0).isInstanceOf[ExecutorCacheTaskLocation]) + numReceiversOnExecutor(locations(0)) = numReceiversOnExecutor.getOrElse(locations(0), 0) + 1 } assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) } - test("scheduleReceivers: " + "schedule receivers evenly when there are more executors than receivers") { val receivers = (0 until 3).map(new RateTestReceiver(_)) - val executors = (10000 until 10006).map(port => s"localhost:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 6).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 1 receiver running on each executor and each receiver has two executors - scheduledExecutors.foreach { case (receiverId, executors) => - assert(executors.size == 2) - executors.foreach { l => + scheduledLocations.foreach { case (receiverId, locations) => + assert(locations.size == 2) + locations.foreach { l => + assert(l.isInstanceOf[ExecutorCacheTaskLocation]) numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 } } @@ -96,34 +119,41 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) - val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ - (10003 until 10006).map(port => s"localhost2:${port}") - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) - val numReceiversOnExecutor = mutable.HashMap[String, Int]() + val executors = (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString)) ++ + (3 until 6).map(executorId => + ExecutorCacheTaskLocation("localhost2", executorId.toString)) + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[TaskLocation, Int]() // There should be 1 receiver running on each executor and each receiver has 1 executor - scheduledExecutors.foreach { case (receiverId, executors) => + scheduledLocations.foreach { case (receiverId, executors) => assert(executors.size == 1) executors.foreach { l => + assert(l.isInstanceOf[ExecutorCacheTaskLocation]) numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 } } assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) // Make sure we schedule the receivers to their preferredLocations val executorsForReceiversWithPreferredLocation = - scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + scheduledLocations.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) // We can simply check the executor set because we only know each receiver only has 1 executor assert(executorsForReceiversWithPreferredLocation.toSet === - (10000 until 10003).map(port => s"localhost:${port}").toSet) + (0 until 3).map(executorId => + ExecutorCacheTaskLocation("localhost", executorId.toString) + ).toSet) } test("scheduleReceivers: return empty if no receiver") { - assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + val scheduledLocations = receiverSchedulingPolicy. + scheduleReceivers(Seq.empty, Seq(ExecutorCacheTaskLocation("localhost", "1"))) + assert(scheduledLocations.isEmpty) } test("scheduleReceivers: return empty scheduled executors if no executors") { val receivers = (0 until 3).map(new RateTestReceiver(_)) - val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) - scheduledExecutors.foreach { case (receiverId, executors) => + val scheduledLocations = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledLocations.foreach { case (receiverId, executors) => assert(executors.isEmpty) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index fda86aef457d4..3bd8d086abf7f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -99,8 +99,8 @@ class ReceiverTrackerSuite extends TestSuiteBase { output.register() ssc.start() eventually(timeout(10 seconds), interval(10 millis)) { - // If preferredLocations is set correctly, receiverTaskLocality should be NODE_LOCAL - assert(receiverTaskLocality === TaskLocality.NODE_LOCAL) + // If preferredLocations is set correctly, receiverTaskLocality should be PROCESS_LOCAL + assert(receiverTaskLocality === TaskLocality.PROCESS_LOCAL) } } } From b960a890561eaf3795b93c621bd95be81e56f5b7 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 27 Oct 2015 16:55:10 -0700 Subject: [PATCH 062/324] [SPARK-11178] Improving naming around task failures. Commit af3bc59d1f5d9d952c2d7ad1af599c49f1dbdaf0 introduced new functionality so that if an executor dies for a reason that's not caused by one of the tasks running on the executor (e.g., due to pre-emption), Spark doesn't count the failure towards the maximum number of failures for the task. That commit introduced some vague naming that this commit attempts to fix; in particular: (1) The variable "isNormalExit", which was used to refer to cases where the executor died for a reason unrelated to the tasks running on the machine, has been renamed (and reversed) to "exitCausedByApp". The problem with the existing name is that it's not clear (at least to me!) what it means for an exit to be "normal"; the new name is intended to make the purpose of this variable more clear. (2) The variable "shouldEventuallyFailJob" has been renamed to "countTowardsTaskFailures". This variable is used to determine whether a task's failure should be counted towards the maximum number of failures allowed for a task before the associated Stage is aborted. The problem with the existing name is that it can be confused with implying that the task's failure should immediately cause the stage to fail because it is somehow fatal (this is the case for a fetch failure, for example: if a task fails because of a fetch failure, there's no point in retrying, and the whole stage should be failed). Author: Kay Ousterhout Closes #9164 from kayousterhout/SPARK-11178. --- .../org/apache/spark/TaskEndReason.scala | 22 ++++++++++----- .../spark/scheduler/ExecutorLossReason.scala | 9 ++++--- .../spark/scheduler/TaskSetManager.scala | 16 ++++++----- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 8 +++--- .../cluster/mesos/MesosSchedulerBackend.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 9 +++---- .../spark/scheduler/TaskSetManagerSuite.scala | 10 ++++--- .../apache/spark/util/JsonProtocolSuite.scala | 6 ++--- .../spark/deploy/yarn/YarnAllocator.scala | 27 ++++++++++--------- 11 files changed, 66 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 9335c5f4160bf..18278b292ff5a 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -53,7 +53,13 @@ sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String - def shouldEventuallyFailJob: Boolean = true + /** + * Whether this task failure should be counted towards the maximum number of times the task is + * allowed to fail before the stage is aborted. Set to false in cases where the task's failure + * was unrelated to the task; for example, if the task failed because the executor it was running + * on was killed. + */ + def countTowardsTaskFailures: Boolean = true } /** @@ -208,7 +214,7 @@ case class TaskCommitDenied( * towards failing the stage. This is intended to prevent spurious stage failures in cases * where many speculative tasks are launched and denied to commit. */ - override def shouldEventuallyFailJob: Boolean = false + override def countTowardsTaskFailures: Boolean = false } /** @@ -217,14 +223,18 @@ case class TaskCommitDenied( * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false) +case class ExecutorLostFailure(execId: String, exitCausedByApp: Boolean = true) extends TaskFailedReason { override def toErrorString: String = { - val exitBehavior = if (isNormalExit) "normally" else "abnormally" - s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + val exitBehavior = if (exitCausedByApp) { + "caused by one of the running tasks" + } else { + "unrelated to the running tasks" + } + s"ExecutorLostFailure (executor ${execId} exited due to an issue ${exitBehavior})" } - override def shouldEventuallyFailJob: Boolean = !isNormalExit + override def countTowardsTaskFailures: Boolean = exitCausedByApp } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 0a98c69b89ea5..33edf25043850 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -28,12 +28,15 @@ class ExecutorLossReason(val message: String) extends Serializable { } private[spark] -case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String) +case class ExecutorExited(exitCode: Int, exitCausedByApp: Boolean, reason: String) extends ExecutorLossReason(reason) private[spark] object ExecutorExited { - def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = { - ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode)) + def apply(exitCode: Int, exitCausedByApp: Boolean): ExecutorExited = { + ExecutorExited( + exitCode, + exitCausedByApp, + ExecutorExitCode.explainExitCode(exitCode)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 987800d3d1f1e..9b3fad9012abc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -704,9 +704,10 @@ private[spark] class TaskSetManager( } ef.exception - case e: ExecutorLostFailure if e.isNormalExit => + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor" + - s" exited normally. Not marking the task as failed.") + "exited for a reason unrelated to the task. Not counting this failure towards the " + + "maximum number of failures for the task.") None case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others @@ -724,7 +725,7 @@ private[spark] class TaskSetManager( addPendingTask(index) if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] - && reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) { + && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -797,11 +798,12 @@ private[spark] class TaskSetManager( } } for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - val isNormalExit: Boolean = reason match { - case exited: ExecutorExited => exited.isNormalExit - case _ => false + val exitCausedByApp: Boolean = reason match { + case exited: ExecutorExited => exited.exitCausedByApp + case _ => true } - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit)) + handleFailedTask( + tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 4652df32efa74..8103efa7302e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -98,7 +98,7 @@ private[spark] object CoarseGrainedClusterMessages { hostToLocalTaskCount: Map[String, Int]) extends CoarseGrainedClusterMessage - // Check if an executor was force-killed but for a normal reason. + // Check if an executor was force-killed but for a reason unrelated to the running tasks. // This could be the case if the executor is preempted, for instance. case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a4214c496166d..05d9bc92f228b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -137,7 +137,7 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code, isNormalExit = false, message) + case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 38218b9c08fd8..e483688edef5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -111,10 +111,10 @@ private[spark] abstract class YarnSchedulerBackend( * immediately. * * In YARN's case however it is crucial to talk to the application master and ask why the - * executor had exited. In particular, the executor may have exited due to the executor - * having been preempted. If the executor "exited normally" according to the application - * master then we pass that information down to the TaskSetManager to inform the - * TaskSetManager that tasks on that lost executor should not count towards a job failure. + * executor had exited. If the executor exited for some reason unrelated to the running tasks + * (e.g., preemption), according to the application master, then we pass that information down + * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should + * not count towards a job failure. * * TODO there's a race condition where while we are querying the ApplicationMaster for * the executor loss reason, there is the potential that tasks will be scheduled on diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 6196176c7cc33..aaffac604a885 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -394,7 +394,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status, isNormalExit = false)) + recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a06dc6f709d33..ad6615c1124d0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -367,9 +367,9 @@ private[spark] object JsonProtocol { ("Job ID" -> taskCommitDenied.jobID) ~ ("Partition ID" -> taskCommitDenied.partitionID) ~ ("Attempt Number" -> taskCommitDenied.attemptNumber) - case ExecutorLostFailure(executorId, isNormalExit) => + case ExecutorLostFailure(executorId, exitCausedByApp) => ("Executor ID" -> executorId) ~ - ("Normal Exit" -> isNormalExit) + ("Exit Caused By App" -> exitCausedByApp) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -810,10 +810,9 @@ private[spark] object JsonProtocol { val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). - map(_.extract[Boolean]) + val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown"), isNormalExit.getOrElse(false)) + ExecutorLostFailure(executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true)) case `unknownReason` => UnknownReason } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index cd6bf723e70cb..ecc18fc6e15b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -511,7 +511,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } - test("Executors are added but exit normally while running tasks") { + test("Executors exit for reason unrelated to currently running tasks") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) val taskSet = FakeTask.createTaskSet(4, @@ -526,11 +526,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.executorAdded() assert(manager.resourceOffer("exec1", "host1", ANY).isDefined) sched.removeExecutor("execA") - manager.executorLost("execA", "host1", ExecutorExited(143, true, "Normal termination")) + manager.executorLost( + "execA", + "host1", + ExecutorExited(143, false, "Terminated for reason unrelated to running tasks")) assert(!sched.taskSetsFailed.contains(taskSet.id)) assert(manager.resourceOffer("execC", "host2", ANY).isDefined) sched.removeExecutor("execC") - manager.executorLost("execC", "host2", ExecutorExited(1, false, "Abnormal termination")) + manager.executorLost( + "execC", "host2", ExecutorExited(1, true, "Terminated due to issue with running tasks")) assert(sched.taskSetsFailed.contains(taskSet.id)) } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f9572921f43cb..86137f259c13d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -603,10 +603,10 @@ class JsonProtocolSuite extends SparkFunSuite { assert(jobId1 === jobId2) assert(partitionId1 === partitionId2) assert(attemptNumber1 === attemptNumber2) - case (ExecutorLostFailure(execId1, isNormalExit1), - ExecutorLostFailure(execId2, isNormalExit2)) => + case (ExecutorLostFailure(execId1, exit1CausedByApp), + ExecutorLostFailure(execId2, exit2CausedByApp)) => assert(execId1 === execId2) - assert(isNormalExit1 === isNormalExit2) + assert(exit1CausedByApp === exit2CausedByApp) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 1deaa3743ddfa..875bbd4e4e3d5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -445,40 +445,41 @@ private[yarn] class YarnAllocator( // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit. val exitStatus = completedContainer.getExitStatus - val (isNormalExit, containerExitReason) = exitStatus match { + val (exitCausedByApp, containerExitReason) = exitStatus match { case ContainerExitStatus.SUCCESS => - (true, s"Executor for container $containerId exited normally.") + (false, s"Executor for container $containerId exited because of a YARN event (e.g., " + + "pre-emption) and not because of an error in the running job.") case ContainerExitStatus.PREEMPTED => - // Preemption should count as a normal exit, since YARN preempts containers merely - // to do resource sharing, and tasks that fail due to preempted executors could + // Preemption is not the fault of the running tasks, since YARN preempts containers + // merely to do resource sharing, and tasks that fail due to preempted executors could // just as easily finish on any other executor. See SPARK-8167. - (true, s"Container ${containerId}${onHostStr} was preempted.") + (false, s"Container ${containerId}${onHostStr} was preempted.") // Should probably still count memory exceeded exit codes towards task failures case VMEM_EXCEEDED_EXIT_CODE => - (false, memLimitExceededLogMessage( + (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, VMEM_EXCEEDED_PATTERN)) case PMEM_EXCEEDED_EXIT_CODE => - (false, memLimitExceededLogMessage( + (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case unknown => numExecutorsFailed += 1 - (false, "Container marked as failed: " + containerId + onHostStr + + (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) } - if (isNormalExit) { - logInfo(containerExitReason) - } else { + if (exitCausedByApp) { logWarning(containerExitReason) + } else { + logInfo(containerExitReason) } - ExecutorExited(0, isNormalExit, containerExitReason) + ExecutorExited(0, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed - ExecutorExited(completedContainer.getExitStatus, isNormalExit = true, + ExecutorExited(completedContainer.getExitStatus, exitCausedByApp = false, s"Container $containerId exited from explicit termination request.") } From d9c6039897236c3f1e4503aa95c5c9b07b32eadd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 27 Oct 2015 20:26:38 -0700 Subject: [PATCH 063/324] [SPARK-10484] [SQL] Optimize the cartesian join with broadcast join for some cases In some cases, we can broadcast the smaller relation in cartesian join, which improve the performance significantly. Author: Cheng Hao Closes #8652 from chenghao-intel/cartesian. --- .../spark/sql/execution/SparkPlanner.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 38 +++++--- .../joins/BroadcastNestedLoopJoin.scala | 7 +- .../org/apache/spark/sql/JoinSuite.scala | 92 +++++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 3 +- ... JOIN #1-0-abfc0b99ee357f71639f6162345fe8e | 20 ++++ ...JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd | 20 ++++ ...JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 | 20 ++++ ...JOIN #4-0-45f8602d257655322b7d18cad09f6a0f | 20 ++++ .../sql/hive/execution/HiveQuerySuite.scala | 54 +++++++++++ 10 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b346f43faebe2..0f98fe88b2101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { EquiJoinSelection :: InMemoryScans :: BasicOperators :: + BroadcastNestedLoop :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) + DefaultJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 637deff4e2202..ee9716285316a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index efef8c8a8b96a..05d20f511aef8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer @@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => + case Inner => + // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case + left.output ++ right.output + case x => // TODO support the Left Semi Join throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b1fb06815868c..a9ca46cab067d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() + def statisticSizeInByte(df: DataFrame): BigInt = { + df.queryExecution.optimizedPlan.statistics.sizeInBytes + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") @@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext { sql("UNCACHE TABLE testData") } + test("cross join with broadcast") { + sql("CACHE TABLE testData") + + val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + + // we set the threshold is greater than statistic of the cached table testData + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + + assert(statisticSizeInByte(sqlContext.table("testData2")) > + sqlContext.conf.autoBroadcastJoinThreshold) + + assert(statisticSizeInByte(sqlContext.table("testData")) < + sqlContext.conf.autoBroadcastJoinThreshold) + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", + classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2 + """.stripMargin), + Row("2", 1, 1) :: + Row("2", 1, 2) :: + Row("2", 2, 1) :: + Row("2", 2, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + } + + sql("UNCACHE TABLE testData") + } + test("left semi join") { val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c328734df316b..83a81cf5d1fcf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -588,8 +588,9 @@ class HiveContext private[hive]( LeftSemiJoin, EquiJoinSelection, BasicOperators, + BroadcastNestedLoop, CartesianProduct, - BroadcastNestedLoopJoin + DefaultJoin ) } diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e new file mode 100644 index 0000000000000..0bb9399af0c45 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +306 0 +306 0 +306 0 +307 0 +307 0 +307 0 +307 0 +307 0 +307 0 +308 0 +308 0 +308 0 +309 0 +309 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2878500453141..b52f7d4b57899 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin + import scala.util.Try import org.scalatest.BeforeAndAfter @@ -69,6 +71,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + // Testing the Broadcast based join for cartesian join (cross join) + // We assume that the Broadcast Join Threshold will works since the src is a small table + private val spark_10484_1 = """ + | SELECT a.key, b.key + | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY b.key, a.key + | LIMIT 20 + """.stripMargin + private val spark_10484_2 = """ + | SELECT a.key, b.key + | FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_3 = """ + | SELECT a.key, b.key + | FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_4 = """ + | SELECT a.key, b.key + | FROM src a JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1", + spark_10484_1) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2", + spark_10484_2) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3", + spark_10484_3) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4", + spark_10484_4) + + test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") { + def assertBroadcastNestedLoopJoin(sqlText: String): Unit = { + assert(sql(sqlText).queryExecution.sparkPlan.collect { + case _: BroadcastNestedLoopJoin => 1 + }.nonEmpty) + } + + assertBroadcastNestedLoopJoin(spark_10484_1) + assertBroadcastNestedLoopJoin(spark_10484_2) + assertBroadcastNestedLoopJoin(spark_10484_3) + assertBroadcastNestedLoopJoin(spark_10484_4) + } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", """ SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP From 826e1e304b57abbc56b8b7ffd663d53942ab3c7c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 27 Oct 2015 23:07:37 -0700 Subject: [PATCH 064/324] [SPARK-11302][MLLIB] 2) Multivariate Gaussian Model with Covariance matrix returns incorrect answer in some cases Fix computation of root-sigma-inverse in multivariate Gaussian; add a test and fix related Python mixture model test. Supersedes https://github.com/apache/spark/pull/9293 Author: Sean Owen Closes #9309 from srowen/SPARK-11302.2. --- .../stat/distribution/MultivariateGaussian.scala | 8 ++++---- .../distribution/MultivariateGaussianSuite.scala | 15 +++++++++++++++ python/pyspark/mllib/clustering.py | 4 ++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 92a5af708d04b..0724af93088c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -56,7 +56,7 @@ class MultivariateGaussian @Since("1.3.0") ( /** * Compute distribution dependent constants: - * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t + * rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants @@ -104,11 +104,11 @@ class MultivariateGaussian @Since("1.3.0") ( * * sigma = U * D * U.t * inv(Sigma) = U * inv(D) * U.t - * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) + * = (D^{-1/2}^ * U.t).t * (D^{-1/2}^ * U.t) * * and thus * - * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ + * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U.t * (x-mu))^2^ * * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered @@ -130,7 +130,7 @@ class MultivariateGaussian @Since("1.3.0") ( // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) + (pinvS * u.t, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => throw new IllegalArgumentException("Covariance matrix has no non-zero singular values") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index aa60deb665aeb..6e7a003475458 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -65,4 +65,19 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) } + + test("SPARK-11302") { + val x = Vectors.dense(629, 640, 1.7188, 618.19) + val mu = Vectors.dense( + 1055.3910505836575, 1070.489299610895, 1.39020554474708, 1040.5907503867697) + val sigma = Matrices.dense(4, 4, Array( + 166769.00466698944, 169336.6705268059, 12.820670788921873, 164243.93314092053, + 169336.6705268059, 172041.5670061245, 21.62590020524533, 166678.01075856484, + 12.820670788921873, 21.62590020524533, 0.872524191943962, 4.283255814732373, + 164243.93314092053, 166678.01075856484, 4.283255814732373, 161848.9196719207)) + val dist = new MultivariateGaussian(mu, sigma) + // Agrees with R's dmvnorm: 7.154782e-05 + assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9) + } + } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c451df17cf264..d1c3755a785f2 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -236,9 +236,9 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() - >>> labels[0]==labels[1]==labels[2] + >>> labels[0]==labels[1] True - >>> labels[3]==labels[4] + >>> labels[2]==labels[3]==labels[4] True .. versionadded:: 1.3.0 From 82c1c5772817785709b0289f7d836beba812c791 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 27 Oct 2015 23:41:42 -0700 Subject: [PATCH 065/324] [MINOR][ML] fix compile warns This fixes some compile time warnings. Author: Xiangrui Meng Closes #9319 from mengxr/mllib-compile-warn-20151027. --- .../org/apache/spark/ml/regression/LinearRegression.scala | 2 +- .../spark/ml/classification/LogisticRegressionSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 573a61a6eabdf..c3ee8b3bc1ba0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,7 +145,7 @@ class LinearRegression(override val uid: String) // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).limit(1).map { case Row(features: Vector) => features.size - }.toArray()(0) + }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 5186c4e2be64d..e0a795e5e0b00 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.classification +import scala.language.existentials import scala.util.Random import org.apache.spark.SparkFunSuite @@ -24,7 +25,7 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ From 5f1cee6f158adb1f9f485ed1d529c56bace68adc Mon Sep 17 00:00:00 2001 From: Nakul Jindal Date: Wed, 28 Oct 2015 01:02:03 -0700 Subject: [PATCH 066/324] [SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance WeightedLeastSquares now uses the common Instance class in ml.feature instead of a private one. Author: Nakul Jindal Closes #9325 from nakul02/SPARK-11332_refactor_WeightedLeastSquares_dot_Instance. --- .../spark/ml/optim/WeightedLeastSquares.scala | 25 ++++++------------- .../ml/regression/LinearRegression.scala | 4 +-- .../ml/optim/WeightedLeastSquaresSuite.scala | 10 ++++---- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index d7eaa5a9268ff..3d64f7f296137 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD @@ -121,16 +122,6 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { - /** - * Case class for weighted observations. - * @param w weight, must be positive - * @param a features - * @param b label - */ - case class Instance(w: Double, a: Vector, b: Double) { - require(w >= 0.0, s"Weight cannot be negative: $w.") - } - /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -168,8 +159,8 @@ private[ml] object WeightedLeastSquares { * Adds an instance. */ def add(instance: Instance): this.type = { - val Instance(w, a, b) = instance - val ak = a.size + val Instance(l, w, f) = instance + val ak = f.size if (!initialized) { init(ak) } @@ -177,11 +168,11 @@ private[ml] object WeightedLeastSquares { count += 1L wSum += w wwSum += w * w - bSum += w * b - bbSum += w * b * b - BLAS.axpy(w, a, aSum) - BLAS.axpy(w * b, a, abSum) - BLAS.spr(w, a, aaSum) + bSum += w * l + bbSum += w * l * l + BLAS.axpy(w, f, aSum) + BLAS.axpy(w * l, f, abSum) + BLAS.spr(w, f, aaSum) this } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c3ee8b3bc1ba0..f663b9bd9ac73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -154,10 +154,10 @@ class LinearRegression(override val uid: String) "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[WeightedLeastSquares.Instance] = dataset.select( + val instances: RDD[Instance] = dataset.select( col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => - WeightedLeastSquares.Instance(weight, features, label) + Instance(label, weight, features) } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 652f3adb984d3..b542ba3dc54d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.optim.WeightedLeastSquares.Instance +import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -38,10 +38,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext w <- c(1, 2, 3, 4) */ instances = sc.parallelize(Seq( - Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0), - Instance(2.0, Vectors.dense(1.0, 7.0), 19.0), - Instance(3.0, Vectors.dense(2.0, 11.0), 23.0), - Instance(4.0, Vectors.dense(3.0, 13.0), 29.0) + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) } From 075ce4914fdcbbcc7286c3c30cb940ed28d474d2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Oct 2015 13:58:52 +0100 Subject: [PATCH 067/324] [SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets) A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets. Author: Wenchen Fan Closes #9324 from cloud-fan/cogroup2. --- .../sql/catalyst/expressions/UnsafeRow.java | 1 + .../plans/logical/basicOperators.scala | 39 ++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 20 +++++ .../sql/execution/CoGroupedIterator.scala | 89 +++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 4 + .../spark/sql/execution/basicOperators.scala | 41 +++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++ .../execution/CoGroupedIteratorSuite.scala | 51 +++++++++++ 8 files changed, 257 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 850838af9be35..5ba14ebdb62a4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -591,6 +591,7 @@ public String toString() { build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } + build.deleteCharAt(build.length() - 1); build.append(']'); return build.toString(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d2d3db0a44484..4cb67aacf33ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -513,3 +513,42 @@ case class MapGroups[K, T, U]( override def missingInput: AttributeSet = AttributeSet.empty } +/** Factory for constructing new `CoGroup` nodes. */ +object CoGroup { + def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( + func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan): CoGroup[K, Left, Right, R] = { + CoGroup( + func, + encoderFor[K], + encoderFor[Left], + encoderFor[Right], + encoderFor[R], + encoderFor[R].schema.toAttributes, + leftGroup, + rightGroup, + left, + right) + } +} + +/** + * A relation produced by applying `func` to each grouping key and associated values from left and + * right children. + */ +case class CoGroup[K, Left, Right, R]( + func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + kEncoder: ExpressionEncoder[K], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + rEncoder: ExpressionEncoder[R], + output: Seq[Attribute], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 89a16dd8b0acc..612f2b60cd405 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + */ + def cogroup[U, R : Encoder]( + other: GroupedDataset[K, U])( + f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + implicit def uEnc: Encoder[U] = other.tEncoder + new Dataset[R]( + sqlContext, + CoGroup( + f, + this.groupingAttributes, + other.groupingAttributes, + this.logicalPlan, + other.logicalPlan)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala new file mode 100644 index 0000000000000..ce5827855e4aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering + +/** + * Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a + * grouping key with its associated values from all [[GroupedIterator]]s. + * Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key. + */ +class CoGroupedIterator( + left: Iterator[(InternalRow, Iterator[InternalRow])], + right: Iterator[(InternalRow, Iterator[InternalRow])], + groupingSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] { + + private val keyOrdering = + GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema) + + private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _ + private var currentRightData: (InternalRow, Iterator[InternalRow]) = _ + + override def hasNext: Boolean = left.hasNext || right.hasNext + + override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + if (currentLeftData.eq(null) && left.hasNext) { + currentLeftData = left.next() + } + if (currentRightData.eq(null) && right.hasNext) { + currentRightData = right.next() + } + + assert(currentLeftData.ne(null) || currentRightData.ne(null)) + + if (currentLeftData.eq(null)) { + // left is null, right is not null, consume the right data. + rightOnly() + } else if (currentRightData.eq(null)) { + // left is not null, right is null, consume the left data. + leftOnly() + } else if (currentLeftData._1 == currentRightData._1) { + // left and right have the same grouping key, consume both of them. + val result = (currentLeftData._1, currentLeftData._2, currentRightData._2) + currentLeftData = null + currentRightData = null + result + } else { + val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1) + assert(compare != 0) + if (compare < 0) { + // the grouping key of left is smaller, consume the left data. + leftOnly() + } else { + // the grouping key of right is smaller, consume the right data. + rightOnly() + } + } + } + + private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentLeftData._1, currentLeftData._2, Iterator.empty) + currentLeftData = null + result + } + + private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentRightData._1, Iterator.empty, currentRightData._2) + currentRightData = null + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ee9716285316a..32067266b516b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, + leftGroup, rightGroup, left, right) => + execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 89938471ee381..d5a803f8c4b24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -390,3 +390,44 @@ case class MapGroups[K, T, U]( } } } + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup[K, Left, Right, R]( + func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + kEncoder: ExpressionEncoder[K], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + rEncoder: ExpressionEncoder[R], + output: Seq[Attribute], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val groupKeyEncoder = kEncoder.bind(leftGroup) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + groupKeyEncoder.fromRow(key), + leftResult.map(leftEnc.fromRow), + rightResult.map(rightEnc.fromRow)) + result.map(rEncoder.toRow) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index aebb390a1d15d..993e6d269ee03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { agged, ("a", 30), ("b", 3), ("c", 1)) } + + test("cogroup") { + val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() + val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala new file mode 100644 index 0000000000000..d1fe81947e9ea --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper + +class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator + val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + assert(result == + (1, + Seq(create_row(1, "a"), create_row(1, "b")), + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "c")), + Seq(create_row(2, 3L))) :: + (3, + Seq.empty, + Seq(create_row(3, 4L))) :: + Nil + ) + } +} From fd9e345ceeff385ba614a16d478097650caa98d0 Mon Sep 17 00:00:00 2001 From: "Mageswaran.D" Date: Wed, 28 Oct 2015 08:46:30 -0700 Subject: [PATCH 068/324] Typo in mllib-evaluation-metrics.md Recall by threshold snippet was using "precisionByThreshold" Author: Mageswaran.D Closes #9333 from Mageswaran1989/Typo_in_mllib-evaluation-metrics.md. --- docs/mllib-evaluation-metrics.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 2270f7a34b069..f73eff637dc36 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -141,7 +141,7 @@ precision.foreach { case (t, p) => } // Recall by threshold -val recall = metrics.precisionByThreshold +val recall = metrics.recallByThreshold recall.foreach { case (t, r) => println(s"Threshold: $t, Recall: $r") } @@ -1509,4 +1509,4 @@ print("Explained variance = %s" % metrics.explainedVariance) {% endhighlight %} - \ No newline at end of file + From fba9e95452ca0a9b589bc14b27c750c69f482b8d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 28 Oct 2015 08:50:21 -0700 Subject: [PATCH 069/324] [SPARK-11369][ML][R] SparkR glm should support setting standardize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SparkR glm currently support : ```formula, family = c(“gaussian”, “binomial”), data, lambda = 0, alpha = 0``` We should also support setting standardize which has been defined at [design documentation](https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit) Author: Yanbo Liang Closes #9331 from yanboliang/spark-11369. --- R/pkg/R/mllib.R | 4 ++-- .../src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 25615e805e03c..aadd5b8da5e3b 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -46,11 +46,11 @@ setClass("PipelineModel", representation(model = "jobj")) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - solver = "auto") { + standardize = TRUE, solver = "auto") { family <- match.arg(family) model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha, solver) + alpha, standardize, solver) return(new("PipelineModel", model = model)) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index fec61fed3cb9c..21ebf6d916db7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -31,6 +31,7 @@ private[r] object SparkRWrappers { family: String, lambda: Double, alpha: Double, + standardize: Boolean, solver: String): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { @@ -38,11 +39,13 @@ private[r] object SparkRWrappers { .setRegParam(lambda) .setElasticNetParam(alpha) .setFitIntercept(formula.hasIntercept) + .setStandardization(standardize) .setSolver(solver) case "binomial" => new LogisticRegression() .setRegParam(lambda) .setElasticNetParam(alpha) .setFitIntercept(formula.hasIntercept) + .setStandardization(standardize) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) From f92b7b98e9998a6069996cc66ca26cbfa695fce5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 28 Oct 2015 08:54:20 -0700 Subject: [PATCH 070/324] [SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver [SPARK-10668](https://issues.apache.org/jira/browse/SPARK-10668) has provided ```WeightedLeastSquares``` solver("normal") in ```LinearRegression``` with L2 regularization in Scala and R, Python ML ```LinearRegression``` should also support setting solver("auto", "normal", "l-bfgs") Author: Yanbo Liang Closes #9328 from yanboliang/spark-11367. --- .../ml/param/_shared_params_code_gen.py | 4 ++- python/pyspark/ml/param/shared.py | 28 +++++++++++++++++++ python/pyspark/ml/regression.py | 27 ++++-------------- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 7143d56330bd6..070c5db01ae73 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -135,7 +135,9 @@ def get$Name(self): "values >= 0. The class with largest value p/t is predicted, where p is the original " + "probability of that class and t is the class' threshold.", None), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None)] + "all instance weights as 1.0.", None), + ("solver", "the solver algorithm for optimization. If this is not set or empty, " + + "default value is 'auto'.", "'auto'")] code = [] for name, doc, defaultValueStr in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 3a58ac87d6b65..4bdf2a8cc563f 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -597,6 +597,34 @@ def getWeightCol(self): return self.getOrDefault(self.weightCol) +class HasSolver(Params): + """ + Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + """ + + # a placeholder to make it appear in the generated doc + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + + def __init__(self): + super(HasSolver, self).__init__() + #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self._setDefault(solver='auto') + + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + self._paramMap[self.solver] = value + return self + + def getSolver(self): + """ + Gets the value of solver or its default value. + """ + return self.getOrDefault(self.solver) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index eeb18b3e9d290..dc68815556d4e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -33,7 +33,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization): + HasStandardization, HasSolver): """ Linear regression. @@ -50,7 +50,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction @@ -73,11 +73,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True): + standardization=True, solver="auto"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True) + standardization=True, solver="auto") """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -90,11 +90,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True): + standardization=True, solver="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True) + standardization=True, solver="auto") Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,21 +103,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - @since("1.4.0") - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - @since("1.4.0") - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ From 032748bb9add096e4691551ee73834f3e5363dd5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 28 Oct 2015 09:40:05 -0700 Subject: [PATCH 071/324] [SPARK-11377] [SQL] withNewChildren should not convert StructType to Seq This is minor, but I ran into while writing Datasets and while it wasn't needed for the final solution, it was super confusing so we should fix it. Basically we recurse into `Seq` to see if they have children. This breaks because we don't preserve the original subclass of `Seq` (and `StructType <:< Seq[StructField]`). Since a struct can never contain children, lets just not recurse into it. Author: Michael Armbrust Closes #9334 from marmbrus/structMakeCopy. --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 7971e25188e8d..35f087baccdee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StructType, DataType} /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -176,6 +176,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer val newArgs = productIterator.map { + case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { case arg: TreeNode[_] if containsChild(arg) => @@ -337,6 +338,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { |Is otherCopyArgs specified correctly for $nodeName. |Exception message: ${e.getMessage} |ctor: $defaultCtor? + |types: ${newArgs.map(_.getClass).mkString(", ")} |args: ${newArgs.mkString(", ")} """.stripMargin) } From 5aa05219118e3d3525fb703a4716ae8e04f3da72 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 28 Oct 2015 14:28:38 -0700 Subject: [PATCH 072/324] [SPARK-11292] [SQL] Python API for text data source Adds DataFrameReader.text and DataFrameWriter.text. Author: Reynold Xin Closes #9259 from rxin/SPARK-11292. --- python/pyspark/sql/readwriter.py | 27 +++++++++++++++++++++++++-- python/test_support/sql/text-test.txt | 2 ++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 python/test_support/sql/text-test.txt diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 93832d4c713e5..97bd90c4db829 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -23,6 +23,7 @@ from py4j.java_gateway import JavaClass from pyspark import RDD, since +from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * @@ -193,10 +194,22 @@ def parquet(self, *paths): """ return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + @ignore_unicode_prefix + @since(1.6) + def text(self, path): + """Loads a text file and returns a [[DataFrame]] with a single string column named "text". + + Each line in the text file is a new row in the resulting DataFrame. + + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') + >>> df.collect() + [Row(text=u'hello'), Row(text=u'this')] + """ + return self._df(self._jreader.text(path)) + @since(1.5) def orc(self, path): - """ - Loads an ORC file, returning the result as a :class:`DataFrame`. + """Loads an ORC file, returning the result as a :class:`DataFrame`. ::Note: Currently ORC support is only available together with :class:`HiveContext`. @@ -432,6 +445,16 @@ def parquet(self, path, mode=None, partitionBy=None): self.partitionBy(partitionBy) self._jwrite.parquet(path) + @since(1.6) + def text(self, path): + """Saves the content of the DataFrame in a text file at the specified path. + + The DataFrame must have only one column that is of string type. + Each row becomes a new line in the output file. + """ + self._jwrite.text(path) + + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. diff --git a/python/test_support/sql/text-test.txt b/python/test_support/sql/text-test.txt new file mode 100644 index 0000000000000..ae1e76c9e93a7 --- /dev/null +++ b/python/test_support/sql/text-test.txt @@ -0,0 +1,2 @@ +hello +this \ No newline at end of file From 20dfd46743401a528b70dfb7862e50ce9a3f8e02 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Oct 2015 15:57:01 -0700 Subject: [PATCH 073/324] [SPARK-11363] [SQL] LeftSemiJoin should be LeftSemi in SparkStrategies JIRA: https://issues.apache.org/jira/browse/SPARK-11363 In SparkStrategies some places use LeftSemiJoin. It should be LeftSemi. cc chenghao-intel liancheng Author: Liang-Chi Hsieh Closes #9318 from viirya/no-left-semi-join. --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 32067266b516b..86d1d390f1918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -297,11 +297,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin => + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin => + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil @@ -311,7 +311,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin => + case logical.Join(left, right, joinType, None) if joinType != LeftSemi => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, From e5b89978edf7fa52090116b9b5b53ddaeef08beb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 29 Oct 2015 11:34:54 +0800 Subject: [PATCH 074/324] [SPARK-11376][SQL] Removes duplicated `mutableRow` field This PR fixes a mistake in the code generated by `GenerateColumnAccessor`. Interestingly, although the code is illegal in Java (the class has two fields with the same name), Janino accepts it happily and accidentally works properly. Author: Cheng Lian Closes #9335 from liancheng/spark-11376.fix-generated-code. --- .../org/apache/spark/sql/columnar/GenerateColumnAccessor.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index d0f5bfa1cd7bc..7980a6f36d8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -140,7 +140,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private int numRowsInBatch = 0; private scala.collection.Iterator input = null; - private MutableRow mutableRow = null; private DataType[] columnTypes = null; private int[] columnIndexes = null; @@ -156,7 +155,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { this.input = input; - this.mutableRow = mutableRow; this.columnTypes = columnTypes; this.columnIndexes = columnIndexes; } From 0cb7662d8683c913c4fff02e8fb0ec75261d9731 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Oct 2015 21:35:57 -0700 Subject: [PATCH 075/324] [SPARK-11351] [SQL] support hive interval literal Author: Wenchen Fan Closes #9304 from cloud-fan/interval. --- .../apache/spark/sql/catalyst/SqlParser.scala | 71 +++++++++++++------ .../spark/sql/catalyst/SqlParserSuite.scala | 52 ++++++++++++++ 2 files changed, 103 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 833368b7d5898..0fef04302714e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -322,7 +322,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val literal: Parser[Literal] = ( numericLiteral | booleanLiteral - | stringLit ^^ {case s => Literal.create(s, StringType) } + | stringLit ^^ { case s => Literal.create(s, StringType) } | intervalLiteral | NULL ^^^ Literal.create(null, NullType) ) @@ -349,13 +349,12 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val integral: Parser[String] = sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - private def intervalUnit(unitName: String) = - acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} + private def intervalUnit(unitName: String) = acceptIf { + case lexical.Identifier(str) => + val normalized = lexical.normalizeKeyword(str) + normalized == unitName || normalized == unitName + "s" + case _ => false + } {_ => "wrong interval unit"} protected lazy val month: Parser[Int] = integral <~ intervalUnit("month") ^^ { case num => num.toInt } @@ -396,21 +395,53 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } + private def intervalKeyword(keyword: String) = acceptIf { + case lexical.Identifier(str) => + lexical.normalizeKeyword(str) == keyword + case _ => false + } {_ => "wrong interval keyword"} + protected lazy val intervalLiteral: Parser[Literal] = - INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { - case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ + ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ + intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromYearMonthString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ + intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromDayTimeString(s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("year", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("month", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("day", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("hour", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("minute", s)) + } + | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => + Literal(CalendarInterval.fromSingleUnitString("second", s)) + } + | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ + millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) + if (!Seq(year, month, week, day, hour, minute, second, + millisecond, microsecond).exists(_.isDefined)) { + throw new AnalysisException( + "at least one time unit should be given for interval literal") } + val months = Seq(year, month).map(_.getOrElse(0)).sum + val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) + .map(_.getOrElse(0L)).sum + Literal(new CalendarInterval(months, microseconds)) + } + ) private def toNarrowestIntegerType(value: String): Any = { val bigIntValue = BigDecimal(value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 79b4846cb9544..ea28bfa021bed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -74,4 +75,55 @@ class SqlParserSuite extends PlanTest { OneRowRelation) comparePlans(parsed, expected) } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = SqlParser.parse(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } } From 3dfa4ea526c881373eeffe541bc378d1fa598129 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Oct 2015 21:45:00 -0700 Subject: [PATCH 076/324] [SPARK-11322] [PYSPARK] Keep full stack trace in captured exception JIRA: https://issues.apache.org/jira/browse/SPARK-11322 As reported by JoshRosen in [databricks/spark-redshift/issues/89](https://github.com/databricks/spark-redshift/issues/89#issuecomment-149828308), the exception-masking behavior sometimes makes debugging harder. To deal with this issue, we should keep full stack trace in the captured exception. Author: Liang-Chi Hsieh Closes #9283 from viirya/py-exception-stacktrace. --- python/pyspark/sql/tests.py | 6 ++++++ python/pyspark/sql/utils.py | 19 +++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6356d4bd6669b..4c03a0d4ffe93 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1079,6 +1079,12 @@ def test_capture_illegalargument_exception(self): df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) + try: + df.select(sha2(df.a, 1024)).collect() + except IllegalArgumentException as e: + self.assertRegexpMatches(e.desc, "1024 is not in the permitted values") + self.assertRegexpMatches(e.stackTrace, + "org.apache.spark.sql.functions") def test_with_column_with_existing_name(self): keys = self.df.withColumn("key", self.df.key).select("key").collect() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 0f795ca35b38a..c4fda8bd3b891 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -18,13 +18,22 @@ import py4j -class AnalysisException(Exception): +class CapturedException(Exception): + def __init__(self, desc, stackTrace): + self.desc = desc + self.stackTrace = stackTrace + + def __str__(self): + return repr(self.desc) + + +class AnalysisException(CapturedException): """ Failed to analyze a SQL query plan. """ -class IllegalArgumentException(Exception): +class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. """ @@ -36,10 +45,12 @@ def deco(*a, **kw): return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: s = e.java_exception.toString() + stackTrace = '\n\t at '.join(map(lambda x: x.toString(), + e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): - raise AnalysisException(s.split(': ', 1)[1]) + raise AnalysisException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): - raise IllegalArgumentException(s.split(': ', 1)[1]) + raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise return deco From 87f28fc24003ad60c52f899d10f38032631624dc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Oct 2015 11:17:03 +0100 Subject: [PATCH 077/324] [SPARK-11379][SQL] ExpressionEncoder can't handle top level primitive type correctly For inner primitive type(e.g. inside `Product`), we use `schemaFor` to get the catalyst type for it, https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala#L403. However, for top level primitive type, we use `dataTypeFor`, which is wrong. Author: Wenchen Fan Closes #9337 from cloud-fan/encoder. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9cbb7c2ffdc76..0b8a8abd02d67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -170,7 +170,7 @@ trait ScalaReflection { .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true)) + def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a374da4da1f08..b0dacf7f555e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -57,6 +57,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { encodeDecodeTest(false) encodeDecodeTest(1.toShort) encodeDecodeTest(1.toByte) + encodeDecodeTest("hello") encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) From f79ebf2a9e99575908dad6f7a14c8cfcffdebd91 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Oct 2015 11:49:45 +0100 Subject: [PATCH 078/324] [SPARK-11370] [SQL] fix a bug in GroupedIterator and create unit test for it Before this PR, user has to consume the iterator of one group before process next group, or we will get into infinite loops. Author: Wenchen Fan Closes #9330 from cloud-fan/group. --- .../spark/sql/execution/GroupedIterator.scala | 99 ++++++++++++------- .../sql/execution/GroupedIteratorSuite.scala | 82 +++++++++++++++ 2 files changed, 144 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index 10742cf7348f8..6a8850129f1ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -27,7 +27,7 @@ object GroupedIterator { keyExpressions: Seq[Expression], inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { if (input.hasNext) { - new GroupedIterator(input, keyExpressions, inputSchema) + new GroupedIterator(input.buffered, keyExpressions, inputSchema) } else { Iterator.empty } @@ -64,7 +64,7 @@ object GroupedIterator { * @param inputSchema The schema of the rows in the `input` iterator. */ class GroupedIterator private( - input: Iterator[InternalRow], + input: BufferedIterator[InternalRow], groupingExpressions: Seq[Expression], inputSchema: Seq[Attribute]) extends Iterator[(InternalRow, Iterator[InternalRow])] { @@ -83,10 +83,17 @@ class GroupedIterator private( /** Holds a copy of an input row that is in the current group. */ var currentGroup = currentRow.copy() - var currentIterator: Iterator[InternalRow] = null + assert(keyOrdering.compare(currentGroup, currentRow) == 0) + var currentIterator = createGroupValuesIterator() - // Return true if we already have the next iterator or fetching a new iterator is successful. + /** + * Return true if we already have the next iterator or fetching a new iterator is successful. + * + * Note that, if we get the iterator by `next`, we should consume it before call `hasNext`, + * because we will consume the input data to skip to next group while fetching a new iterator, + * thus make the previous iterator empty. + */ def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator def next(): (InternalRow, Iterator[InternalRow]) = { @@ -96,46 +103,64 @@ class GroupedIterator private( ret } - def fetchNextGroupIterator(): Boolean = { - if (currentRow != null || input.hasNext) { - val inputIterator = new Iterator[InternalRow] { - // Return true if we have a row and it is in the current group, or if fetching a new row is - // successful. - def hasNext = { - (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) || - fetchNextRowInGroup() - } + private def fetchNextGroupIterator(): Boolean = { + assert(currentIterator == null) + + if (currentRow == null && input.hasNext) { + currentRow = input.next() + } + + if (currentRow == null) { + // These is no data left, return false. + false + } else { + // Skip to next group. + while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { + currentRow = input.next() + } + + if (keyOrdering.compare(currentGroup, currentRow) == 0) { + // We are in the last group, there is no more groups, return false. + false + } else { + // Now the `currentRow` is the first row of next group. + currentGroup = currentRow.copy() + currentIterator = createGroupValuesIterator() + true + } + } + } + + private def createGroupValuesIterator(): Iterator[InternalRow] = { + new Iterator[InternalRow] { + def hasNext: Boolean = currentRow != null || fetchNextRowInGroup() + + def next(): InternalRow = { + assert(hasNext) + val res = currentRow + currentRow = null + res + } - def fetchNextRowInGroup(): Boolean = { - if (currentRow != null || input.hasNext) { + private def fetchNextRowInGroup(): Boolean = { + assert(currentRow == null) + + if (input.hasNext) { + // The inner iterator should NOT consume the input into next group, here we use `head` to + // peek the next input, to see if we should continue to process it. + if (keyOrdering.compare(currentGroup, input.head) == 0) { + // Next input is in the current group. Continue the inner iterator. currentRow = input.next() - if (keyOrdering.compare(currentGroup, currentRow) == 0) { - // The row is in the current group. Continue the inner iterator. - true - } else { - // We got a row, but its not in the right group. End this inner iterator and prepare - // for the next group. - currentIterator = null - currentGroup = currentRow.copy() - false - } + true } else { - // There is no more input so we are done. + // Next input is not in the right group. End this inner iterator. false } - } - - def next(): InternalRow = { - assert(hasNext) // Ensure we have fetched the next row. - val res = currentRow - currentRow = null - res + } else { + // There is no more data, return false. + false } } - currentIterator = inputIterator - true - } else { - false } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala new file mode 100644 index 0000000000000..e7a08481cfa80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} + +class GroupedIteratorSuite extends SparkFunSuite { + + test("basic") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 1) + key.getInt(0) -> data.map(encoder.fromRow).toSeq + }.toSeq + + assert(result == + 1 -> Seq(input(0), input(1)) :: + 2 -> Seq(input(2)) :: Nil) + } + + test("group by 2 columns") { + val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) + val encoder = RowEncoder(schema) + + val input = Seq( + Row(1, 2L, "a"), + Row(1, 2L, "b"), + Row(1, 3L, "c"), + Row(2, 1L, "d"), + Row(3, 2L, "e")) + + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) + + val result = grouped.map { + case (key, data) => + assert(key.numFields == 2) + (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq) + }.toSeq + + assert(result == + (1, 2L, Seq(input(0), input(1))) :: + (1, 3L, Seq(input(2))) :: + (2, 1L, Seq(input(3))) :: + (3, 2L, Seq(input(4))) :: Nil) + } + + test("do nothing to the value iterator") { + val schema = new StructType().add("i", IntegerType).add("s", StringType) + val encoder = RowEncoder(schema) + val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) + val grouped = GroupedIterator(input.iterator.map(encoder.toRow), + Seq('i.int.at(0)), schema.toAttributes) + + assert(grouped.length == 2) + } +} From f304f9c9a1c954b3b5786f90bb13f543637d3192 Mon Sep 17 00:00:00 2001 From: tedyu Date: Thu, 29 Oct 2015 15:02:13 +0100 Subject: [PATCH 079/324] [SPARK-11318] Include hive profile in make-distribution.sh command Author: tedyu Closes #9281 from tedyu/master. --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 743643cbcc62f..4f73adb85446c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -38,7 +38,7 @@ To create a Spark distribution like those distributed by the to be runnable, use `make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: - ./make-distribution.sh --name custom-spark --tgz -Phadoop-2.4 -Pyarn + ./make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn For more information on usage, run `./make-distribution.sh --help` From 3bb2a8d7508b507edfcc21bd20912b0ff4a0a248 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 29 Oct 2015 15:11:00 +0100 Subject: [PATCH 080/324] [SPARK-11388][BUILD] Fix self closing tags. Java 8 javadoc does not like self closing tags: ```

    ```, ```
    ```, ... This PR fixes those. Author: Herman van Hovell Closes #9339 from hvanhovell/SPARK-11388. --- .../java/org/apache/spark/launcher/SparkAppHandle.java | 4 ++-- .../java/org/apache/spark/launcher/SparkLauncher.java | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java index 2896a91d5e793..13dd9f1739fb6 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -19,7 +19,7 @@ /** * A handle to a running Spark application. - *

    + *

    * Provides runtime information about the underlying Spark application, and actions to control it. * * @since 1.6.0 @@ -110,7 +110,7 @@ public interface Listener { * Callback for changes in the handle's state. * * @param handle The updated handle. - * @see {@link SparkAppHandle#getState()} + * @see SparkAppHandle#getState() */ void stateChanged(SparkAppHandle handle); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 5d74b37033a51..dd1c93af6ca4c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -350,7 +350,7 @@ public SparkLauncher setVerbose(boolean verbose) { /** * Launches a sub-process that will start the configured Spark application. - *

    + *

    * The {@link #startApplication(SparkAppHandle.Listener...)} method is preferred when launching * Spark, since it provides better control of the child application. * @@ -362,16 +362,16 @@ public Process launch() throws IOException { /** * Starts a Spark application. - *

    + *

    * This method returns a handle that provides information about the running application and can * be used to do basic interaction with it. - *

    + *

    * The returned handle assumes that the application will instantiate a single SparkContext * during its lifetime. Once that context reports a final state (one that indicates the * SparkContext has stopped), the handle will not perform new state transitions, so anything * that happens after that cannot be monitored. If the underlying application is launched as * a child process, {@link SparkAppHandle#kill()} can still be used to kill the child process. - *

    + *

    * Currently, all applications are launched as child processes. The child's stdout and stderr * are merged and written to a logger (see java.util.logging). The logger's name * can be defined by setting {@link #CHILD_PROCESS_LOGGER_NAME} in the app's configuration. If From f7a51deebad1b4c3b970a051f25d286110b94438 Mon Sep 17 00:00:00 2001 From: xin Wu Date: Thu, 29 Oct 2015 07:42:46 -0700 Subject: [PATCH 081/324] [SPARK-11246] [SQL] Table cache for Parquet broken in 1.5 The root cause is that when spark.sql.hive.convertMetastoreParquet=true by default, the cached InMemoryRelation of the ParquetRelation can not be looked up from the cachedData of CacheManager because the key comparison fails even though it is the same LogicalPlan representing the Subquery that wraps the ParquetRelation. The solution in this PR is overriding the LogicalPlan.sameResult function in Subquery case class to eliminate subquery node first before directly comparing the child (ParquetRelation), which will find the key to the cached InMemoryRelation. Author: xin Wu Closes #9326 from xwu0226/spark-11246-commit. --- .../sql/execution/datasources/LogicalRelation.scala | 5 +++++ .../org/apache/spark/sql/hive/CachedTableSuite.scala | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 783252e0a297f..219dae88e515d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -62,6 +62,11 @@ case class LogicalRelation( case _ => false } + // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need + // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's + // expId can be different but the relation is still the same. + override lazy val cleanArgs: Seq[Any] = Seq(relation) + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = BigInt(relation.sizeInBytes) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 9adb3780a2c55..5c2fc7d82ffbd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId @@ -203,4 +204,14 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { sql("DROP TABLE refreshTable") Utils.deleteRecursively(tempPath) } + + test("SPARK-11246 cache parquet table") { + sql("CREATE TABLE cachedTable STORED AS PARQUET AS SELECT 1") + + cacheTable("cachedTable") + val sparkPlan = sql("SELECT * FROM cachedTable").queryExecution.sparkPlan + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 1) + + sql("DROP TABLE cachedTable") + } } From 8185f038c13c72e1bea7b0921b84125b7a352139 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 29 Oct 2015 18:29:50 +0100 Subject: [PATCH 082/324] [SPARK-11188][SQL] Elide stacktraces in bin/spark-sql for AnalysisExceptions Only print the error message to the console for Analysis Exceptions in sql-shell. Author: Dilip Biswal Closes #9194 from dilipbiswal/spark-11188. --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 10 +++++++++- .../spark/sql/hive/thriftserver/SparkSQLDriver.scala | 11 ++++++++--- .../spark/sql/hive/thriftserver/CliSuite.scala | 12 ++++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index b5073961a1c84..62e912c69abc6 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.util.{ArrayList => JArrayList, Locale} +import org.apache.spark.sql.AnalysisException + import scala.collection.JavaConverters._ import jline.console.ConsoleReader @@ -298,6 +300,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out + val err = sessionState.err val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) @@ -308,7 +311,12 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret = rc.getResponseCode if (ret != 0) { - console.printError(rc.getErrorMessage()) + // For analysis exception, only the error is printed out to the console. + rc.getException() match { + case e : AnalysisException => + err.println(s"""Error in query: ${e.getMessage}""") + case _ => err.println(rc.getErrorMessage()) + } driver.close() return ret } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 2619286afc148..f1ec7238520ac 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.thriftserver import java.util.{Arrays, ArrayList => JArrayList, List => JList} +import org.apache.log4j.LogManager +import org.apache.spark.sql.AnalysisException import scala.collection.JavaConverters._ @@ -63,9 +65,12 @@ private[hive] class SparkSQLDriver( tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { - case cause: Throwable => - logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) + case ae: AnalysisException => + logDebug(s"Failed in [$command]", ae) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae) + case cause: Throwable => + logError(s"Failed in [$command]", cause) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause) } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 76d1591a235c2..3fa5c8528b602 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -58,7 +58,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { * @param timeout maximum time for the commands to complete * @param extraArgs any extra arguments * @param errorResponses a sequence of strings whose presence in the stdout of the forked process - * is taken as an immediate error condition. That is: if a line beginning + * is taken as an immediate error condition. That is: if a line containing * with one of these strings is found, fail the test immediately. * The default value is `Seq("Error:")` * @@ -104,7 +104,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } } else { errorResponses.foreach { r => - if (line.startsWith(r)) { + if (line.contains(r)) { foundAllExpectedAnswers.tryFailure( new RuntimeException(s"Failed with error line '$line'")) } @@ -219,4 +219,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "OK" ) } + + test("SPARK-11188 Analysis error reporting") { + runCliWithin(timeout = 2.minute, + errorResponses = Seq("AnalysisException"))( + "select * from nonexistent_table;" + -> "Error in query: Table not found: nonexistent_table;" + ) + } } From a01cbf5daac148f39cd97299780f542abc41d1e9 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 29 Oct 2015 11:58:39 -0700 Subject: [PATCH 083/324] [SPARK-10641][SQL] Add Skewness and Kurtosis Support Implementing skewness and kurtosis support based on following algorithm: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics Author: sethah Closes #9003 from sethah/SPARK-10641. --- .../catalyst/analysis/FunctionRegistry.scala | 5 + .../catalyst/analysis/HiveTypeCoercion.scala | 5 + .../spark/sql/catalyst/dsl/package.scala | 5 + .../expressions/aggregate/functions.scala | 329 ++++++++++++++++++ .../expressions/aggregate/utils.scala | 30 ++ .../sql/catalyst/expressions/aggregates.scala | 95 +++++ .../org/apache/spark/sql/GroupedData.scala | 65 ++++ .../org/apache/spark/sql/functions.scala | 115 +++++- .../spark/sql/DataFrameAggregateSuite.scala | 73 ++++ .../org/apache/spark/sql/QueryTest.scala | 48 +++ .../org/apache/spark/sql/SQLQuerySuite.scala | 63 +++- .../execution/HiveCompatibilitySuite.scala | 1 - 12 files changed, 823 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3dce6c1a27e85..ed9fcfe014f0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -189,6 +189,11 @@ object FunctionRegistry { expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), + expression[Variance]("variance"), + expression[VariancePop]("var_pop"), + expression[VarianceSamp]("var_samp"), + expression[Skewness]("skewness"), + expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 1140150f66864..3c675672dab85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -300,6 +300,11 @@ object HiveTypeCoercion { case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case Variance(e @ StringType()) => Variance(Cast(e, DoubleType)) + case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) + case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) + case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) + case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 27b3cd84b3846..787f67a297e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -162,6 +162,11 @@ package object dsl { def stddev(e: Expression): Expression = Stddev(e) def stddev_pop(e: Expression): Expression = StddevPop(e) def stddev_samp(e: Expression): Expression = StddevSamp(e) + def variance(e: Expression): Expression = Variance(e) + def var_pop(e: Expression): Expression = VariancePop(e) + def var_samp(e: Expression): Expression = VarianceSamp(e) + def skewness(e: Expression): Expression = Skewness(e) + def kurtosis(e: Expression): Expression = Kurtosis(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 515246d344244..281404f285a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -930,3 +930,332 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } + +/** + * A central moment is the expected value of a specified power of the deviation of a random + * variable from the mean. Central moments are often used to characterize the properties of about + * the shape of a distribution. + * + * This class implements online, one-pass algorithms for computing the central moments of a set of + * points. + * + * Behavior: + * - null values are ignored + * - returns `Double.NaN` when the column contains `Double.NaN` values + * + * References: + * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." + * 2015. http://arxiv.org/abs/1510.04923 + * + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + * Algorithms for calculating variance (Wikipedia)]] + * + * @param child to compute central moments of. + */ +abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { + + /** + * The central moment order to be computed. + */ + protected def momentOrder: Int + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * Size of aggregation buffer. + */ + private[this] val bufferSize = 5 + + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => + AttributeReference(s"M$i", DoubleType)() + } + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // buffer offsets + private[this] val nOffset = mutableAggBufferOffset + private[this] val meanOffset = mutableAggBufferOffset + 1 + private[this] val secondMomentOffset = mutableAggBufferOffset + 2 + private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 + private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 + + // frequently used values for online updates + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + private[this] var n = 0.0 + private[this] var mean = 0.0 + private[this] var m2 = 0.0 + private[this] var m3 = 0.0 + private[this] var m4 = 0.0 + + /** + * Initialize all moments to zero. + */ + override def initialize(buffer: MutableRow): Unit = { + for (aggIndex <- 0 until bufferSize) { + buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + } + } + + /** + * Update the central moments buffer. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = Cast(child, DoubleType).eval(input) + if (v != null) { + val updateValue = v match { + case d: Double => d + } + + n = buffer.getDouble(nOffset) + mean = buffer.getDouble(meanOffset) + + n += 1.0 + buffer.setDouble(nOffset, n) + delta = updateValue - mean + deltaN = delta / n + mean += deltaN + buffer.setDouble(meanOffset, mean) + + if (momentOrder >= 2) { + m2 = buffer.getDouble(secondMomentOffset) + m2 += delta * (delta - deltaN) + buffer.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + delta2 = delta * delta + deltaN2 = deltaN * deltaN + m3 = buffer.getDouble(thirdMomentOffset) + m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) + buffer.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + m4 = buffer.getDouble(fourthMomentOffset) + m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + + delta * (delta * delta2 - deltaN * deltaN2) + buffer.setDouble(fourthMomentOffset, m4) + } + } + } + + /** + * Merge two central moment buffers. + */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val n1 = buffer1.getDouble(nOffset) + val n2 = buffer2.getDouble(inputAggBufferOffset) + val mean1 = buffer1.getDouble(meanOffset) + val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + + var secondMoment1 = 0.0 + var secondMoment2 = 0.0 + + var thirdMoment1 = 0.0 + var thirdMoment2 = 0.0 + + var fourthMoment1 = 0.0 + var fourthMoment2 = 0.0 + + n = n1 + n2 + buffer1.setDouble(nOffset, n) + delta = mean2 - mean1 + deltaN = if (n == 0.0) 0.0 else delta / n + mean = mean1 + deltaN * n2 + buffer1.setDouble(mutableAggBufferOffset + 1, mean) + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + if (momentOrder >= 2) { + secondMoment1 = buffer1.getDouble(secondMomentOffset) + secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + thirdMoment1 = buffer1.getDouble(thirdMomentOffset) + thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) + buffer1.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + fourthMoment1 = buffer1.getDouble(fourthMomentOffset) + fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * + (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) + buffer1.setDouble(fourthMomentOffset, m4) + } + } + + /** + * Compute aggregate statistic from sufficient moments. + * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) + * needed to compute the aggregate stat. + */ + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + + override final def eval(buffer: InternalRow): Any = { + val n = buffer.getDouble(nOffset) + val mean = buffer.getDouble(meanOffset) + val moments = Array.ofDim[Double](momentOrder + 1) + moments(0) = 1.0 + moments(1) = 0.0 + if (momentOrder >= 2) { + moments(2) = buffer.getDouble(secondMomentOffset) + } + if (momentOrder >= 3) { + moments(3) = buffer.getDouble(thirdMomentOffset) + } + if (momentOrder >= 4) { + moments(4) = buffer.getDouble(fourthMomentOffset) + } + + getStatistic(n, mean, moments) + } +} + +case class Variance(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + + if (n == 0.0) Double.NaN else moments(2) / n + } +} + +case class VarianceSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + } +} + +case class VariancePop(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "variance_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) Double.NaN else moments(2) / n + } +} + +case class Skewness(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "skewness" + + override protected val momentOrder = 3 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m3 = moments(3) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) + } + } +} + +case class Kurtosis(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "kurtosis" + + override protected val momentOrder = 4 + + // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m4 = moments(4) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + n * m4 / (m2 * m2) - 3.0 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 12bdab0915801..c911ec53f1ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -67,6 +67,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Kurtosis(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Kurtosis(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Last(child, ignoreNulls) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Last(child, ignoreNulls), @@ -85,6 +91,12 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Skewness(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Skewness(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Stddev(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Stddev(child), @@ -120,6 +132,24 @@ object Utils { aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), mode = aggregate.Complete, isDistinct = false) + + case expressions.Variance(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Variance(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.VariancePop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VariancePop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.VarianceSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.VarianceSamp(child), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 70819be5af5b0..c1bab6d36ab29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -991,3 +991,98 @@ case class StddevFunction( } } } + +// placeholder +case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "kurtosis" + + override def toString: String = s"KURTOSIS($child)" +} + +// placeholder +case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "skewness" + + override def toString: String = s"SKEWNESS($child)" +} + +// placeholder +case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance" + + override def toString: String = s"VARIANCE($child)" +} + +// placeholder +case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance_pop" + + override def toString: String = s"VAR_POP($child)" +} + +// placeholder +case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { + + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + + "please set spark.sql.useAggregate2 = true") + } + + override def nullable: Boolean = false + + override def dataType: DoubleType.type = DoubleType + + override def foldable: Boolean = false + + override def prettyName: String = "variance_samp" + + override def toString: String = s"VAR_SAMP($child)" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 102b802ad0a0a..dc96384a4d28d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -127,7 +127,12 @@ class GroupedData protected[sql]( case "stddev" => Stddev case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp + case "variance" => Variance + case "var_pop" => VariancePop + case "var_samp" => VarianceSamp case "sum" => Sum + case "skewness" => Skewness + case "kurtosis" => Kurtosis case "count" | "size" => // Turn count(*) into count(1) (inputExpr: Expression) => inputExpr match { @@ -250,6 +255,30 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Average) } + /** + * Compute the skewness for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the skewness values for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def skewness(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Skewness) + } + + /** + * Compute the kurtosis for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the kurtosis values for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def kurtosis(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Kurtosis) + } + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. @@ -333,4 +362,40 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def variance(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Variance) + } + + /** + * Compute the population variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VariancePop) + } + + /** + * Compute the sample variance for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the variance for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def var_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(VarianceSamp) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 15c864a8ab641..c1737b1ef663c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -228,6 +228,22 @@ object functions { */ def first(columnName: String): Column = first(Column(columnName)) + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(e: Column): Column = Kurtosis(e.expr) + + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * @@ -295,8 +311,24 @@ object functions { def min(columnName: String): Column = min(Column(columnName)) /** - * Aggregate function: returns the unbiased sample standard deviation - * of the expression in a group. + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(e: Column): Column = Skewness(e.expr) + + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. * * @group agg_funcs * @since 1.6.0 @@ -304,13 +336,13 @@ object functions { def stddev(e: Column): Column = Stddev(e.expr) /** - * Aggregate function: returns the population standard deviation of + * Aggregate function: returns the unbiased sample standard deviation of * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = StddevPop(e.expr) + def stddev(columnName: String): Column = stddev(Column(columnName)) /** * Aggregate function: returns the unbiased sample standard deviation of @@ -321,6 +353,33 @@ object functions { */ def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * @@ -353,6 +412,54 @@ object functions { */ def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(e: Column): Column = Variance(e.expr) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(e: Column): Column = VarianceSamp(e.expr) + + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(e: Column): Column = VariancePop(e.expr) + + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f5ef9ffd7f4f2..9b23977c765dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -221,4 +221,77 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { emptyTableData.agg(sumDistinct('a)), Row(null)) } + + test("moments") { + val absTol = 1e-8 + + val sparkVariance = testData2.agg(variance('a)) + val expectedVariance = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) + val sparkVariancePop = testData2.agg(var_pop('a)) + checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol) + + val sparkVarianceSamp = testData2.agg(var_samp('a)) + val expectedVarianceSamp = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol) + + val sparkSkewness = testData2.agg(skewness('a)) + val expectedSkewness = Row(0.0) + checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) + + val sparkKurtosis = testData2.agg(kurtosis('a)) + val expectedKurtosis = Row(-1.5) + checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol) + + } + + test("zero moments") { + val emptyTableData = Seq((1, 2)).toDF("a", "b") + assert(emptyTableData.count() === 1) + + checkAnswer( + emptyTableData.agg(variance('a)), + Row(0.0)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(0.0)) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(Double.NaN)) + } + + test("null moments") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(variance('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_samp('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(var_pop('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(skewness('a)), + Row(Double.NaN)) + + checkAnswer( + emptyTableData.agg(kurtosis('a)), + Row(Double.NaN)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 73e02eb0d9574..3c174efe73ffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -134,6 +134,32 @@ abstract class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param dataFrame the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Seq[Row], + absTol: Double): Unit = { + // TODO: catch exceptions in data frame execution + val actualAnswer = dataFrame.collect() + require(actualAnswer.length == expectedAnswer.length, + s"actual num rows ${actualAnswer.length} != expected num of rows ${expectedAnswer.length}") + + actualAnswer.zip(expectedAnswer).foreach { + case (actualRow, expectedRow) => + QueryTest.checkAggregatesWithTol(actualRow, expectedRow, absTol) + } + } + + protected def checkAggregatesWithTol(dataFrame: DataFrame, + expectedAnswer: Row, + absTol: Double): Unit = { + checkAggregatesWithTol(dataFrame, Seq(expectedAnswer), absTol) + } + /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ @@ -214,6 +240,28 @@ object QueryTest { return None } + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require(actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert(math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f5ae3ae49b460..5a616fac0bc2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -523,8 +523,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 1, 6, 3) + sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) ) } @@ -717,14 +718,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev") { checkAnswer( sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4/5.0)) + Row(math.sqrt(4.0 / 5.0)) ) } test("stddev_pop") { checkAnswer( sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4/6.0)) + Row(math.sqrt(4.0 / 6.0)) ) } @@ -735,10 +736,60 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("var_samp") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") + val expectedAnswer = Row(4.0 / 5.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("variance") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") + val expectedAnswer = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("var_pop") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") + val expectedAnswer = Row(4.0 / 6.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("skewness") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT skewness(a) FROM testData2") + val expectedAnswer = Row(0.0) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("kurtosis") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") + val expectedAnswer = Row(-1.5) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + test("stddev agg") { checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) + } + + test("variance agg") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + + "FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + } + + test("skewness and kurtosis agg") { + val absTol = 1e-8 + val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") + val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) + checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } test("inner join where, one match per row") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index eed9e436f9af7..9e357bf348c94 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -467,7 +467,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_orderby1", "escape_sortby1", "explain_rearrange", - "fetch_aggregation", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", From f21ef8dbb2ef6526b8b47e18b9d8d91dd520c086 Mon Sep 17 00:00:00 2001 From: teramonagi Date: Thu, 29 Oct 2015 10:54:04 -0700 Subject: [PATCH 084/324] [SPARK-10532][EC2] Added --profile option to specify the name of profile "profiles" give us the way that you can specify the set of credentials you want to use when you initialize a connection to AWS. You can keep multiple sets of credentials in the same credentials files using different profile names. For example, you can use --profile option to do that when you use "aws cli tool". http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html Author: teramonagi Closes #8696 from teramonagi/SPARK-10532. --- ec2/spark_ec2.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 3a2361c6d6d2b..9327e21e43db7 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -181,6 +181,10 @@ def parse_args(): parser.add_option( "-i", "--identity-file", help="SSH private key file to use for logging into instances") + parser.add_option( + "-p", "--profile", default=None, + help="If you have multiple profiles (AWS or boto config), you can configure " + + "additional, named profiles by using this option (default: %default)") parser.add_option( "-t", "--instance-type", default="m1.large", help="Type of instance to launch (default: %default). " + @@ -1315,7 +1319,10 @@ def real_main(): sys.exit(1) try: - conn = ec2.connect_to_region(opts.region) + if opts.profile is None: + conn = ec2.connect_to_region(opts.region) + else: + conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) except Exception as e: print((e), file=stderr) sys.exit(1) From 4f5e60c647d7d6827438721b7fabbc3a57b81023 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Thu, 29 Oct 2015 15:13:38 -0700 Subject: [PATCH 085/324] [SPARK-11236][CORE] Update Tachyon dependency from 0.7.1 -> 0.8.0. Upgrades the tachyon-client version to the latest release. No new dependencies are added and no spark facing APIs are changed. The removal of the `tachyon-underfs-s3` exclusion will enable users to use S3 out of the box and there are no longer any additional external dependencies added by the module. Author: Calvin Jia Closes #9204 from calvinjia/spark-11236. --- core/pom.xml | 6 +----- make-distribution.sh | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 319a50049a82d..dff40e91ad228 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -266,7 +266,7 @@ org.tachyonproject tachyon-client - 0.7.1 + 0.8.0 org.apache.hadoop @@ -288,10 +288,6 @@ org.tachyonproject tachyon-underfs-glusterfs - - org.tachyonproject - tachyon-underfs-s3 - diff --git a/make-distribution.sh b/make-distribution.sh index 24418ace26270..f6766784813c3 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,9 +33,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.7.1" +TACHYON_VERSION="0.8.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" +TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" MAKE_TGZ=false NAME=none @@ -240,10 +240,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx From 96cf87f66d47245b19e719cb83947042b21546fa Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 29 Oct 2015 16:36:52 -0700 Subject: [PATCH 086/324] [SPARK-11301] [SQL] fix case sensitivity for filter on partitioned columns Author: Wenchen Fan Closes #9271 from cloud-fan/filter. --- .../execution/datasources/DataSourceStrategy.scala | 12 +++++------- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ffb4645b89321..af6626c897583 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -63,16 +63,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) if t.partitionSpec.partitionColumns.nonEmpty => // We divide the filter expressions into 3 parts - val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet + val partitionColumns = AttributeSet( + t.partitionColumns.map(c => l.output.find(_.name == c.name).get)) - // TODO this is case-sensitive - // Only prunning the partition keys - val partitionFilters = - filters.filter(_.references.map(_.name).toSet.subsetOf(partitionColumnNames)) + // Only pruning the partition keys + val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) // Only pushes down predicates that do not reference partition keys. - val pushedFilters = - filters.filter(_.references.map(_.name).toSet.intersect(partitionColumnNames).isEmpty) + val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) // Predicates with both partition keys and attributes val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 59565a6b13d40..c9d6e19d2ce93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -987,4 +987,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") assert(df.select($"src.i".cast(StringType)).columns.head === "i") } + + test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) + val df = sqlContext.read.parquet(path.getAbsolutePath) + checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) + } + } + } } From d89be0bf81029cd82008a959d191e1c7b6ceaa18 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 29 Oct 2015 21:01:10 -0700 Subject: [PATCH 087/324] [SPARK-11409][SPARKR] Enable url link in R doc for Persist Quick one line doc fix link is not clickable ![image](https://cloud.githubusercontent.com/assets/8969467/10833041/4e91dd7c-7e4c-11e5-8905-713b986dbbde.png) shivaram Author: felixcheung Closes #9363 from felixcheung/rpersistdoc. --- R/pkg/R/DataFrame.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c8944459542af..87a2c66ffd2a9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -357,7 +357,7 @@ setMethod("cache", #' #' Persist this DataFrame with the specified storage level. For details of the #' supported storage levels, refer to -#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The DataFrame to persist #' @rdname persist @@ -1572,7 +1572,7 @@ setMethod("merge", joinRes }) -#' +#' #' Creates a list of columns by replacing the intersected ones with aliases. #' The name of the alias column is formed by concatanating the original column name and a suffix. #' From 56419cf11f769c80f391b45dc41b3c7101cc5ff4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 29 Oct 2015 23:38:06 -0700 Subject: [PATCH 088/324] [SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management This PR introduce a mechanism to call spill() on those SQL operators that support spilling (for example, BytesToBytesMap, UnsafeExternalSorter and ShuffleExternalSorter) if there is not enough memory for execution. The preserved first page is needed anymore, so removed. Other Spillable objects in Spark core (ExternalSorter and AppendOnlyMap) are not included in this PR, but those could benefit from this (trigger others' spilling). The PrepareRDD may be not needed anymore, could be removed in follow up PR. The following script will fail with OOM before this PR, finished in 150 seconds with 2G heap (also works in 1.5 branch, with similar duration). ```python sqlContext.setConf("spark.sql.shuffle.partitions", "1") df = sqlContext.range(1<<25).selectExpr("id", "repeat(id, 2) as s") df2 = df.select(df.id.alias('id2'), df.s.alias('s2')) j = df.join(df2, df.id==df2.id2).groupBy(df.id).max("id", "id2") j.explain() print j.count() ``` For thread-safety, here what I'm got: 1) Without calling spill(), the operators should only be used by single thread, no safety problems. 2) spill() could be triggered in two cases, triggered by itself, or by other operators. we can check trigger == this in spill(), so it's still in the same thread, so safety problems. 3) if it's triggered by other operators (right now cache will not trigger spill()), we only spill the data into disk when it's in scanning stage (building is finished), so the in-memory sorter or memory pages are read-only, we only need to synchronize the iterator and change it. 4) During scanning, the iterator will only use one record in one page, we can't free this page, because the downstream is currently using it (used by UnsafeRow or other objects). In BytesToBytesMap, we just skip the current page, and dump all others into disk. In UnsafeExternalSorter, we keep the page that is used by current record (having the same baseObject), free it when loading the next record. In ShuffleExternalSorter, the spill() will not trigger during scanning. 5) In order to avoid deadlock, we didn't call acquireMemory during spill (so we reused the pointer array in InMemorySorter). Author: Davies Liu Closes #9241 from davies/force_spill. --- .../apache/spark/memory/MemoryConsumer.java | 128 ++++++ .../spark/memory/TaskMemoryManager.java | 138 ++++-- .../shuffle/sort/ShuffleExternalSorter.java | 210 +++------ .../shuffle/sort/ShuffleInMemorySorter.java | 50 +- .../shuffle/sort/UnsafeShuffleWriter.java | 6 - .../spark/unsafe/map/BytesToBytesMap.java | 430 +++++++++++------- .../unsafe/sort/UnsafeExternalSorter.java | 426 ++++++++--------- .../unsafe/sort/UnsafeInMemorySorter.java | 60 ++- .../unsafe/sort/UnsafeSorterSpillReader.java | 6 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 2 +- .../apache/spark/memory/MemoryManager.scala | 9 +- .../spark/util/collection/Spillable.scala | 4 +- .../spark/memory/TaskMemoryManagerSuite.java | 77 +++- .../sort/PackedRecordPointerSuite.java | 30 +- .../sort/ShuffleInMemorySorterSuite.java | 6 +- .../sort/UnsafeShuffleWriterSuite.java | 38 +- .../map/AbstractBytesToBytesMapSuite.java | 149 +++++- .../sort/UnsafeExternalSorterSuite.java | 97 ++-- .../sort/UnsafeInMemorySorterSuite.java | 20 +- .../scala/org/apache/spark/FailureSuite.scala | 4 +- .../spark/memory/MemoryManagerSuite.scala | 60 +-- ...yManager.scala => TestMemoryManager.scala} | 32 +- .../execution/UnsafeExternalRowSorter.java | 7 +- .../UnsafeFixedWidthAggregationMap.java | 2 +- .../sql/execution/UnsafeKVExternalSorter.java | 19 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 22 +- .../UnsafeKVExternalSorterSuite.scala | 6 +- .../TungstenAggregationIteratorSuite.scala | 54 --- .../unsafe/memory/HeapMemoryAllocator.java | 9 +- .../unsafe/memory/UnsafeMemoryAllocator.java | 3 - 30 files changed, 1270 insertions(+), 834 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/MemoryConsumer.java rename core/src/test/scala/org/apache/spark/memory/{GrantEverythingMemoryManager.scala => TestMemoryManager.scala} (71%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 0000000000000..008799cc77395 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + + +import java.io.IOException; + +import org.apache.spark.unsafe.memory.MemoryBlock; + + +/** + * An memory consumer of TaskMemoryManager, which support spilling. + */ +public abstract class MemoryConsumer { + + private final TaskMemoryManager taskMemoryManager; + private final long pageSize; + private long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + this.used = 0; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes()); + } + + /** + * Returns the size of used memory in bytes. + */ + long getUsed() { + return used; + } + + /** + * Force spill during building. + * + * For testing. + */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + * + * This should be implemented by subclass. + * + * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + * @throws IOException + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + /** + * Acquire `size` bytes memory. + * + * If there is not enough memory, throws OutOfMemoryError. + */ + protected void acquireMemory(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + if (got < size) { + taskMemoryManager.releaseExecutionMemory(got, this); + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + } + used += got; + } + + /** + * Release `size` bytes memory. + */ + protected void releaseMemory(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * Throws IOException if there is not enough memory. + * + * @throws OutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + freePage(page); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } + used += page.size(); + return page; + } + + /** + * Free a memory block. + */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 7b31c90dac666..4230575446d31 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -17,13 +17,18 @@ package org.apache.spark.memory; -import java.util.*; +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashSet; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; /** * Manages the memory allocated by an individual task. @@ -100,6 +105,12 @@ public class TaskMemoryManager { */ private final boolean inHeap; + /** + * The size of memory granted to each consumer. + */ + @GuardedBy("this") + private final HashSet consumers; + /** * Construct a new TaskMemoryManager. */ @@ -107,23 +118,92 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call + * spill() of consumers to release more memory. + * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long size) { - return memoryManager.acquireExecutionMemory(size, taskAttemptId); + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + assert(required >= 0); + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + + // try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + for (MemoryConsumer c: consumers) { + if (c != null && c != consumer && c.getUsed() > 0) { + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + if (got >= required) { + break; + } + } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); + } + } + } + } + + // call spill() on itself + if (got < required && consumer != null) { + try { + long released = consumer.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from itself ({})", taskAttemptId, + Utils.bytesToString(released), consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + } + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " + + e.getMessage()); + } + } + + consumers.add(consumer); + logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } } /** - * Release N bytes of execution memory. + * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size) { + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); memoryManager.releaseExecutionMemory(size, taskAttemptId); } + /** + * Dump the memory usage of all consumers. + */ + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + for (MemoryConsumer c: consumers) { + if (c.getUsed() > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + } + } + } + } + + /** + * Return the page size in bytes. + */ public long pageSizeBytes() { return memoryManager.pageSizeBytes(); } @@ -134,42 +214,40 @@ public long pageSizeBytes() { * * Returns `null` if there was not enough memory to allocate the page. */ - public MemoryBlock allocatePage(long size) { + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } + long acquired = acquireExecutionMemory(size, consumer); + if (acquired <= 0) { + return null; + } + final int pageNumber; synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { + releaseExecutionMemory(acquired, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } - final long acquiredExecutionMemory = acquireExecutionMemory(size); - if (acquiredExecutionMemory != size) { - releaseExecutionMemory(acquiredExecutionMemory); - synchronized (this) { - allocatedPages.clear(pageNumber); - } - return null; - } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size); + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); } return page; } /** - * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ - public void freePage(MemoryBlock page) { + public void freePage(MemoryBlock page, MemoryConsumer consumer) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; assert(allocatedPages.get(page.pageNumber)); @@ -182,14 +260,14 @@ public void freePage(MemoryBlock page) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize); + releaseExecutionMemory(pageSize, consumer); } /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * - * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/ + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). @@ -261,17 +339,17 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { * value can be used to detect memory leaks. */ public long cleanUpAllAllocatedMemory() { - long freedBytes = 0; - for (MemoryBlock page : pageTable) { - if (page != null) { - freedBytes += page.size(); - freePage(page); + synchronized (this) { + Arrays.fill(pageTable, null); + for (MemoryConsumer c: consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } } + consumers.clear(); } - - freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); - - return freedBytes; + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index f43236f41ae7b..400d8520019b9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -31,15 +31,15 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -58,23 +58,18 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class ShuffleExternalSorter { +final class ShuffleExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private final int initialSize; private final int numPartitions; - private final int pageSizeBytes; - @VisibleForTesting - final int maxRecordSizeBytes; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; - private long numRecordsInsertedSinceLastSpill = 0; /** Force this sorter to spill when there are this many elements in memory. For testing only */ private final long numElementsForSpillThreshold; @@ -98,8 +93,7 @@ final class ShuffleExternalSorter { // These variables are reset after spilling: @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - private long freeSpaceInCurrentPage = 0; + private long pageCursor = -1; public ShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -108,42 +102,21 @@ public ShuffleExternalSorter( int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) throws IOException { + ShuffleWriteMetrics writeMetrics) { + super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + memoryManager.pageSizeBytes())); this.taskMemoryManager = memoryManager; this.blockManager = blockManager; this.taskContext = taskContext; - this.initialSize = initialSize; - this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); - this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes()); - this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; - initializeForWriting(); - - // preserve first page to ensure that we have at least one page to work with. Otherwise, - // other operators in the same task may starve this sorter (SPARK-9709). - acquireNewPageIfNecessary(pageSizeBytes); - } - - /** - * Allocates new sort data structures. Called when creating the sorter and after each spill. - */ - private void initializeForWriting() throws IOException { - // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested); - if (memoryAcquired != memoryRequested) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); - } - + acquireMemory(initialSize * 8L); this.inMemSorter = new ShuffleInMemorySorter(initialSize); - numRecordsInsertedSinceLastSpill = 0; + this.peakMemoryUsedBytes = getMemoryUsage(); } /** @@ -242,6 +215,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } + inMemSorter.reset(); + if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -266,9 +241,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { /** * Sort and spill the current records in response to memory pressure. */ - @VisibleForTesting - void spill() throws IOException { - assert(inMemSorter != null); + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) { + return 0L; + } + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -276,13 +254,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - - initializeForWriting(); + return spillSize; } private long getMemoryUsage() { @@ -312,18 +286,12 @@ private long freeMemory() { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - taskMemoryManager.freePage(block); memoryFreed += block.size(); - } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); + freePage(block); } allocatedPages.clear(); currentPage = null; - currentPagePosition = -1; - freeSpaceInCurrentPage = 0; + pageCursor = 0; return memoryFreed; } @@ -332,16 +300,16 @@ private long freeMemory() { */ public void cleanupResources() { freeMemory(); + if (inMemSorter != null) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); + } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); - } } /** @@ -352,16 +320,27 @@ public void cleanupResources() { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); - final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray); - if (memoryAcquired < memoryToGrowPointerArray) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - spill(); + long used = inMemSorter.getMemoryUsage(); + long needed = used + inMemSorter.getMemoryToExpand(); + try { + acquireMemory(needed); // could trigger spilling + } catch (OutOfMemoryError e) { + // should have trigger spilling + assert(inMemSorter.hasSpaceForAnotherRecord()); + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + releaseMemory(needed); } else { - inMemSorter.expandPointerArray(); - taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage); + try { + inMemSorter.expandPointerArray(); + releaseMemory(used); + } catch (OutOfMemoryError oom) { + // Just in case that JVM had run out of memory + releaseMemory(needed); + spill(); + } } } } @@ -370,96 +349,46 @@ private void growPointerArrayIfNecessary() throws IOException { * Allocates more memory in order to insert an additional record. This will request additional * memory from the memory manager and spill if the requested memory can not be obtained. * - * @param requiredSpace the required space in the data page, in bytes, including space for storing + * @param required the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records * that exceed the page size are handled via a different code path which uses * special overflow pages). */ - private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { - growPointerArrayIfNecessary(); - if (requiredSpace > freeSpaceInCurrentPage) { - logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, - freeSpaceInCurrentPage); - // TODO: we should track metrics on the amount of space wasted when we roll over to a new page - // without using the free space at the end of the current page. We should also do this for - // BytesToBytesMap. - if (requiredSpace > pageSizeBytes) { - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - pageSizeBytes + ")"); - } else { - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - spill(); - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); - } + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) { + // TODO: try to find space in previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); } } /** * Write a record to the shuffle sorter. */ - public void insertRecord( - Object recordBaseObject, - long recordBaseOffset, - int lengthInBytes, - int partitionId) throws IOException { + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { - if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + // for tests + assert(inMemSorter != null); + if (inMemSorter.numRecords() > numElementsForSpillThreshold) { spill(); } growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. - final int totalSpaceRequired = lengthInBytes + 4; - - // --- Figure out where to insert the new record ---------------------------------------------- - - final MemoryBlock dataPage; - long dataPagePosition; - boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; - if (useOverflowPage) { - long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - spill(); - overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); - } - } - allocatedPages.add(overflowPage); - dataPage = overflowPage; - dataPagePosition = overflowPage.getBaseOffset(); - } else { - // The record is small enough to fit in a regular data page, but the current page might not - // have enough space to hold it (or no pages have been allocated yet). - acquireNewPageIfNecessary(totalSpaceRequired); - dataPage = currentPage; - dataPagePosition = currentPagePosition; - // Update bookkeeping information - freeSpaceInCurrentPage -= totalSpaceRequired; - currentPagePosition += totalSpaceRequired; - } - final Object dataPageBaseObject = dataPage.getBaseObject(); - - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); - dataPagePosition += 4; - Platform.copyMemory( - recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); - assert(inMemSorter != null); + final int required = length + 4; + acquireNewPageIfNecessary(required); + + assert(currentPage != null); + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, length); + pageCursor += 4; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); - numRecordsInsertedSinceLastSpill += 1; } /** @@ -475,6 +404,9 @@ public SpillInfo[] closeAndGetSpills() throws IOException { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index a8dee6c6101c1..e630575d1ae19 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -37,33 +37,51 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] pointerArray; + private long[] array; /** * The position in the pointer array where new records can be inserted. */ - private int pointerArrayInsertPosition = 0; + private int pos = 0; public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); - this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); + this.array = new long[initialSize]; + this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } - public void expandPointerArray() { - final long[] oldArray = pointerArray; + public int numRecords() { + return pos; + } + + public void reset() { + pos = 0; + } + + private int newLength() { // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; + } + + /** + * Returns the memory needed to expand + */ + public long getMemoryToExpand() { + return ((long) (newLength() - array.length)) * 8; + } + + public void expandPointerArray() { + final long[] oldArray = array; + array = new long[newLength()]; + System.arraycopy(oldArray, 0, array, 0, oldArray.length); } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 1 < pointerArray.length; + return pos < array.length; } public long getMemoryUsage() { - return pointerArray.length * 8L; + return array.length * 8L; } /** @@ -78,15 +96,15 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (pointerArray.length == Integer.MAX_VALUE) { + if (array.length == Integer.MAX_VALUE) { throw new IllegalStateException("Sort pointer array has reached maximum size"); } else { expandPointerArray(); } } - pointerArray[pointerArrayInsertPosition] = + array[pos] = PackedRecordPointer.packPointer(recordPointer, partitionId); - pointerArrayInsertPosition++; + pos++; } /** @@ -118,7 +136,7 @@ public void loadNext() { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { - sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + sorter.sort(array, 0, pos, SORT_COMPARATOR); + return new ShuffleSorterIterator(pos, array); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f6c5c944bd77b..e19b37864293c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -127,12 +127,6 @@ public UnsafeShuffleWriter( open(); } - @VisibleForTesting - public int maxRecordSizeBytes() { - assert(sorter != null); - return sorter.maxRecordSizeBytes; - } - private void updatePeakMemoryUsed() { // sorter can be null if this writer is closed if (sorter != null) { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f035bdac810bd..e36709c6fc849 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -18,14 +18,20 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; -import java.util.List; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -33,7 +39,8 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -54,7 +61,7 @@ * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ -public final class BytesToBytesMap { +public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); @@ -62,27 +69,22 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; - /** - * Special record length that is placed after the last record in a data page. - */ - private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager taskMemoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ - private final List dataPages = new LinkedList(); + private final LinkedList dataPages = new LinkedList<>(); /** * The data page that will be used to store keys and values for new hashtable entries. When this * page becomes full, a new page will be allocated and this pointer will change to point to that * new page. */ - private MemoryBlock currentDataPage = null; + private MemoryBlock currentPage = null; /** - * Offset into `currentDataPage` that points to the location where new data can be inserted into + * Offset into `currentPage` that points to the location where new data can be inserted into * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -116,6 +118,11 @@ public final class BytesToBytesMap { // full base addresses in the page table for off-heap mode so that we can reconstruct the full // absolute memory addresses. + /** + * Whether or not the longArray can grow. We will not insert more elements if it's false. + */ + private boolean canGrowArray = true; + /** * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. @@ -164,13 +171,20 @@ public final class BytesToBytesMap { private long peakMemoryUsedBytes = 0L; + private final BlockManager blockManager; + private volatile MapIterator destructiveIterator = null; + private LinkedList spillWriters = new LinkedList<>(); + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, + BlockManager blockManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { + super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; + this.blockManager = blockManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -187,18 +201,13 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); - - // Acquire a new page as soon as we construct the map to ensure that we have at least - // one page to work with. Otherwise, other operators in the same task may starve this - // map (SPARK-9747). - acquireNewPage(); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, pageSizeBytes, false); } public BytesToBytesMap( @@ -208,6 +217,7 @@ public BytesToBytesMap( boolean enablePerfMetrics) { this( taskMemoryManager, + SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -219,61 +229,153 @@ public BytesToBytesMap( */ public int numElements() { return numElements; } - public static final class BytesToBytesMapIterator implements Iterator { + public final class MapIterator implements Iterator { - private final int numRecords; - private final Iterator dataPagesIterator; + private int numRecords; private final Location loc; private MemoryBlock currentPage = null; - private int currentRecordNumber = 0; + private int recordsInPage = 0; private Object pageBaseObject; private long offsetInPage; // If this iterator destructive or not. When it is true, it frees each page as it moves onto // next one. private boolean destructive = false; - private BytesToBytesMap bmap; + private UnsafeSorterSpillReader reader = null; - private BytesToBytesMapIterator( - int numRecords, Iterator dataPagesIterator, Location loc, - boolean destructive, BytesToBytesMap bmap) { + private MapIterator(int numRecords, Location loc, boolean destructive) { this.numRecords = numRecords; - this.dataPagesIterator = dataPagesIterator; this.loc = loc; this.destructive = destructive; - this.bmap = bmap; - if (dataPagesIterator.hasNext()) { - advanceToNextPage(); + if (destructive) { + destructiveIterator = this; } } private void advanceToNextPage() { - if (destructive && currentPage != null) { - dataPagesIterator.remove(); - this.bmap.taskMemoryManager.freePage(currentPage); + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + freePage(currentPage); + nextIdx --; + } + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); + offsetInPage += 4; + } else { + currentPage = null; + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + try { + reader = spillWriters.getFirst().getReader(blockManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + } } - currentPage = dataPagesIterator.next(); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); } @Override public boolean hasNext() { - return currentRecordNumber != numRecords; + if (numRecords == 0) { + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + return numRecords > 0; } @Override public Location next() { - int totalLength = Platform.getInt(pageBaseObject, offsetInPage); - if (totalLength == END_OF_PAGE_MARKER) { + if (recordsInPage == 0) { advanceToNextPage(); - totalLength = Platform.getInt(pageBaseObject, offsetInPage); } - loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; - currentRecordNumber++; - return loc; + numRecords--; + if (currentPage != null) { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; + recordsInPage --; + return loc; + } else { + assert(reader != null); + if (!reader.hasNext()) { + advanceToNextPage(); + } + try { + reader.loadNext(); + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength()); + return loc; + } + } + + public long spill(long numBytes) throws IOException { + synchronized (this) { + if (!destructive || dataPages.size() == 1) { + return 0L; + } + + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } + + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = Platform.getInt(base, offset); + offset += 4; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = Platform.getInt(base, offset); + writer.write(base, offset + 4, length, 0); + offset += 4 + length; + numRecords--; + } + writer.close(); + spillWriters.add(writer); + + dataPages.removeLast(); + released += block.size(); + freePage(block); + + if (released >= numBytes) { + break; + } + } + + return released; + } } @Override @@ -290,8 +392,8 @@ public void remove() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + public MapIterator iterator() { + return new MapIterator(numElements, loc, false); } /** @@ -304,8 +406,8 @@ public BytesToBytesMapIterator iterator() { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator destructiveIterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); + public MapIterator destructiveIterator() { + return new MapIterator(numElements, loc, true); } /** @@ -314,11 +416,8 @@ public BytesToBytesMapIterator destructiveIterator() { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes) { - safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + public Location lookup(Object keyBase, long keyOffset, int keyLength) { + safeLookup(keyBase, keyOffset, keyLength, loc); return loc; } @@ -327,18 +426,14 @@ public Location lookup( * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes, - Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { assert(bitset != null); assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); int pos = hashcode & mask; int step = 1; while (true) { @@ -354,16 +449,16 @@ public void safeLookup( if ((int) (stored) == hashcode) { // Full hash code matches. Let's compare the keys for equality. loc.with(pos, hashcode, true); - if (loc.getKeyLength() == keyRowLengthBytes) { + if (loc.getKeyLength() == keyLength) { final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedKeyBaseObject = keyAddress.getBaseObject(); - final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final Object storedkeyBase = keyAddress.getBaseObject(); + final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( - keyBaseObject, - keyBaseOffset, - storedKeyBaseObject, - storedKeyBaseOffset, - keyRowLengthBytes + keyBase, + keyOffset, + storedkeyBase, + storedkeyOffset, + keyLength ); if (areEqual) { return; @@ -410,18 +505,18 @@ private void updateAddressesAndSizes(long fullKeyAddress) { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object page, final long offsetInPage) { - long position = offsetInPage; - final int totalLength = Platform.getInt(page, position); + private void updateAddressesAndSizes(final Object base, final long offset) { + long position = offset; + final int totalLength = Platform.getInt(base, position); position += 4; - keyLength = Platform.getInt(page, position); + keyLength = Platform.getInt(base, position); position += 4; valueLength = totalLength - keyLength - 4; - keyMemoryLocation.setObjAndOffset(page, position); + keyMemoryLocation.setObjAndOffset(base, position); position += keyLength; - valueMemoryLocation.setObjAndOffset(page, position); + valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -443,6 +538,19 @@ private Location with(MemoryBlock page, long offsetInPage) { return this; } + /** + * This is only used for spilling + */ + private Location with(Object base, long offset, int length) { + this.isDefined = true; + this.memoryPage = null; + keyLength = Platform.getInt(base, offset); + valueLength = length - 4 - keyLength; + keyMemoryLocation.setObjAndOffset(base, offset + 4); + valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); + return this; + } + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. @@ -517,9 +625,9 @@ public int getValueLength() { * As an example usage, here's the proper way to store a new key: *

    *
    -     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
    +     *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
          *   if (!loc.isDefined()) {
    -     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
    +     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
          *       // handle failure to grow map (by spilling, for example)
          *     }
          *   }
    @@ -531,113 +639,59 @@ public int getValueLength() {
          * @return true if the put() was successful and false if the put() failed because memory could
          *         not be acquired.
          */
    -    public boolean putNewKey(
    -        Object keyBaseObject,
    -        long keyBaseOffset,
    -        int keyLengthBytes,
    -        Object valueBaseObject,
    -        long valueBaseOffset,
    -        int valueLengthBytes) {
    +    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
    +        Object valueBase, long valueOffset, int valueLength) {
           assert (!isDefined) : "Can only set value once for a key";
    -      assert (keyLengthBytes % 8 == 0);
    -      assert (valueLengthBytes % 8 == 0);
    +      assert (keyLength % 8 == 0);
    +      assert (valueLength % 8 == 0);
           assert(bitset != null);
           assert(longArray != null);
     
    -      if (numElements == MAX_CAPACITY) {
    -        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
    +      if (numElements == MAX_CAPACITY || !canGrowArray) {
    +        return false;
           }
     
           // Here, we'll copy the data into our data pages. Because we only store a relative offset from
           // the key address instead of storing the absolute address of the value, the key and value
           // must be stored in the same memory page.
           // (8 byte key length) (key) (value)
    -      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
    -
    -      // --- Figure out where to insert the new record ---------------------------------------------
    -
    -      final MemoryBlock dataPage;
    -      final Object dataPageBaseObject;
    -      final long dataPageInsertOffset;
    -      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
    -      if (useOverflowPage) {
    -        // The record is larger than the page size, so allocate a special overflow page just to hold
    -        // that record.
    -        final long overflowPageSize = requiredSize + 8;
    -        MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
    -        if (overflowPage == null) {
    -          logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
    +      final long recordLength = 8 + keyLength + valueLength;
    +      if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
    +        if (!acquireNewPage(recordLength + 4L)) {
               return false;
             }
    -        dataPages.add(overflowPage);
    -        dataPage = overflowPage;
    -        dataPageBaseObject = overflowPage.getBaseObject();
    -        dataPageInsertOffset = overflowPage.getBaseOffset();
    -      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
    -        // The record can fit in a data page, but either we have not allocated any pages yet or
    -        // the current page does not have enough space.
    -        if (currentDataPage != null) {
    -          // There wasn't enough space in the current page, so write an end-of-page marker:
    -          final Object pageBaseObject = currentDataPage.getBaseObject();
    -          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
    -          Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
    -        }
    -        if (!acquireNewPage()) {
    -          return false;
    -        }
    -        dataPage = currentDataPage;
    -        dataPageBaseObject = currentDataPage.getBaseObject();
    -        dataPageInsertOffset = currentDataPage.getBaseOffset();
    -      } else {
    -        // There is enough space in the current data page.
    -        dataPage = currentDataPage;
    -        dataPageBaseObject = currentDataPage.getBaseObject();
    -        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
           }
     
           // --- Append the key and value data to the current data page --------------------------------
    -
    -      long insertCursor = dataPageInsertOffset;
    -
    -      // Compute all of our offsets up-front:
    -      final long recordOffset = insertCursor;
    -      insertCursor += 4;
    -      final long keyLengthOffset = insertCursor;
    -      insertCursor += 4;
    -      final long keyDataOffsetInPage = insertCursor;
    -      insertCursor += keyLengthBytes;
    -      final long valueDataOffsetInPage = insertCursor;
    -      insertCursor += valueLengthBytes; // word used to store the value size
    -
    -      Platform.putInt(dataPageBaseObject, recordOffset,
    -        keyLengthBytes + valueLengthBytes + 4);
    -      Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
    -      // Copy the key
    -      Platform.copyMemory(
    -        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
    -      // Copy the value
    -      Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
    -        valueDataOffsetInPage, valueLengthBytes);
    -
    -      // --- Update bookeeping data structures -----------------------------------------------------
    -
    -      if (useOverflowPage) {
    -        // Store the end-of-page marker at the end of the data page
    -        Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
    -      } else {
    -        pageCursor += requiredSize;
    -      }
    -
    +      final Object base = currentPage.getBaseObject();
    +      long offset = currentPage.getBaseOffset() + pageCursor;
    +      final long recordOffset = offset;
    +      Platform.putInt(base, offset, keyLength + valueLength + 4);
    +      Platform.putInt(base, offset + 4, keyLength);
    +      offset += 8;
    +      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
    +      offset += keyLength;
    +      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
    +
    +      // --- Update bookkeeping data structures -----------------------------------------------------
    +      offset = currentPage.getBaseOffset();
    +      Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
    +      pageCursor += recordLength;
           numElements++;
           bitset.set(pos);
           final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
    -        dataPage, recordOffset);
    +        currentPage, recordOffset);
           longArray.set(pos * 2, storedKeyAddress);
           longArray.set(pos * 2 + 1, keyHashcode);
           updateAddressesAndSizes(storedKeyAddress);
           isDefined = true;
    +
           if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
    -        growAndRehash();
    +        try {
    +          growAndRehash();
    +        } catch (OutOfMemoryError oom) {
    +          canGrowArray = false;
    +        }
           }
           return true;
         }
    @@ -647,18 +701,26 @@ public boolean putNewKey(
        * Acquire a new page from the memory manager.
        * @return whether there is enough space to allocate the new page.
        */
    -  private boolean acquireNewPage() {
    -    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
    -    if (newPage == null) {
    -      logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
    +  private boolean acquireNewPage(long required) {
    +    try {
    +      currentPage = allocatePage(required);
    +    } catch (OutOfMemoryError e) {
           return false;
         }
    -    dataPages.add(newPage);
    -    pageCursor = 0;
    -    currentDataPage = newPage;
    +    dataPages.add(currentPage);
    +    Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
    +    pageCursor = 4;
         return true;
       }
     
    +  @Override
    +  public long spill(long size, MemoryConsumer trigger) throws IOException {
    +    if (trigger != this && destructiveIterator != null) {
    +      return destructiveIterator.spill(size);
    +    }
    +    return 0L;
    +  }
    +
       /**
        * Allocate new data structures for this map. When calling this outside of the constructor,
        * make sure to keep references to the old data structures so that you can free them.
    @@ -670,6 +732,7 @@ private void allocate(int capacity) {
         // The capacity needs to be divisible by 64 so that our bit set can be sized properly
         capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
         assert (capacity <= MAX_CAPACITY);
    +    acquireMemory(capacity * 16);
         longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
         bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
     
    @@ -677,6 +740,19 @@ private void allocate(int capacity) {
         this.mask = capacity - 1;
       }
     
    +  /**
    +   * Free the memory used by longArray.
    +   */
    +  public void freeArray() {
    +    updatePeakMemoryUsed();
    +    if (longArray != null) {
    +      long used = longArray.memoryBlock().size();
    +      longArray = null;
    +      releaseMemory(used);
    +      bitset = null;
    +    }
    +  }
    +
       /**
        * Free all allocated memory associated with this map, including the storage for keys and values
        * as well as the hash map array itself.
    @@ -684,16 +760,23 @@ private void allocate(int capacity) {
        * This method is idempotent and can be called multiple times.
        */
       public void free() {
    -    updatePeakMemoryUsed();
    -    longArray = null;
    -    bitset = null;
    +    freeArray();
         Iterator dataPagesIterator = dataPages.iterator();
         while (dataPagesIterator.hasNext()) {
           MemoryBlock dataPage = dataPagesIterator.next();
           dataPagesIterator.remove();
    -      taskMemoryManager.freePage(dataPage);
    +      freePage(dataPage);
         }
         assert(dataPages.isEmpty());
    +
    +    while (!spillWriters.isEmpty()) {
    +      File file = spillWriters.removeFirst().getFile();
    +      if (file != null && file.exists()) {
    +        if (!file.delete()) {
    +          logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
    +        }
    +      }
    +    }
       }
     
       public TaskMemoryManager getTaskMemoryManager() {
    @@ -782,7 +865,13 @@ void growAndRehash() {
         final int oldCapacity = (int) oldBitSet.capacity();
     
         // Allocate the new data structures
    -    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
    +    try {
    +      allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
    +    } catch (OutOfMemoryError oom) {
    +      longArray = oldLongArray;
    +      bitset = oldBitSet;
    +      throw oom;
    +    }
     
         // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
         for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
    @@ -806,6 +895,7 @@ void growAndRehash() {
             }
           }
         }
    +    releaseMemory(oldLongArray.memoryBlock().size());
     
         if (enablePerfMetrics) {
           timeSpentResizingNs += System.nanoTime() - resizeStartTime;
    diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
    index e317ea391c556..49a5a4b13b70d 100644
    --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
    +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
    @@ -17,39 +17,34 @@
     
     package org.apache.spark.util.collection.unsafe.sort;
     
    +import javax.annotation.Nullable;
     import java.io.File;
     import java.io.IOException;
     import java.util.LinkedList;
     
    -import javax.annotation.Nullable;
    -
    -import scala.runtime.AbstractFunction0;
    -import scala.runtime.BoxedUnit;
    -
     import com.google.common.annotations.VisibleForTesting;
     import org.slf4j.Logger;
     import org.slf4j.LoggerFactory;
     
     import org.apache.spark.TaskContext;
     import org.apache.spark.executor.ShuffleWriteMetrics;
    +import org.apache.spark.memory.MemoryConsumer;
    +import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.storage.BlockManager;
    -import org.apache.spark.unsafe.array.ByteArrayMethods;
     import org.apache.spark.unsafe.Platform;
     import org.apache.spark.unsafe.memory.MemoryBlock;
    -import org.apache.spark.memory.TaskMemoryManager;
    +import org.apache.spark.util.TaskCompletionListener;
     import org.apache.spark.util.Utils;
     
     /**
      * External sorter based on {@link UnsafeInMemorySorter}.
      */
    -public final class UnsafeExternalSorter {
    +public final class UnsafeExternalSorter extends MemoryConsumer {
     
       private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
     
    -  private final long pageSizeBytes;
       private final PrefixComparator prefixComparator;
       private final RecordComparator recordComparator;
    -  private final int initialSize;
       private final TaskMemoryManager taskMemoryManager;
       private final BlockManager blockManager;
       private final TaskContext taskContext;
    @@ -69,14 +64,12 @@ public final class UnsafeExternalSorter {
       private final LinkedList spillWriters = new LinkedList<>();
     
       // These variables are reset after spilling:
    -  @Nullable private UnsafeInMemorySorter inMemSorter;
    -  // Whether the in-mem sorter is created internally, or passed in from outside.
    -  // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
    -  private boolean isInMemSorterExternal = false;
    +  @Nullable private volatile UnsafeInMemorySorter inMemSorter;
    +
       private MemoryBlock currentPage = null;
    -  private long currentPagePosition = -1;
    -  private long freeSpaceInCurrentPage = 0;
    +  private long pageCursor = -1;
       private long peakMemoryUsedBytes = 0;
    +  private volatile SpillableIterator readingIterator = null;
     
       public static UnsafeExternalSorter createWithExistingInMemorySorter(
           TaskMemoryManager taskMemoryManager,
    @@ -86,7 +79,7 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
           PrefixComparator prefixComparator,
           int initialSize,
           long pageSizeBytes,
    -      UnsafeInMemorySorter inMemorySorter) throws IOException {
    +      UnsafeInMemorySorter inMemorySorter) {
         return new UnsafeExternalSorter(taskMemoryManager, blockManager,
           taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
       }
    @@ -98,7 +91,7 @@ public static UnsafeExternalSorter create(
           RecordComparator recordComparator,
           PrefixComparator prefixComparator,
           int initialSize,
    -      long pageSizeBytes) throws IOException {
    +      long pageSizeBytes) {
         return new UnsafeExternalSorter(taskMemoryManager, blockManager,
           taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
       }
    @@ -111,60 +104,41 @@ private UnsafeExternalSorter(
           PrefixComparator prefixComparator,
           int initialSize,
           long pageSizeBytes,
    -      @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
    +      @Nullable UnsafeInMemorySorter existingInMemorySorter) {
    +    super(taskMemoryManager, pageSizeBytes);
         this.taskMemoryManager = taskMemoryManager;
         this.blockManager = blockManager;
         this.taskContext = taskContext;
         this.recordComparator = recordComparator;
         this.prefixComparator = prefixComparator;
    -    this.initialSize = initialSize;
         // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
         // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
         this.fileBufferSizeBytes = 32 * 1024;
    -    this.pageSizeBytes = pageSizeBytes;
    +    // TODO: metrics tracking + integration with shuffle write metrics
    +    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
         this.writeMetrics = new ShuffleWriteMetrics();
     
         if (existingInMemorySorter == null) {
    -      initializeForWriting();
    -      // Acquire a new page as soon as we construct the sorter to ensure that we have at
    -      // least one page to work with. Otherwise, other operators in the same task may starve
    -      // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
    -      acquireNewPage();
    +      this.inMemSorter =
    +        new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
    +      acquireMemory(inMemSorter.getMemoryUsage());
         } else {
    -      this.isInMemSorterExternal = true;
           this.inMemSorter = existingInMemorySorter;
    +      // will acquire after free the map
         }
    +    this.peakMemoryUsedBytes = getMemoryUsage();
     
         // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
         // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
         // does not fully consume the sorter's output (e.g. sort followed by limit).
    -    taskContext.addOnCompleteCallback(new AbstractFunction0() {
    -      @Override
    -      public BoxedUnit apply() {
    -        cleanupResources();
    -        return null;
    +    taskContext.addTaskCompletionListener(
    +      new TaskCompletionListener() {
    +        @Override
    +        public void onTaskCompletion(TaskContext context) {
    +          cleanupResources();
    +        }
           }
    -    });
    -  }
    -
    -  // TODO: metrics tracking + integration with shuffle write metrics
    -  // need to connect the write metrics to task metrics so we count the spill IO somewhere.
    -
    -  /**
    -   * Allocates new sort data structures. Called when creating the sorter and after each spill.
    -   */
    -  private void initializeForWriting() throws IOException {
    -    // Note: Do not track memory for the pointer array for now because of SPARK-10474.
    -    // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to
    -    // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably
    -    // fails if all other memory is already occupied. It should be safe to not track the array
    -    // because its memory footprint is frequently much smaller than that of a page. This is a
    -    // temporary hack that we should address in 1.6.0.
    -    // TODO: track the pointer array memory!
    -    this.writeMetrics = new ShuffleWriteMetrics();
    -    this.inMemSorter =
    -      new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
    -    this.isInMemSorterExternal = false;
    +    );
       }
     
       /**
    @@ -173,14 +147,27 @@ private void initializeForWriting() throws IOException {
        */
       @VisibleForTesting
       public void closeCurrentPage() {
    -    freeSpaceInCurrentPage = 0;
    +    if (currentPage != null) {
    +      pageCursor = currentPage.getBaseOffset() + currentPage.size();
    +    }
       }
     
       /**
        * Sort and spill the current records in response to memory pressure.
        */
    -  public void spill() throws IOException {
    -    assert(inMemSorter != null);
    +  @Override
    +  public long spill(long size, MemoryConsumer trigger) throws IOException {
    +    if (trigger != this) {
    +      if (readingIterator != null) {
    +        return readingIterator.spill();
    +      }
    +      return 0L;
    +    }
    +
    +    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
    +      return 0L;
    +    }
    +
         logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
           Thread.currentThread().getId(),
           Utils.bytesToString(getMemoryUsage()),
    @@ -202,6 +189,8 @@ public void spill() throws IOException {
             spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
           }
           spillWriter.close();
    +
    +      inMemSorter.reset();
         }
     
         final long spillSize = freeMemory();
    @@ -210,7 +199,7 @@ public void spill() throws IOException {
         // written to disk. This also counts the space needed to store the sorter's pointer array.
         taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
     
    -    initializeForWriting();
    +    return spillSize;
       }
     
       /**
    @@ -246,7 +235,7 @@ public int getNumberOfAllocatedPages() {
       }
     
       /**
    -   * Free this sorter's in-memory data structures, including its data pages and pointer array.
    +   * Free this sorter's data pages.
        *
        * @return the number of bytes freed.
        */
    @@ -254,14 +243,12 @@ private long freeMemory() {
         updatePeakMemoryUsed();
         long memoryFreed = 0;
         for (MemoryBlock block : allocatedPages) {
    -      taskMemoryManager.freePage(block);
           memoryFreed += block.size();
    +      freePage(block);
         }
    -    // TODO: track in-memory sorter memory usage (SPARK-10474)
         allocatedPages.clear();
         currentPage = null;
    -    currentPagePosition = -1;
    -    freeSpaceInCurrentPage = 0;
    +    pageCursor = 0;
         return memoryFreed;
       }
     
    @@ -283,8 +270,15 @@ private void deleteSpillFiles() {
        * Frees this sorter's in-memory data structures and cleans up its spill files.
        */
       public void cleanupResources() {
    -    deleteSpillFiles();
    -    freeMemory();
    +    synchronized (this) {
    +      deleteSpillFiles();
    +      freeMemory();
    +      if (inMemSorter != null) {
    +        long used = inMemSorter.getMemoryUsage();
    +        inMemSorter = null;
    +        releaseMemory(used);
    +      }
    +    }
       }
     
       /**
    @@ -295,8 +289,28 @@ public void cleanupResources() {
       private void growPointerArrayIfNecessary() throws IOException {
         assert(inMemSorter != null);
         if (!inMemSorter.hasSpaceForAnotherRecord()) {
    -      // TODO: track the pointer array memory! (SPARK-10474)
    -      inMemSorter.expandPointerArray();
    +      long used = inMemSorter.getMemoryUsage();
    +      long needed = used + inMemSorter.getMemoryToExpand();
    +      try {
    +        acquireMemory(needed);  // could trigger spilling
    +      } catch (OutOfMemoryError e) {
    +        // should have trigger spilling
    +        assert(inMemSorter.hasSpaceForAnotherRecord());
    +        return;
    +      }
    +      // check if spilling is triggered or not
    +      if (inMemSorter.hasSpaceForAnotherRecord()) {
    +        releaseMemory(needed);
    +      } else {
    +        try {
    +          inMemSorter.expandPointerArray();
    +          releaseMemory(used);
    +        } catch (OutOfMemoryError oom) {
    +          // Just in case that JVM had run out of memory
    +          releaseMemory(needed);
    +          spill();
    +        }
    +      }
         }
       }
     
    @@ -304,101 +318,38 @@ private void growPointerArrayIfNecessary() throws IOException {
        * Allocates more memory in order to insert an additional record. This will request additional
        * memory from the memory manager and spill if the requested memory can not be obtained.
        *
    -   * @param requiredSpace the required space in the data page, in bytes, including space for storing
    +   * @param required the required space in the data page, in bytes, including space for storing
        *                      the record size. This must be less than or equal to the page size (records
        *                      that exceed the page size are handled via a different code path which uses
        *                      special overflow pages).
        */
    -  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
    -    assert (requiredSpace <= pageSizeBytes);
    -    if (requiredSpace > freeSpaceInCurrentPage) {
    -      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
    -        freeSpaceInCurrentPage);
    -      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
    -      // without using the free space at the end of the current page. We should also do this for
    -      // BytesToBytesMap.
    -      if (requiredSpace > pageSizeBytes) {
    -        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
    -          pageSizeBytes + ")");
    -      } else {
    -        acquireNewPage();
    -      }
    +  private void acquireNewPageIfNecessary(int required) {
    +    if (currentPage == null ||
    +      pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
    +      // TODO: try to find space on previous pages
    +      currentPage = allocatePage(required);
    +      pageCursor = currentPage.getBaseOffset();
    +      allocatedPages.add(currentPage);
         }
       }
     
    -  /**
    -   * Acquire a new page from the memory manager.
    -   *
    -   * If there is not enough space to allocate the new page, spill all existing ones
    -   * and try again. If there is still not enough space, report error to the caller.
    -   */
    -  private void acquireNewPage() throws IOException {
    -    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
    -    if (currentPage == null) {
    -      spill();
    -      currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
    -      if (currentPage == null) {
    -        throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
    -      }
    -    }
    -    currentPagePosition = currentPage.getBaseOffset();
    -    freeSpaceInCurrentPage = pageSizeBytes;
    -    allocatedPages.add(currentPage);
    -  }
    -
       /**
        * Write a record to the sorter.
        */
    -  public void insertRecord(
    -      Object recordBaseObject,
    -      long recordBaseOffset,
    -      int lengthInBytes,
    -      long prefix) throws IOException {
    +  public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
    +    throws IOException {
     
         growPointerArrayIfNecessary();
         // Need 4 bytes to store the record length.
    -    final int totalSpaceRequired = lengthInBytes + 4;
    -
    -    // --- Figure out where to insert the new record ----------------------------------------------
    -
    -    final MemoryBlock dataPage;
    -    long dataPagePosition;
    -    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
    -    if (useOverflowPage) {
    -      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
    -      // The record is larger than the page size, so allocate a special overflow page just to hold
    -      // that record.
    -      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
    -      if (overflowPage == null) {
    -        spill();
    -        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
    -        if (overflowPage == null) {
    -          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
    -        }
    -      }
    -      allocatedPages.add(overflowPage);
    -      dataPage = overflowPage;
    -      dataPagePosition = overflowPage.getBaseOffset();
    -    } else {
    -      // The record is small enough to fit in a regular data page, but the current page might not
    -      // have enough space to hold it (or no pages have been allocated yet).
    -      acquireNewPageIfNecessary(totalSpaceRequired);
    -      dataPage = currentPage;
    -      dataPagePosition = currentPagePosition;
    -      // Update bookkeeping information
    -      freeSpaceInCurrentPage -= totalSpaceRequired;
    -      currentPagePosition += totalSpaceRequired;
    -    }
    -    final Object dataPageBaseObject = dataPage.getBaseObject();
    -
    -    // --- Insert the record ----------------------------------------------------------------------
    -
    -    final long recordAddress =
    -      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
    -    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
    -    dataPagePosition += 4;
    -    Platform.copyMemory(
    -      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
    +    final int required = length + 4;
    +    acquireNewPageIfNecessary(required);
    +
    +    final Object base = currentPage.getBaseObject();
    +    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
    +    Platform.putInt(base, pageCursor, length);
    +    pageCursor += 4;
    +    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
    +    pageCursor += length;
         assert(inMemSorter != null);
         inMemSorter.insertRecord(recordAddress, prefix);
       }
    @@ -411,59 +362,24 @@ public void insertRecord(
        *
        * record length = key length + value length + 4
        */
    -  public void insertKVRecord(
    -      Object keyBaseObj, long keyOffset, int keyLen,
    -      Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
    +  public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
    +      Object valueBase, long valueOffset, int valueLen, long prefix)
    +    throws IOException {
     
         growPointerArrayIfNecessary();
    -    final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
    -
    -    // --- Figure out where to insert the new record ----------------------------------------------
    -
    -    final MemoryBlock dataPage;
    -    long dataPagePosition;
    -    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
    -    if (useOverflowPage) {
    -      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
    -      // The record is larger than the page size, so allocate a special overflow page just to hold
    -      // that record.
    -      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
    -      if (overflowPage == null) {
    -        spill();
    -        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
    -        if (overflowPage == null) {
    -          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
    -        }
    -      }
    -      allocatedPages.add(overflowPage);
    -      dataPage = overflowPage;
    -      dataPagePosition = overflowPage.getBaseOffset();
    -    } else {
    -      // The record is small enough to fit in a regular data page, but the current page might not
    -      // have enough space to hold it (or no pages have been allocated yet).
    -      acquireNewPageIfNecessary(totalSpaceRequired);
    -      dataPage = currentPage;
    -      dataPagePosition = currentPagePosition;
    -      // Update bookkeeping information
    -      freeSpaceInCurrentPage -= totalSpaceRequired;
    -      currentPagePosition += totalSpaceRequired;
    -    }
    -    final Object dataPageBaseObject = dataPage.getBaseObject();
    -
    -    // --- Insert the record ----------------------------------------------------------------------
    -
    -    final long recordAddress =
    -      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
    -    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
    -    dataPagePosition += 4;
    -
    -    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen);
    -    dataPagePosition += 4;
    -
    -    Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
    -    dataPagePosition += keyLen;
    -
    -    Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
    +    final int required = keyLen + valueLen + 4 + 4;
    +    acquireNewPageIfNecessary(required);
    +
    +    final Object base = currentPage.getBaseObject();
    +    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
    +    Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
    +    pageCursor += 4;
    +    Platform.putInt(base, pageCursor, keyLen);
    +    pageCursor += 4;
    +    Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
    +    pageCursor += keyLen;
    +    Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
    +    pageCursor += valueLen;
     
         assert(inMemSorter != null);
         inMemSorter.insertRecord(recordAddress, prefix);
    @@ -475,10 +391,10 @@ public void insertKVRecord(
        */
       public UnsafeSorterIterator getSortedIterator() throws IOException {
         assert(inMemSorter != null);
    -    final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
    -    int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
    +    readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
    +    int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0);
         if (spillWriters.isEmpty()) {
    -      return inMemoryIterator;
    +      return readingIterator;
         } else {
           final UnsafeSorterSpillMerger spillMerger =
             new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
    @@ -486,9 +402,113 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
             spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
           }
           spillWriters.clear();
    -      spillMerger.addSpillIfNotEmpty(inMemoryIterator);
    +      spillMerger.addSpillIfNotEmpty(readingIterator);
     
           return spillMerger.getSortedIterator();
         }
       }
    +
    +  /**
    +   * An UnsafeSorterIterator that support spilling.
    +   */
    +  class SpillableIterator extends UnsafeSorterIterator {
    +    private UnsafeSorterIterator upstream;
    +    private UnsafeSorterIterator nextUpstream = null;
    +    private MemoryBlock lastPage = null;
    +    private boolean loaded = false;
    +    private int numRecords = 0;
    +
    +    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
    +      this.upstream = inMemIterator;
    +      this.numRecords = inMemIterator.numRecordsLeft();
    +    }
    +
    +    public long spill() throws IOException {
    +      synchronized (this) {
    +        if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
    +          && numRecords > 0)) {
    +          return 0L;
    +        }
    +
    +        UnsafeInMemorySorter.SortedIterator inMemIterator =
    +          ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
    +
    +        final UnsafeSorterSpillWriter spillWriter =
    +          new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
    +        while (inMemIterator.hasNext()) {
    +          inMemIterator.loadNext();
    +          final Object baseObject = inMemIterator.getBaseObject();
    +          final long baseOffset = inMemIterator.getBaseOffset();
    +          final int recordLength = inMemIterator.getRecordLength();
    +          spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
    +        }
    +        spillWriter.close();
    +        spillWriters.add(spillWriter);
    +        nextUpstream = spillWriter.getReader(blockManager);
    +
    +        long released = 0L;
    +        synchronized (UnsafeExternalSorter.this) {
    +          // release the pages except the one that is used
    +          for (MemoryBlock page : allocatedPages) {
    +            if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
    +              released += page.size();
    +              freePage(page);
    +            } else {
    +              lastPage = page;
    +            }
    +          }
    +          allocatedPages.clear();
    +        }
    +        return released;
    +      }
    +    }
    +
    +    @Override
    +    public boolean hasNext() {
    +      return numRecords > 0;
    +    }
    +
    +    @Override
    +    public void loadNext() throws IOException {
    +      synchronized (this) {
    +        loaded = true;
    +        if (nextUpstream != null) {
    +          // Just consumed the last record from in memory iterator
    +          if (lastPage != null) {
    +            freePage(lastPage);
    +            lastPage = null;
    +          }
    +          upstream = nextUpstream;
    +          nextUpstream = null;
    +
    +          assert(inMemSorter != null);
    +          long used = inMemSorter.getMemoryUsage();
    +          inMemSorter = null;
    +          releaseMemory(used);
    +        }
    +        numRecords--;
    +        upstream.loadNext();
    +      }
    +    }
    +
    +    @Override
    +    public Object getBaseObject() {
    +      return upstream.getBaseObject();
    +    }
    +
    +    @Override
    +    public long getBaseOffset() {
    +      return upstream.getBaseOffset();
    +    }
    +
    +    @Override
    +    public int getRecordLength() {
    +      return upstream.getRecordLength();
    +    }
    +
    +    @Override
    +    public long getKeyPrefix() {
    +      return upstream.getKeyPrefix();
    +    }
    +  }
     }
    diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
    index 5aad72c374c37..1480f0681ed9c 100644
    --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
    +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
    @@ -70,12 +70,12 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
        * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
        * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
        */
    -  private long[] pointerArray;
    +  private long[] array;
     
       /**
        * The position in the sort buffer where new records can be inserted.
        */
    -  private int pointerArrayInsertPosition = 0;
    +  private int pos = 0;
     
       public UnsafeInMemorySorter(
           final TaskMemoryManager memoryManager,
    @@ -83,37 +83,43 @@ public UnsafeInMemorySorter(
           final PrefixComparator prefixComparator,
           int initialSize) {
         assert (initialSize > 0);
    -    this.pointerArray = new long[initialSize * 2];
    +    this.array = new long[initialSize * 2];
         this.memoryManager = memoryManager;
         this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
         this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
       }
     
    +  public void reset() {
    +    pos = 0;
    +  }
    +
       /**
        * @return the number of records that have been inserted into this sorter.
        */
       public int numRecords() {
    -    return pointerArrayInsertPosition / 2;
    +    return pos / 2;
       }
     
    -  public long getMemoryUsage() {
    -    return pointerArray.length * 8L;
    +  private int newLength() {
    +    return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
    +  }
    +
    +  public long getMemoryToExpand() {
    +    return (long) (newLength() - array.length) * 8L;
       }
     
    -  static long getMemoryRequirementsForPointerArray(long numEntries) {
    -    return numEntries * 2L * 8L;
    +  public long getMemoryUsage() {
    +    return array.length * 8L;
       }
     
       public boolean hasSpaceForAnotherRecord() {
    -    return pointerArrayInsertPosition + 2 < pointerArray.length;
    +    return pos + 2 <= array.length;
       }
     
       public void expandPointerArray() {
    -    final long[] oldArray = pointerArray;
    -    // Guard against overflow:
    -    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
    -    pointerArray = new long[newLength];
    -    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
    +    final long[] oldArray = array;
    +    array = new long[newLength()];
    +    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
       }
     
       /**
    @@ -127,10 +133,10 @@ public void insertRecord(long recordPointer, long keyPrefix) {
         if (!hasSpaceForAnotherRecord()) {
           expandPointerArray();
         }
    -    pointerArray[pointerArrayInsertPosition] = recordPointer;
    -    pointerArrayInsertPosition++;
    -    pointerArray[pointerArrayInsertPosition] = keyPrefix;
    -    pointerArrayInsertPosition++;
    +    array[pos] = recordPointer;
    +    pos++;
    +    array[pos] = keyPrefix;
    +    pos++;
       }
     
       public static final class SortedIterator extends UnsafeSorterIterator {
    @@ -153,11 +159,25 @@ private SortedIterator(
           this.sortBuffer = sortBuffer;
         }
     
    +    public SortedIterator clone () {
    +      SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
    +      iter.position = position;
    +      iter.baseObject = baseObject;
    +      iter.baseOffset = baseOffset;
    +      iter.keyPrefix = keyPrefix;
    +      iter.recordLength = recordLength;
    +      return iter;
    +    }
    +
         @Override
         public boolean hasNext() {
           return position < sortBufferInsertPosition;
         }
     
    +    public int numRecordsLeft() {
    +      return (sortBufferInsertPosition - position) / 2;
    +    }
    +
         @Override
         public void loadNext() {
           // This pointer points to a 4-byte record length, followed by the record's bytes
    @@ -187,7 +207,7 @@ public void loadNext() {
        * {@code next()} will return the same mutable object.
        */
       public SortedIterator getSortedIterator() {
    -    sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
    -    return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
    +    sorter.sort(array, 0, pos / 2, sortComparator);
    +    return new SortedIterator(memoryManager, pos, array);
       }
     }
    diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
    index 501dfe77d13cb..039e940a357ea 100644
    --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
    +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
    @@ -20,18 +20,18 @@
     import java.io.*;
     
     import com.google.common.io.ByteStreams;
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
     
     import org.apache.spark.storage.BlockId;
     import org.apache.spark.storage.BlockManager;
     import org.apache.spark.unsafe.Platform;
    -import org.slf4j.Logger;
    -import org.slf4j.LoggerFactory;
     
     /**
      * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
      * of the file format).
      */
    -final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
    +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
       private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
     
       private final File file;
    diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
    index e59a84ff8d118..234e21140a1dd 100644
    --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
    +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
    @@ -35,7 +35,7 @@
      *
      *   [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
      */
    -final class UnsafeSorterSpillWriter {
    +public final class UnsafeSorterSpillWriter {
     
       static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
     
    diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
    index 6c9a71c3855b0..b0cf2696a397f 100644
    --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
    +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
    @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
     
     import com.google.common.annotations.VisibleForTesting
     
    +import org.apache.spark.util.Utils
     import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
     import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
     import org.apache.spark.unsafe.array.ByteArrayMethods
    @@ -215,8 +216,12 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
       final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
         val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
         if (curMem < numBytes) {
    -      throw new SparkException(
    -        s"Internal error: release called on $numBytes bytes but task only has $curMem")
    +      if (Utils.isTesting) {
    +        throw new SparkException(
    +          s"Internal error: release called on $numBytes bytes but task only has $curMem")
    +      } else {
    +        logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem")
    +      }
         }
         if (executionMemoryForTask.contains(taskAttemptId)) {
           executionMemoryForTask(taskAttemptId) -= numBytes
    diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
    index a76891acf0baf..9e002621a6909 100644
    --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
    +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
    @@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
         if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
           // Claim up to double our current memory from the shuffle memory pool
           val amountToRequest = 2 * currentMemory - myMemoryThreshold
    -      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
    +      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
           myMemoryThreshold += granted
           // If we were granted too little memory to grow further (either tryToAcquire returned 0,
           // or we already had more memory than myMemoryThreshold), spill the current collection
    @@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging {
        */
       def releaseMemory(): Unit = {
         // The amount we requested does not include the initial memory tracking threshold
    -    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
    +    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
         myMemoryThreshold = initialMemoryThreshold
       }
     
    diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
    index f381db0c62653..dab7b0592cb4e 100644
    --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
    +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
    @@ -17,6 +17,8 @@
     
     package org.apache.spark.memory;
     
    +import java.io.IOException;
    +
     import org.junit.Assert;
     import org.junit.Test;
     
    @@ -25,19 +27,40 @@
     
     public class TaskMemoryManagerSuite {
     
    +  class TestMemoryConsumer extends MemoryConsumer {
    +    TestMemoryConsumer(TaskMemoryManager memoryManager) {
    +      super(memoryManager);
    +    }
    +
    +    @Override
    +    public long spill(long size, MemoryConsumer trigger) throws IOException {
    +      long used = getUsed();
    +      releaseMemory(used);
    +      return used;
    +    }
    +
    +    void use(long size) {
    +      acquireMemory(size);
    +    }
    +
    +    void free(long size) {
    +      releaseMemory(size);
    +    }
    +  }
    +
       @Test
       public void leakedPageMemoryIsDetected() {
         final TaskMemoryManager manager = new TaskMemoryManager(
    -      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    -    manager.allocatePage(4096);  // leak memory
    +      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    +    manager.allocatePage(4096, null);  // leak memory
         Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
       }
     
       @Test
       public void encodePageNumberAndOffsetOffHeap() {
         final TaskMemoryManager manager = new TaskMemoryManager(
    -      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
    -    final MemoryBlock dataPage = manager.allocatePage(256);
    +      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
    +    final MemoryBlock dataPage = manager.allocatePage(256, null);
         // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
         // encode. This test exercises that corner-case:
         final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
    @@ -49,11 +72,53 @@ public void encodePageNumberAndOffsetOffHeap() {
       @Test
       public void encodePageNumberAndOffsetOnHeap() {
         final TaskMemoryManager manager = new TaskMemoryManager(
    -      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    -    final MemoryBlock dataPage = manager.allocatePage(256);
    +      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    +    final MemoryBlock dataPage = manager.allocatePage(256, null);
         final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
         Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
         Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
       }
     
    +  @Test
    +  public void cooperativeSpilling() {
    +    final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
    +    memoryManager.limit(100);
    +    final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
    +
    +    TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
    +    TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
    +    c1.use(100);
    +    assert(c1.getUsed() == 100);
    +    c2.use(100);
    +    assert(c2.getUsed() == 100);
    +    assert(c1.getUsed() == 0);  // spilled
    +    c1.use(100);
    +    assert(c1.getUsed() == 100);
    +    assert(c2.getUsed() == 0);  // spilled
    +
    +    c1.use(50);
    +    assert(c1.getUsed() == 50);  // spilled
    +    assert(c2.getUsed() == 0);
    +    c2.use(50);
    +    assert(c1.getUsed() == 50);
    +    assert(c2.getUsed() == 50);
    +
    +    c1.use(100);
    +    assert(c1.getUsed() == 100);
    +    assert(c2.getUsed() == 0);  // spilled
    +
    +    c1.free(20);
    +    assert(c1.getUsed() == 80);
    +    c2.use(10);
    +    assert(c1.getUsed() == 80);
    +    assert(c2.getUsed() == 10);
    +    c2.use(100);
    +    assert(c2.getUsed() == 100);
    +    assert(c1.getUsed() == 0);  // spilled
    +
    +    c1.free(0);
    +    c2.free(100);
    +    assert(manager.cleanUpAllAllocatedMemory() == 0);
    +  }
    +
     }
    diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
    index 7fb2f92ca80e8..9a43f1f3a9235 100644
    --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
    +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
    @@ -17,25 +17,29 @@
     
     package org.apache.spark.shuffle.sort;
     
    -import org.apache.spark.shuffle.sort.PackedRecordPointer;
    +import java.io.IOException;
    +
     import org.junit.Test;
    -import static org.junit.Assert.*;
     
     import org.apache.spark.SparkConf;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
    -import org.apache.spark.unsafe.memory.MemoryBlock;
    +import org.apache.spark.memory.TestMemoryManager;
     import org.apache.spark.memory.TaskMemoryManager;
    -import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
    +import org.apache.spark.unsafe.memory.MemoryBlock;
    +
    +import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
    +import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID;
    +import static org.junit.Assert.assertEquals;
    +import static org.junit.Assert.assertFalse;
     
     public class PackedRecordPointerSuite {
     
       @Test
    -  public void heap() {
    +  public void heap() throws IOException {
         final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
         final TaskMemoryManager memoryManager =
    -      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
    -    final MemoryBlock page0 = memoryManager.allocatePage(128);
    -    final MemoryBlock page1 = memoryManager.allocatePage(128);
    +      new TaskMemoryManager(new TestMemoryManager(conf), 0);
    +    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
    +    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
         final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
           page1.getBaseOffset() + 42);
         PackedRecordPointer packedPointer = new PackedRecordPointer();
    @@ -49,12 +53,12 @@ public void heap() {
       }
     
       @Test
    -  public void offHeap() {
    +  public void offHeap() throws IOException {
         final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
         final TaskMemoryManager memoryManager =
    -      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
    -    final MemoryBlock page0 = memoryManager.allocatePage(128);
    -    final MemoryBlock page1 = memoryManager.allocatePage(128);
    +      new TaskMemoryManager(new TestMemoryManager(conf), 0);
    +    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
    +    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
         final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
           page1.getBaseOffset() + 42);
         PackedRecordPointer packedPointer = new PackedRecordPointer();
    diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
    index 5049a5306ff21..2293b1bbc113e 100644
    --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
    +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
    @@ -26,7 +26,7 @@
     import org.apache.spark.HashPartitioner;
     import org.apache.spark.SparkConf;
     import org.apache.spark.unsafe.Platform;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
    +import org.apache.spark.memory.TestMemoryManager;
     import org.apache.spark.unsafe.memory.MemoryBlock;
     import org.apache.spark.memory.TaskMemoryManager;
     
    @@ -60,8 +60,8 @@ public void testBasicSorting() throws Exception {
         };
         final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
         final TaskMemoryManager memoryManager =
    -      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
    -    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
    +      new TaskMemoryManager(new TestMemoryManager(conf), 0);
    +    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
         final Object baseObject = dataPage.getBaseObject();
         final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
         final HashPartitioner hashPartitioner = new HashPartitioner(4);
    diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
    index d65926949c036..4763395d7d401 100644
    --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
    +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
    @@ -54,13 +54,14 @@
     import org.apache.spark.scheduler.MapStatus;
     import org.apache.spark.shuffle.IndexShuffleBlockResolver;
     import org.apache.spark.storage.*;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
    +import org.apache.spark.memory.TestMemoryManager;
     import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.util.Utils;
     
     public class UnsafeShuffleWriterSuite {
     
       static final int NUM_PARTITITONS = 4;
    +  TestMemoryManager memoryManager;
       TaskMemoryManager taskMemoryManager;
       final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
       File mergedOutputFile;
    @@ -106,10 +107,11 @@ public void setUp() throws IOException {
         partitionSizesInMergedFile = null;
         spillFilesCreated.clear();
         conf = new SparkConf()
    -      .set("spark.buffer.pageSize", "128m")
    +      .set("spark.buffer.pageSize", "1m")
           .set("spark.unsafe.offHeap", "false");
         taskMetrics = new TaskMetrics();
    -    taskMemoryManager =  new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
    +    memoryManager = new TestMemoryManager(conf);
    +    taskMemoryManager =  new TaskMemoryManager(memoryManager, 0);
     
         when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
         when(blockManager.getDiskWriter(
    @@ -344,9 +346,7 @@ private void testMergingSpills(
         }
         assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
     
    -    assertEquals(
    -      HashMultiset.create(dataToWrite),
    -      HashMultiset.create(readRecordsFromFile()));
    +    assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
         assertSpillFilesWereCleanedUp();
         ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
         assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
    @@ -398,20 +398,14 @@ public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
     
       @Test
       public void writeEnoughDataToTriggerSpill() throws Exception {
    -    taskMemoryManager = spy(taskMemoryManager);
    -    doCallRealMethod() // initialize sort buffer
    -      .doCallRealMethod() // allocate initial data page
    -      .doReturn(0L) // deny request to allocate new page
    -      .doCallRealMethod() // grant new sort buffer and data page
    -      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
    +    memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
         final UnsafeShuffleWriter writer = createWriter(false);
         final ArrayList> dataToWrite = new ArrayList>();
    -    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
    -    for (int i = 0; i < 128 + 1; i++) {
    +    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
    +    for (int i = 0; i < 10 + 1; i++) {
           dataToWrite.add(new Tuple2(i, bigByteArray));
         }
         writer.write(dataToWrite.iterator());
    -    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
         assertEquals(2, spillFilesCreated.size());
         writer.stop(true);
         readRecordsFromFile();
    @@ -426,19 +420,13 @@ public void writeEnoughDataToTriggerSpill() throws Exception {
     
       @Test
       public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
    -    taskMemoryManager = spy(taskMemoryManager);
    -    doCallRealMethod() // initialize sort buffer
    -      .doCallRealMethod() // allocate initial data page
    -      .doReturn(0L) // deny request to allocate new page
    -      .doCallRealMethod() // grant new sort buffer and data page
    -      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
    +    memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
         final UnsafeShuffleWriter writer = createWriter(false);
         final ArrayList> dataToWrite = new ArrayList<>();
         for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
           dataToWrite.add(new Tuple2(i, i));
         }
         writer.write(dataToWrite.iterator());
    -    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
         assertEquals(2, spillFilesCreated.size());
         writer.stop(true);
         readRecordsFromFile();
    @@ -473,11 +461,11 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
         final ArrayList> dataToWrite = new ArrayList>();
         dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1])));
         // We should be able to write a record that's right _at_ the max record size
    -    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
    +    final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
         new Random(42).nextBytes(atMaxRecordSize);
         dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize)));
         // Inserting a record that's larger than the max record size
    -    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
    +    final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
         new Random(42).nextBytes(exceedsMaxRecordSize);
         dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
         writer.write(dataToWrite.iterator());
    @@ -524,7 +512,7 @@ public void testPeakMemoryUsed() throws Exception {
           for (int i = 0; i < numRecordsPerPage * 10; i++) {
             writer.insertRecordIntoSorter(new Tuple2(1, 1));
             newPeakMemory = writer.getPeakMemoryUsedBytes();
    -        if (i % numRecordsPerPage == 0 && i != 0) {
    +        if (i % numRecordsPerPage == 0) {
               // The first page is allocated in constructor, another page will be allocated after
               // every numRecordsPerPage records (peak memory should change).
               assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
    diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
    index 6e52496cf933b..92bd45e5fa241 100644
    --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
    +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
    @@ -17,40 +17,117 @@
     
     package org.apache.spark.unsafe.map;
     
    -import java.lang.Exception;
    +import java.io.File;
    +import java.io.IOException;
    +import java.io.InputStream;
    +import java.io.OutputStream;
     import java.nio.ByteBuffer;
     import java.util.*;
     
    -import org.apache.spark.memory.TaskMemoryManager;
    -import org.junit.*;
    -import static org.hamcrest.Matchers.greaterThan;
    -import static org.junit.Assert.*;
    +import scala.Tuple2;
    +import scala.Tuple2$;
    +import scala.runtime.AbstractFunction1;
    +
    +import org.junit.After;
    +import org.junit.Assert;
    +import org.junit.Before;
    +import org.junit.Test;
    +import org.mockito.Mock;
    +import org.mockito.MockitoAnnotations;
    +import org.mockito.invocation.InvocationOnMock;
    +import org.mockito.stubbing.Answer;
     
     import org.apache.spark.SparkConf;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
    -import org.apache.spark.unsafe.array.ByteArrayMethods;
    -import org.apache.spark.unsafe.memory.*;
    +import org.apache.spark.executor.ShuffleWriteMetrics;
    +import org.apache.spark.memory.TestMemoryManager;
    +import org.apache.spark.memory.TaskMemoryManager;
    +import org.apache.spark.serializer.SerializerInstance;
    +import org.apache.spark.storage.*;
     import org.apache.spark.unsafe.Platform;
    +import org.apache.spark.unsafe.array.ByteArrayMethods;
    +import org.apache.spark.unsafe.memory.MemoryLocation;
    +import org.apache.spark.util.Utils;
    +
    +import static org.hamcrest.Matchers.greaterThan;
    +import static org.junit.Assert.assertEquals;
    +import static org.junit.Assert.assertFalse;
    +import static org.mockito.AdditionalAnswers.returnsSecondArg;
    +import static org.mockito.Answers.RETURNS_SMART_NULLS;
    +import static org.mockito.Matchers.any;
    +import static org.mockito.Matchers.anyInt;
    +import static org.mockito.Mockito.when;
     
     
     public abstract class AbstractBytesToBytesMapSuite {
     
       private final Random rand = new Random(42);
     
    -  private GrantEverythingMemoryManager memoryManager;
    +  private TestMemoryManager memoryManager;
       private TaskMemoryManager taskMemoryManager;
       private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
     
    +  final LinkedList spillFilesCreated = new LinkedList();
    +  File tempDir;
    +
    +  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
    +  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
    +
    +  private static final class CompressStream extends AbstractFunction1 {
    +    @Override
    +    public OutputStream apply(OutputStream stream) {
    +      return stream;
    +    }
    +  }
    +
       @Before
       public void setup() {
         memoryManager =
    -      new GrantEverythingMemoryManager(
    +      new TestMemoryManager(
             new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
         taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
    +
    +    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
    +    spillFilesCreated.clear();
    +    MockitoAnnotations.initMocks(this);
    +    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
    +    when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
    +      @Override
    +      public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
    +        TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
    +        File file = File.createTempFile("spillFile", ".spill", tempDir);
    +        spillFilesCreated.add(file);
    +        return Tuple2$.MODULE$.apply(blockId, file);
    +      }
    +    });
    +    when(blockManager.getDiskWriter(
    +      any(BlockId.class),
    +      any(File.class),
    +      any(SerializerInstance.class),
    +      anyInt(),
    +      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() {
    +      @Override
    +      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
    +        Object[] args = invocationOnMock.getArguments();
    +
    +        return new DiskBlockObjectWriter(
    +          (File) args[1],
    +          (SerializerInstance) args[2],
    +          (Integer) args[3],
    +          new CompressStream(),
    +          false,
    +          (ShuffleWriteMetrics) args[4]
    +        );
    +      }
    +    });
    +    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
    +      .then(returnsSecondArg());
       }
     
       @After
       public void tearDown() {
    +    Utils.deleteRecursively(tempDir);
    +    tempDir = null;
    +
         Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
         if (taskMemoryManager != null) {
           long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
    @@ -415,9 +492,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
     
       @Test
       public void failureToAllocateFirstPage() {
    -    memoryManager.markExecutionAsOutOfMemory();
    +    memoryManager.limit(1024);  // longArray
         BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
    -    memoryManager.markExecutionAsOutOfMemory();
         try {
           final long[] emptyArray = new long[0];
           final BytesToBytesMap.Location loc =
    @@ -439,7 +515,7 @@ public void failureToGrow() {
           int i;
           for (i = 0; i < 127; i++) {
             if (i > 0) {
    -          memoryManager.markExecutionAsOutOfMemory();
    +          memoryManager.limit(0);
             }
             final long[] arr = new long[]{i};
             final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
    @@ -456,6 +532,44 @@ public void failureToGrow() {
         }
       }
     
    +  @Test
    +  public void spillInIterator() throws IOException {
    +    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
    +    try {
    +      int i;
    +      for (i = 0; i < 1024; i++) {
    +        final long[] arr = new long[]{i};
    +        final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
    +        loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
    +      }
    +      BytesToBytesMap.MapIterator iter = map.iterator();
    +      for (i = 0; i < 100; i++) {
    +        iter.next();
    +      }
    +      // Non-destructive iterator is not spillable
    +      Assert.assertEquals(0, iter.spill(1024L * 10));
    +      for (i = 100; i < 1024; i++) {
    +        iter.next();
    +      }
    +
    +      BytesToBytesMap.MapIterator iter2 = map.destructiveIterator();
    +      for (i = 0; i < 100; i++) {
    +        iter2.next();
    +      }
    +      Assert.assertTrue(iter2.spill(1024) >= 1024);
    +      for (i = 100; i < 1024; i++) {
    +        iter2.next();
    +      }
    +      assertFalse(iter2.hasNext());
    +    } finally {
    +      map.free();
    +      for (File spillFile : spillFilesCreated) {
    +        assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
    +          spillFile.exists());
    +      }
    +    }
    +  }
    +
       @Test
       public void initialCapacityBoundsChecking() {
         try {
    @@ -500,7 +614,7 @@ public void testPeakMemoryUsed() {
               Platform.LONG_ARRAY_OFFSET,
               8);
             newPeakMemory = map.getPeakMemoryUsedBytes();
    -        if (i % numRecordsPerPage == 0 && i > 0) {
    +        if (i % numRecordsPerPage == 0) {
               // We allocated a new page for this record, so peak memory should change
               assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
             } else {
    @@ -519,11 +633,4 @@ public void testPeakMemoryUsed() {
         }
       }
     
    -  @Test
    -  public void testAcquirePageInConstructor() {
    -    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
    -    assertEquals(1, map.getNumDataPages());
    -    map.free();
    -  }
    -
     }
    diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
    index 94d50b94fde3f..cfead0e5924b8 100644
    --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
    +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
    @@ -36,28 +36,29 @@
     import org.mockito.MockitoAnnotations;
     import org.mockito.invocation.InvocationOnMock;
     import org.mockito.stubbing.Answer;
    -import static org.hamcrest.Matchers.greaterThanOrEqualTo;
    -import static org.junit.Assert.*;
    -import static org.mockito.AdditionalAnswers.returnsSecondArg;
    -import static org.mockito.Answers.RETURNS_SMART_NULLS;
    -import static org.mockito.Mockito.*;
     
     import org.apache.spark.SparkConf;
     import org.apache.spark.TaskContext;
     import org.apache.spark.executor.ShuffleWriteMetrics;
     import org.apache.spark.executor.TaskMetrics;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
    +import org.apache.spark.memory.TestMemoryManager;
    +import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.serializer.SerializerInstance;
     import org.apache.spark.storage.*;
     import org.apache.spark.unsafe.Platform;
    -import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.util.Utils;
     
    +import static org.hamcrest.Matchers.greaterThanOrEqualTo;
    +import static org.junit.Assert.*;
    +import static org.mockito.AdditionalAnswers.returnsSecondArg;
    +import static org.mockito.Answers.RETURNS_SMART_NULLS;
    +import static org.mockito.Mockito.*;
    +
     public class UnsafeExternalSorterSuite {
     
       final LinkedList spillFilesCreated = new LinkedList();
    -  final GrantEverythingMemoryManager memoryManager =
    -    new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
    +  final TestMemoryManager memoryManager =
    +    new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
       final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
       // Use integer comparison for comparing prefixes (which are partition ids, in this case)
       final PrefixComparator prefixComparator = new PrefixComparator() {
    @@ -86,7 +87,7 @@ public int compare(
       @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
     
     
    -  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
    +  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
     
       private static final class CompressStream extends AbstractFunction1 {
         @Override
    @@ -233,7 +234,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception {
           insertNumber(sorter, numRecords - i);
         }
         assertEquals(1, sorter.getNumberOfAllocatedPages());
    -    memoryManager.markExecutionAsOutOfMemory();
    +    memoryManager.markExecutionAsOutOfMemoryOnce();
         // The insertion of this record should trigger a spill:
         insertNumber(sorter, 0);
         // Ensure that spill files were created
    @@ -311,6 +312,62 @@ public void sortingRecordsThatExceedPageSize() throws Exception {
         assertSpillFilesWereCleanedUp();
       }
     
    +  @Test
    +  public void forcedSpillingWithReadIterator() throws Exception {
    +    final UnsafeExternalSorter sorter = newSorter();
    +    long[] record = new long[100];
    +    int recordSize = record.length * 8;
    +    int n = (int) pageSizeBytes / recordSize * 3;
    +    for (int i = 0; i < n; i++) {
    +      record[0] = (long) i;
    +      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
    +    }
    +    assert(sorter.getNumberOfAllocatedPages() >= 2);
    +    UnsafeExternalSorter.SpillableIterator iter =
    +      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
    +    int lastv = 0;
    +    for (int i = 0; i < n / 3; i++) {
    +      iter.hasNext();
    +      iter.loadNext();
    +      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
    +      lastv = i;
    +    }
    +    assert(iter.spill() > 0);
    +    assert(iter.spill() == 0);
    +    assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv);
    +    for (int i = n / 3; i < n; i++) {
    +      iter.hasNext();
    +      iter.loadNext();
    +      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
    +    }
    +    sorter.cleanupResources();
    +    assertSpillFilesWereCleanedUp();
    +  }
    +
    +  @Test
    +  public void forcedSpillingWithNotReadIterator() throws Exception {
    +    final UnsafeExternalSorter sorter = newSorter();
    +    long[] record = new long[100];
    +    int recordSize = record.length * 8;
    +    int n = (int) pageSizeBytes / recordSize * 3;
    +    for (int i = 0; i < n; i++) {
    +      record[0] = (long) i;
    +      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
    +    }
    +    assert(sorter.getNumberOfAllocatedPages() >= 2);
    +    UnsafeExternalSorter.SpillableIterator iter =
    +      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
    +    assert(iter.spill() > 0);
    +    assert(iter.spill() == 0);
    +    for (int i = 0; i < n; i++) {
    +      iter.hasNext();
    +      iter.loadNext();
    +      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
    +    }
    +    sorter.cleanupResources();
    +    assertSpillFilesWereCleanedUp();
    +  }
    +
       @Test
       public void testPeakMemoryUsed() throws Exception {
         final long recordLengthBytes = 8;
    @@ -334,7 +391,7 @@ public void testPeakMemoryUsed() throws Exception {
             insertNumber(sorter, i);
             newPeakMemory = sorter.getPeakMemoryUsedBytes();
             // The first page is pre-allocated on instantiation
    -        if (i % numRecordsPerPage == 0 && i > 0) {
    +        if (i % numRecordsPerPage == 0) {
               // We allocated a new page for this record, so peak memory should change
               assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
             } else {
    @@ -358,21 +415,5 @@ public void testPeakMemoryUsed() throws Exception {
         }
       }
     
    -  @Test
    -  public void testReservePageOnInstantiation() throws Exception {
    -    final UnsafeExternalSorter sorter = newSorter();
    -    try {
    -      assertEquals(1, sorter.getNumberOfAllocatedPages());
    -      // Inserting a new record doesn't allocate more memory since we already have a page
    -      long peakMemory = sorter.getPeakMemoryUsedBytes();
    -      insertNumber(sorter, 100);
    -      assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
    -      assertEquals(1, sorter.getNumberOfAllocatedPages());
    -    } finally {
    -      sorter.cleanupResources();
    -      assertSpillFilesWereCleanedUp();
    -    }
    -  }
    -
     }
     
    diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
    index d5de56a0512f9..642f6585f8a15 100644
    --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
    +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
    @@ -20,17 +20,19 @@
     import java.util.Arrays;
     
     import org.junit.Test;
    -import static org.hamcrest.MatcherAssert.assertThat;
    -import static org.hamcrest.Matchers.*;
    -import static org.junit.Assert.*;
    -import static org.mockito.Mockito.mock;
     
     import org.apache.spark.HashPartitioner;
     import org.apache.spark.SparkConf;
    +import org.apache.spark.memory.TestMemoryManager;
    +import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.unsafe.Platform;
    -import org.apache.spark.memory.GrantEverythingMemoryManager;
     import org.apache.spark.unsafe.memory.MemoryBlock;
    -import org.apache.spark.memory.TaskMemoryManager;
    +
    +import static org.hamcrest.MatcherAssert.assertThat;
    +import static org.hamcrest.Matchers.greaterThanOrEqualTo;
    +import static org.hamcrest.Matchers.isIn;
    +import static org.junit.Assert.assertEquals;
    +import static org.mockito.Mockito.mock;
     
     public class UnsafeInMemorySorterSuite {
     
    @@ -44,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset,
       public void testSortingEmptyInput() {
         final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
           new TaskMemoryManager(
    -        new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
    +        new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
           mock(RecordComparator.class),
           mock(PrefixComparator.class),
           100);
    @@ -66,8 +68,8 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
           "Mango"
         };
         final TaskMemoryManager memoryManager = new TaskMemoryManager(
    -      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    -    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
    +      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
    +    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
         final Object baseObject = dataPage.getBaseObject();
         // Write the records into the data page:
         long position = dataPage.getBaseOffset();
    diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
    index 0242cbc9244a8..203dab934ca1f 100644
    --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
    @@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
         // cause is preserved
         val thrownDueToTaskFailure = intercept[SparkException] {
           sc.parallelize(Seq(0)).mapPartitions { iter =>
    -        TaskContext.get().taskMemoryManager().allocatePage(128)
    +        TaskContext.get().taskMemoryManager().allocatePage(128, null)
             throw new Exception("intentional task failure")
             iter
           }.count()
    @@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
         // If the task succeeded but memory was leaked, then the task should fail due to that leak
         val thrownDueToMemoryLeak = intercept[SparkException] {
           sc.parallelize(Seq(0)).mapPartitions { iter =>
    -        TaskContext.get().taskMemoryManager().allocatePage(128)
    +        TaskContext.get().taskMemoryManager().allocatePage(128, null)
             iter
           }.count()
         }
    diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
    index 1265087743a98..4a9479cf490fb 100644
    --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
    @@ -145,20 +145,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val manager = createMemoryManager(1000L)
         val taskMemoryManager = new TaskMemoryManager(manager, 0)
     
    -    assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
    -    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
    -    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
    -    assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
    -    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
    -    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
    +    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L)
    +    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
    +    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
    +    assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L)
    +    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
    +    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
     
    -    taskMemoryManager.releaseExecutionMemory(500L)
    -    assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
    -    assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
    +    taskMemoryManager.releaseExecutionMemory(500L, null)
    +    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L)
    +    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L)
     
         taskMemoryManager.cleanUpAllAllocatedMemory()
    -    assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
    -    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
    +    assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L)
    +    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
       }
     
       test("two tasks requesting full execution memory") {
    @@ -168,15 +168,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val futureTimeout: Duration = 20.seconds
     
         // Have both tasks request 500 bytes, then wait until both requests have been granted:
    -    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
    -    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
    +    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         assert(Await.result(t1Result1, futureTimeout) === 500L)
         assert(Await.result(t2Result1, futureTimeout) === 500L)
     
         // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
         // both now at 1 / N
    -    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
    -    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
    +    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         assert(Await.result(t1Result2, 200.millis) === 0L)
         assert(Await.result(t2Result2, 200.millis) === 0L)
       }
    @@ -188,15 +188,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val futureTimeout: Duration = 20.seconds
     
         // Have both tasks request 250 bytes, then wait until both requests have been granted:
    -    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
    -    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
    +    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) }
    +    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
         assert(Await.result(t1Result1, futureTimeout) === 250L)
         assert(Await.result(t2Result1, futureTimeout) === 250L)
     
         // Have both tasks each request 500 bytes more.
         // We should only grant 250 bytes to each of them on this second request
    -    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
    -    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
    +    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         assert(Await.result(t1Result2, futureTimeout) === 250L)
         assert(Await.result(t2Result2, futureTimeout) === 250L)
       }
    @@ -208,17 +208,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val futureTimeout: Duration = 20.seconds
     
         // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
    -    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
    +    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
         assert(Await.result(t1Result1, futureTimeout) === 1000L)
    -    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
    +    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
         // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
         // to make sure the other thread blocks for some time otherwise.
         Thread.sleep(300)
    -    t1MemManager.releaseExecutionMemory(250L)
    +    t1MemManager.releaseExecutionMemory(250L, null)
         // The memory freed from t1 should now be granted to t2.
         assert(Await.result(t2Result1, futureTimeout) === 250L)
         // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
    -    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
    +    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) }
         assert(Await.result(t2Result2, 200.millis) === 0L)
       }
     
    @@ -229,18 +229,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val futureTimeout: Duration = 20.seconds
     
         // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
    -    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
    +    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
         assert(Await.result(t1Result1, futureTimeout) === 1000L)
    -    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
         // to make sure the other thread blocks for some time otherwise.
         Thread.sleep(300)
         // t1 releases all of its memory, so t2 should be able to grab all of the memory
         t1MemManager.cleanUpAllAllocatedMemory()
         assert(Await.result(t2Result1, futureTimeout) === 500L)
    -    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         assert(Await.result(t2Result2, futureTimeout) === 500L)
    -    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
    +    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
         assert(Await.result(t2Result3, 200.millis) === 0L)
       }
     
    @@ -251,13 +251,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
         val t2MemManager = new TaskMemoryManager(memoryManager, 2)
         val futureTimeout: Duration = 20.seconds
     
    -    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
    +    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) }
         assert(Await.result(t1Result1, futureTimeout) === 700L)
     
    -    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
    +    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) }
         assert(Await.result(t2Result1, futureTimeout) === 300L)
     
    -    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
    +    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) }
         assert(Await.result(t1Result2, 200.millis) === 0L)
       }
     }
    diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
    similarity index 71%
    rename from core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
    rename to core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
    index fe102d8aeb2a5..77e43554ee27c 100644
    --- a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
    +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
    @@ -22,16 +22,22 @@ import scala.collection.mutable
     import org.apache.spark.SparkConf
     import org.apache.spark.storage.{BlockStatus, BlockId}
     
    -class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
    +class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
       private[memory] override def doAcquireExecutionMemory(
           numBytes: Long,
           evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
    -    if (oom) {
    -      oom = false
    +    if (oomOnce) {
    +      oomOnce = false
           0
    -    } else {
    +    } else if (available >= numBytes) {
           _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
    +      available -= numBytes
           numBytes
    +    } else {
    +      _executionMemoryUsed += available
    +      val grant = available
    +      available = 0
    +      grant
         }
       }
       override def acquireStorageMemory(
    @@ -42,13 +48,23 @@ class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf,
           blockId: BlockId,
           numBytes: Long,
           evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
    -  override def releaseStorageMemory(numBytes: Long): Unit = { }
    +  override def releaseExecutionMemory(numBytes: Long): Unit = {
    +    available += numBytes
    +    _executionMemoryUsed -= numBytes
    +  }
    +  override def releaseStorageMemory(numBytes: Long): Unit = {}
       override def maxExecutionMemory: Long = Long.MaxValue
       override def maxStorageMemory: Long = Long.MaxValue
     
    -  private var oom = false
    +  private var oomOnce = false
    +  private var available = Long.MaxValue
     
    -  def markExecutionAsOutOfMemory(): Unit = {
    -    oom = true
    +  def markExecutionAsOutOfMemoryOnce(): Unit = {
    +    oomOnce = true
       }
    +
    +  def limit(avail: Long): Unit = {
    +    available = avail
    +  }
    +
     }
    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
    index 810c74fd2fb96..f7063d1e5c829 100644
    --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
    +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
    @@ -96,15 +96,10 @@ void insertRow(UnsafeRow row) throws IOException {
         );
         numRowsInserted++;
         if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
    -      spill();
    +      sorter.spill();
         }
       }
     
    -  @VisibleForTesting
    -  void spill() throws IOException {
    -    sorter.spill();
    -  }
    -
       /**
        * Return the peak memory used so far, in bytes.
        */
    diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
    index 82c645df284de..889f97003450c 100644
    --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
    +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
    @@ -165,7 +165,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo
       public KVIterator iterator() {
         return new KVIterator() {
     
    -      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
    +      private final BytesToBytesMap.MapIterator mapLocationIterator =
             map.destructiveIterator();
           private final UnsafeRow key = new UnsafeRow();
           private final UnsafeRow value = new UnsafeRow();
    diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
    index 46301f0042954..845f2ae6859b7 100644
    --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
    +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
    @@ -17,13 +17,13 @@
     
     package org.apache.spark.sql.execution;
     
    -import java.io.IOException;
    -
     import javax.annotation.Nullable;
    +import java.io.IOException;
     
     import com.google.common.annotations.VisibleForTesting;
     
     import org.apache.spark.TaskContext;
    +import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
     import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
     import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
    @@ -33,7 +33,6 @@
     import org.apache.spark.unsafe.Platform;
     import org.apache.spark.unsafe.map.BytesToBytesMap;
     import org.apache.spark.unsafe.memory.MemoryBlock;
    -import org.apache.spark.memory.TaskMemoryManager;
     import org.apache.spark.util.collection.unsafe.sort.*;
     
     /**
    @@ -84,18 +83,16 @@ public UnsafeKVExternalSorter(
             /* initialSize */ 4096,
             pageSizeBytes);
         } else {
    -      // Insert the records into the in-memory sorter.
    -      // We will use the number of elements in the map as the initialSize of the
    -      // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
    -      // we will use 1 as its initial size if the map is empty.
    -      // TODO: track pointer array memory used by this in-memory sorter! (SPARK-10474)
    +      // The memory needed for UnsafeInMemorySorter should be less than longArray in map.
    +      map.freeArray();
    +      // The memory used by UnsafeInMemorySorter will be counted later (end of this block)
           final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
             taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
     
           // We cannot use the destructive iterator here because we are reusing the existing memory
           // pages in BytesToBytesMap to hold records during sorting.
           // The only new memory we are allocating is the pointer/prefix array.
    -      BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
    +      BytesToBytesMap.MapIterator iter = map.iterator();
           final int numKeyFields = keySchema.size();
           UnsafeRow row = new UnsafeRow();
           while (iter.hasNext()) {
    @@ -117,7 +114,7 @@ public UnsafeKVExternalSorter(
           }
     
           sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
    -        taskContext.taskMemoryManager(),
    +        taskMemoryManager,
             blockManager,
             taskContext,
             new KVComparator(ordering, keySchema.length()),
    @@ -128,6 +125,8 @@ public UnsafeKVExternalSorter(
     
           sorter.spill();
           map.free();
    +      // counting the memory used UnsafeInMemorySorter
    +      taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter);
         }
       }
     
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
    index dbf4863b767bf..a38623623a441 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
    @@ -24,7 +24,7 @@ import scala.util.{Try, Random}
     import org.scalatest.Matchers
     
     import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
    -import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
    +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
     import org.apache.spark.sql.catalyst.InternalRow
     import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
     import org.apache.spark.sql.test.SharedSQLContext
    @@ -48,7 +48,7 @@ class UnsafeFixedWidthAggregationMapSuite
       private def emptyAggregationBuffer: InternalRow = InternalRow(0)
       private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
     
    -  private var memoryManager: GrantEverythingMemoryManager = null
    +  private var memoryManager: TestMemoryManager = null
       private var taskMemoryManager: TaskMemoryManager = null
     
       def testWithMemoryLeakDetection(name: String)(f: => Unit) {
    @@ -62,7 +62,7 @@ class UnsafeFixedWidthAggregationMapSuite
     
         test(name) {
           val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
    -      memoryManager = new GrantEverythingMemoryManager(conf)
    +      memoryManager = new TestMemoryManager(conf)
           taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
     
           TaskContext.setTaskContext(new TaskContextImpl(
    @@ -193,10 +193,6 @@ class UnsafeFixedWidthAggregationMapSuite
         // Convert the map into a sorter
         val sorter = map.destructAndCreateExternalSorter()
     
    -    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
    -      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
    -    }
    -
         // Add more keys to the sorter and make sure the results come out sorted.
         val additionalKeys = randomStrings(1024)
         val keyConverter = UnsafeProjection.create(groupKeySchema)
    @@ -208,7 +204,7 @@ class UnsafeFixedWidthAggregationMapSuite
           sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
     
           if ((i % 100) == 0) {
    -        memoryManager.markExecutionAsOutOfMemory()
    +        memoryManager.markExecutionAsOutOfMemoryOnce()
             sorter.closeCurrentPage()
           }
         }
    @@ -251,7 +247,7 @@ class UnsafeFixedWidthAggregationMapSuite
           sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
     
           if ((i % 100) == 0) {
    -        memoryManager.markExecutionAsOutOfMemory()
    +        memoryManager.markExecutionAsOutOfMemoryOnce()
             sorter.closeCurrentPage()
           }
         }
    @@ -294,16 +290,12 @@ class UnsafeFixedWidthAggregationMapSuite
         // Convert the map into a sorter. Right now, it contains one record.
         val sorter = map.destructAndCreateExternalSorter()
     
    -    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
    -      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
    -    }
    -
         // Add more keys to the sorter and make sure the results come out sorted.
         (1 to 4096).foreach { i =>
           sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
     
           if ((i % 100) == 0) {
    -        memoryManager.markExecutionAsOutOfMemory()
    +        memoryManager.markExecutionAsOutOfMemoryOnce()
             sorter.closeCurrentPage()
           }
         }
    @@ -342,7 +334,7 @@ class UnsafeFixedWidthAggregationMapSuite
           buf.setInt(0, str.length)
         }
         // Simulate running out of space
    -    memoryManager.markExecutionAsOutOfMemory()
    +    memoryManager.limit(0)
         val str = rand.nextString(1024)
         val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
         assert(buf == null)
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
    index 13dc1754c9ff0..7b80963ec8708 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
    @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
     import scala.util.Random
     
     import org.apache.spark._
    -import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
    +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
     import org.apache.spark.sql.{RandomDataGenerator, Row}
     import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
     import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
    @@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
           pageSize: Long,
           spill: Boolean): Unit = {
         val memoryManager =
    -      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
    +      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
         val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
         TaskContext.setTaskContext(new TaskContextImpl(
           stageId = 0,
    @@ -128,7 +128,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
           sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
           // 1% chance we will spill
           if (rand.nextDouble() < 0.01 && spill) {
    -        memoryManager.markExecutionAsOutOfMemory()
    +        memoryManager.markExecutionAsOutOfMemoryOnce()
             sorter.closeCurrentPage()
           }
         }
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
    deleted file mode 100644
    index 475037bd45379..0000000000000
    --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
    +++ /dev/null
    @@ -1,54 +0,0 @@
    -/*
    - * Licensed to the Apache Software Foundation (ASF) under one or more
    - * contributor license agreements.  See the NOTICE file distributed with
    - * this work for additional information regarding copyright ownership.
    - * The ASF licenses this file to You under the Apache License, Version 2.0
    - * (the "License"); you may not use this file except in compliance with
    - * the License.  You may obtain a copy of the License at
    - *
    - *    http://www.apache.org/licenses/LICENSE-2.0
    - *
    - * Unless required by applicable law or agreed to in writing, software
    - * distributed under the License is distributed on an "AS IS" BASIS,
    - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    - * See the License for the specific language governing permissions and
    - * limitations under the License.
    - */
    -
    -package org.apache.spark.sql.execution.aggregate
    -
    -import org.apache.spark._
    -import org.apache.spark.memory.TaskMemoryManager
    -import org.apache.spark.sql.catalyst.expressions._
    -import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
    -import org.apache.spark.sql.execution.metric.SQLMetrics
    -import org.apache.spark.sql.test.SharedSQLContext
    -
    -class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
    -
    -  test("memory acquired on construction") {
    -    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
    -    val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
    -    TaskContext.setTaskContext(taskContext)
    -
    -    // Assert that a page is allocated before processing starts
    -    var iter: TungstenAggregationIterator = null
    -    try {
    -      val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
    -        () => new InterpretedMutableProjection(expr, schema)
    -      }
    -      val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
    -      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
    -        0, Seq.empty, newMutableProjection, Seq.empty, None,
    -        dummyAccum, dummyAccum, dummyAccum, dummyAccum)
    -      val numPages = iter.getHashMap.getNumDataPages
    -      assert(numPages === 1)
    -    } finally {
    -      // Clean up
    -      if (iter != null) {
    -        iter.free()
    -      }
    -      TaskContext.unset()
    -    }
    -  }
    -}
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    index ebe90d9e63d83..09847cec9c4ca 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
    @@ -23,6 +23,8 @@
     import java.util.LinkedList;
     import java.util.Map;
     
    +import org.apache.spark.unsafe.Platform;
    +
     /**
      * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
      */
    @@ -45,9 +47,6 @@ private boolean shouldPool(long size) {
     
       @Override
       public MemoryBlock allocate(long size) throws OutOfMemoryError {
    -    if (size % 8 != 0) {
    -      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
    -    }
         if (shouldPool(size)) {
           synchronized (this) {
             final LinkedList> pool = bufferPoolsBySize.get(size);
    @@ -64,8 +63,8 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError {
             }
           }
         }
    -    long[] array = new long[(int) (size / 8)];
    -    return MemoryBlock.fromLongArray(array);
    +    long[] array = new long[(int) ((size + 7) / 8)];
    +    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
       }
     
       @Override
    diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    index cda7826c8c99b..98ce711176e43 100644
    --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
    @@ -26,9 +26,6 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
     
       @Override
       public MemoryBlock allocate(long size) throws OutOfMemoryError {
    -    if (size % 8 != 0) {
    -      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
    -    }
         long address = Platform.allocateMemory(size);
         return new MemoryBlock(null, address, size);
       }
    
    From eb59b94c450fe6391d24d44ff7ea9bd4c6893af8 Mon Sep 17 00:00:00 2001
    From: Davies Liu 
    Date: Fri, 30 Oct 2015 00:36:20 -0700
    Subject: [PATCH 089/324] [SPARK-11417] [SQL] no @Override in codegen
    
    Older version of Janino (>2.7) does not support Override, we should not use that in codegen.
    
    Author: Davies Liu 
    
    Closes #9372 from davies/no_override.
    ---
     .../catalyst/expressions/codegen/GenerateOrdering.scala    | 1 -
     .../catalyst/expressions/codegen/GeneratePredicate.scala   | 1 -
     .../catalyst/expressions/codegen/GenerateProjection.scala  | 7 -------
     3 files changed, 9 deletions(-)
    
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
    index c2b420286f755..1af7c73cd4bf5 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
    @@ -126,7 +126,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
               ${initMutableStates(ctx)}
             }
     
    -        @Override
             public int compare(InternalRow a, InternalRow b) {
               InternalRow ${ctx.INPUT_ROW} = null;  // Holds current row being evaluated.
               $comparisons
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
    index ae6ffe6293c5d..457b4f08424a6 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
    @@ -55,7 +55,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
               ${initMutableStates(ctx)}
             }
     
    -        @Override
             public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
               ${eval.code}
               return !${eval.isNull} && ${eval.value};
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
    index dbcc9dc08408f..c0d313b2e1301 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
    @@ -82,7 +82,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
           if (cases.length > 0) {
             val getter = "get" + ctx.primitiveTypeName(jt)
             s"""
    -      @Override
           public $jt $getter(int i) {
             if (isNullAt(i)) {
               return ${ctx.defaultValue(jt)};
    @@ -107,7 +106,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
           if (cases.length > 0) {
             val setter = "set" + ctx.primitiveTypeName(jt)
             s"""
    -      @Override
           public void $setter(int i, $jt value) {
             nullBits[i] = false;
             switch (i) {
    @@ -169,7 +167,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
             ${initMutableStates(ctx)}
           }
     
    -      @Override
           public Object apply(Object r) {
             // GenerateProjection does not work with UnsafeRows.
             assert(!(r instanceof ${classOf[UnsafeRow].getName}));
    @@ -189,7 +186,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
             public void setNullAt(int i) { nullBits[i] = true; }
             public boolean isNullAt(int i) { return nullBits[i]; }
     
    -        @Override
             public Object genericGet(int i) {
               if (isNullAt(i)) return null;
               switch (i) {
    @@ -210,14 +206,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
             $specificAccessorFunctions
             $specificMutatorFunctions
     
    -        @Override
             public int hashCode() {
               int result = 37;
               $hashUpdates
               return result;
             }
     
    -        @Override
             public boolean equals(Object other) {
               if (other instanceof SpecificRow) {
                 SpecificRow row = (SpecificRow) other;
    @@ -227,7 +221,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
               return super.equals(other);
             }
     
    -        @Override
             public InternalRow copy() {
               Object[] arr = new Object[${expressions.length}];
               ${copyColumns}
    
    From 86d65265fcab7edab88a7bdb10acba47da95bcb3 Mon Sep 17 00:00:00 2001
    From: Lewuathe 
    Date: Fri, 30 Oct 2015 02:59:05 -0700
    Subject: [PATCH 090/324] =?UTF-8?q?[SPARK-11207]=20[ML]=20Add=20test=20cas?=
     =?UTF-8?q?es=20for=20solver=20selection=20of=20LinearRegres=E2=80=A6?=
    MIME-Version: 1.0
    Content-Type: text/plain; charset=UTF-8
    Content-Transfer-Encoding: 8bit
    
    …sion as followup. This is the follow up work of SPARK-10668.
    
    * Fix miner style issues.
    * Add test case for checking whether solver is selected properly.
    
    Author: Lewuathe 
    Author: lewuathe 
    
    Closes #9180 from Lewuathe/SPARK-11207.
    ---
     .../mllib/util/LinearDataGenerator.scala      |  54 +++++-
     .../ml/regression/LinearRegressionSuite.scala | 172 ++++++++++--------
     2 files changed, 144 insertions(+), 82 deletions(-)
    
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    index d0ba454f379a9..6ff07eed6cfd2 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
    @@ -77,13 +77,11 @@ object LinearDataGenerator {
           nPoints: Int,
           seed: Int,
           eps: Double = 0.1): Seq[LabeledPoint] = {
    -    generateLinearInput(intercept, weights,
    -      Array.fill[Double](weights.length)(0.0),
    -      Array.fill[Double](weights.length)(1.0 / 3.0),
    -      nPoints, seed, eps)}
    +    generateLinearInput(intercept, weights, Array.fill[Double](weights.length)(0.0),
    +      Array.fill[Double](weights.length)(1.0 / 3.0), nPoints, seed, eps)
    +  }
     
       /**
    -   *
        * @param intercept Data intercept
        * @param weights  Weights to be applied.
        * @param xMean the mean of the generated features. Lots of time, if the features are not properly
    @@ -104,16 +102,49 @@ object LinearDataGenerator {
           nPoints: Int,
           seed: Int,
           eps: Double): Seq[LabeledPoint] = {
    +    generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps, 0.0)
    +  }
    +
     
    +  /**
    +   * @param intercept Data intercept
    +   * @param weights  Weights to be applied.
    +   * @param xMean the mean of the generated features. Lots of time, if the features are not properly
    +   *              standardized, the algorithm with poor implementation will have difficulty
    +   *              to converge.
    +   * @param xVariance the variance of the generated features.
    +   * @param nPoints Number of points in sample.
    +   * @param seed Random seed
    +   * @param eps Epsilon scaling factor.
    +   * @param sparsity The ratio of zero elements. If it is 0.0, LabeledPoints with
    +   *                 DenseVector is returned.
    +   * @return Seq of input.
    +   */
    +  @Since("1.6.0")
    +  def generateLinearInput(
    +      intercept: Double,
    +      weights: Array[Double],
    +      xMean: Array[Double],
    +      xVariance: Array[Double],
    +      nPoints: Int,
    +      seed: Int,
    +      eps: Double,
    +      sparsity: Double): Seq[LabeledPoint] = {
    +    require(0.0 <= sparsity && sparsity <= 1.0)
         val rnd = new Random(seed)
         val x = Array.fill[Array[Double]](nPoints)(
           Array.fill[Double](weights.length)(rnd.nextDouble()))
     
    +    val sparseRnd = new Random(seed)
         x.foreach { v =>
           var i = 0
           val len = v.length
           while (i < len) {
    -        v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
    +        if (sparseRnd.nextDouble() < sparsity) {
    +          v(i) = 0.0
    +        } else {
    +          v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
    +        }
             i += 1
           }
         }
    @@ -121,7 +152,16 @@ object LinearDataGenerator {
         val y = x.map { xi =>
           blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()
         }
    -    y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
    +
    +    y.zip(x).map { p =>
    +      if (sparsity == 0.0) {
    +        // Return LabeledPoints with DenseVector
    +        LabeledPoint(p._1, Vectors.dense(p._2))
    +      } else {
    +        // Return LabeledPoints with SparseVector
    +        LabeledPoint(p._1, Vectors.dense(p._2).toSparse)
    +      }
    +    }
       }
     
       /**
    diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    index a6e0c72ba9030..a2a5c0bbdcb90 100644
    --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
    @@ -32,8 +32,9 @@ import org.apache.spark.sql.{DataFrame, Row}
     class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       private val seed: Int = 42
    -  @transient var dataset: DataFrame = _
    -  @transient var datasetWithoutIntercept: DataFrame = _
    +  @transient var datasetWithDenseFeature: DataFrame = _
    +  @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
    +  @transient var datasetWithSparseFeature: DataFrame = _
     
       /*
          In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
    @@ -49,16 +50,29 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
        */
       override def beforeAll(): Unit = {
         super.beforeAll()
    -    dataset = sqlContext.createDataFrame(
    +    datasetWithDenseFeature = sqlContext.createDataFrame(
           sc.parallelize(LinearDataGenerator.generateLinearInput(
    -        6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2))
    +        intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
    +        xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
         /*
            datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
            training model without intercept
          */
    -    datasetWithoutIntercept = sqlContext.createDataFrame(
    +    datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame(
           sc.parallelize(LinearDataGenerator.generateLinearInput(
    -        0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, seed, 0.1), 2))
    +        intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
    +        xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2))
    +
    +    val r = new Random(seed)
    +    // When feature size is larger than 4096, normal optimizer is choosed
    +    // as the solver of linear regression in the case of "auto" mode.
    +    val featureSize = 4100
    +    datasetWithSparseFeature = sqlContext.createDataFrame(
    +      sc.parallelize(LinearDataGenerator.generateLinearInput(
    +        intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray,
    +        xMean = Seq.fill(featureSize)(r.nextDouble).toArray,
    +        xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200,
    +        seed, eps = 0.1, sparsity = 0.7), 2))
       }
     
       test("params") {
    @@ -77,19 +91,19 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(lir.getFitIntercept)
         assert(lir.getStandardization)
         assert(lir.getSolver == "auto")
    -    val model = lir.fit(dataset)
    +    val model = lir.fit(datasetWithDenseFeature)
     
         // copied model must have the same parent.
         MLTestingUtils.checkCopy(model)
     
    -    model.transform(dataset)
    +    model.transform(datasetWithDenseFeature)
           .select("label", "prediction")
           .collect()
         assert(model.getFeaturesCol === "features")
         assert(model.getPredictionCol === "prediction")
         assert(model.intercept !== 0.0)
         assert(model.hasParent)
    -    val numFeatures = dataset.select("features").first().getAs[Vector](0).size
    +    val numFeatures = datasetWithDenseFeature.select("features").first().getAs[Vector](0).size
         assert(model.numFeatures === numFeatures)
       }
     
    @@ -98,8 +112,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           val trainer1 = new LinearRegression().setSolver(solver)
           // The result should be the same regardless of standardization without regularization
           val trainer2 = (new LinearRegression).setStandardization(false).setSolver(solver)
    -      val model1 = trainer1.fit(dataset)
    -      val model2 = trainer2.fit(dataset)
    +      val model1 = trainer1.fit(datasetWithDenseFeature)
    +      val model2 = trainer2.fit(datasetWithDenseFeature)
     
           /*
              Using the following R code to load the data and train the model using glmnet package.
    @@ -124,7 +138,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model2.intercept ~== interceptR relTol 1E-3)
           assert(model2.weights ~= weightsR relTol 1E-3)
     
    -      model1.transform(dataset).select("features", "prediction").collect().foreach {
    +      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
                 features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    @@ -139,10 +153,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           // Without regularization the results should be the same
           val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
             .setSolver(solver)
    -      val model1 = trainer1.fit(dataset)
    -      val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
    -      val model2 = trainer2.fit(dataset)
    -      val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
    +      val model1 = trainer1.fit(datasetWithDenseFeature)
    +      val modelWithoutIntercept1 = trainer1.fit(datasetWithDenseFeatureWithoutIntercept)
    +      val model2 = trainer2.fit(datasetWithDenseFeature)
    +      val modelWithoutIntercept2 = trainer2.fit(datasetWithDenseFeatureWithoutIntercept)
     
           /*
              weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
    @@ -186,19 +200,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
             .setSolver(solver).setStandardization(false)
     
    -      var model1: LinearRegressionModel = null
    -      var model2: LinearRegressionModel = null
    -
           // Normal optimizer is not supported with only L1 regularization case.
           if (solver == "normal") {
             intercept[IllegalArgumentException] {
    -            trainer1.fit(dataset)
    -            trainer2.fit(dataset)
    +            trainer1.fit(datasetWithDenseFeature)
    +            trainer2.fit(datasetWithDenseFeature)
               }
           } else {
    -        model1 = trainer1.fit(dataset)
    -        model2 = trainer2.fit(dataset)
    -
    +        val model1 = trainer1.fit(datasetWithDenseFeature)
    +        val model2 = trainer2.fit(datasetWithDenseFeature)
     
             /*
                weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
    @@ -230,11 +240,12 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             assert(model2.intercept ~== interceptR2 relTol 1E-3)
             assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -        model1.transform(dataset).select("features", "prediction").collect().foreach {
    -          case Row(features: DenseVector, prediction1: Double) =>
    -            val prediction2 =
    -              features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    -            assert(prediction1 ~== prediction2 relTol 1E-5)
    +        model1.transform(datasetWithDenseFeature).select("features", "prediction")
    +          .collect().foreach {
    +            case Row(features: DenseVector, prediction1: Double) =>
    +              val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) +
    +                model1.intercept
    +              assert(prediction1 ~== prediction2 relTol 1E-5)
             }
           }
         }
    @@ -247,18 +258,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
             .setFitIntercept(false).setStandardization(false).setSolver(solver)
     
    -      var model1: LinearRegressionModel = null
    -      var model2: LinearRegressionModel = null
    -
           // Normal optimizer is not supported with only L1 regularization case.
           if (solver == "normal") {
             intercept[IllegalArgumentException] {
    -            trainer1.fit(dataset)
    -            trainer2.fit(dataset)
    +            trainer1.fit(datasetWithDenseFeature)
    +            trainer2.fit(datasetWithDenseFeature)
               }
           } else {
    -        model1 = trainer1.fit(dataset)
    -        model2 = trainer2.fit(dataset)
    +        val model1 = trainer1.fit(datasetWithDenseFeature)
    +        val model2 = trainer2.fit(datasetWithDenseFeature)
     
             /*
                weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
    @@ -292,11 +300,12 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             assert(model2.intercept ~== interceptR2 absTol 1E-3)
             assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -        model1.transform(dataset).select("features", "prediction").collect().foreach {
    -          case Row(features: DenseVector, prediction1: Double) =>
    -            val prediction2 =
    -              features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    -            assert(prediction1 ~== prediction2 relTol 1E-5)
    +        model1.transform(datasetWithDenseFeature).select("features", "prediction")
    +          .collect().foreach {
    +            case Row(features: DenseVector, prediction1: Double) =>
    +              val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) +
    +                model1.intercept
    +              assert(prediction1 ~== prediction2 relTol 1E-5)
             }
           }
         }
    @@ -308,8 +317,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setSolver(solver)
           val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
             .setStandardization(false).setSolver(solver)
    -      val model1 = trainer1.fit(dataset)
    -      val model2 = trainer2.fit(dataset)
    +      val model1 = trainer1.fit(datasetWithDenseFeature)
    +      val model2 = trainer2.fit(datasetWithDenseFeature)
     
           /*
              weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
    @@ -342,7 +351,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model2.intercept ~== interceptR2 relTol 1E-3)
           assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -      model1.transform(dataset).select("features", "prediction").collect().foreach {
    +      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
                 features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    @@ -357,8 +366,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             .setFitIntercept(false).setSolver(solver)
           val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
             .setFitIntercept(false).setStandardization(false).setSolver(solver)
    -      val model1 = trainer1.fit(dataset)
    -      val model2 = trainer2.fit(dataset)
    +      val model1 = trainer1.fit(datasetWithDenseFeature)
    +      val model2 = trainer2.fit(datasetWithDenseFeature)
     
           /*
              weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
    @@ -392,7 +401,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model2.intercept ~== interceptR2 absTol 1E-3)
           assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -      model1.transform(dataset).select("features", "prediction").collect().foreach {
    +      model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach {
             case Row(features: DenseVector, prediction1: Double) =>
               val prediction2 =
                 features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    @@ -408,18 +417,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
             .setStandardization(false).setSolver(solver)
     
    -      var model1: LinearRegressionModel = null
    -      var model2: LinearRegressionModel = null
    -
           // Normal optimizer is not supported with non-zero elasticnet parameter.
           if (solver == "normal") {
             intercept[IllegalArgumentException] {
    -            trainer1.fit(dataset)
    -            trainer2.fit(dataset)
    +            trainer1.fit(datasetWithDenseFeature)
    +            trainer2.fit(datasetWithDenseFeature)
               }
           } else {
    -        model1 = trainer1.fit(dataset)
    -        model2 = trainer2.fit(dataset)
    +        val model1 = trainer1.fit(datasetWithDenseFeature)
    +        val model2 = trainer2.fit(datasetWithDenseFeature)
     
             /*
                weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
    @@ -452,10 +458,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             assert(model2.intercept ~== interceptR2 relTol 1E-3)
             assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -        model1.transform(dataset).select("features", "prediction").collect().foreach {
    +        model1.transform(datasetWithDenseFeature).select("features", "prediction")
    +          .collect().foreach {
               case Row(features: DenseVector, prediction1: Double) =>
    -            val prediction2 =
    -              features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    +            val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) +
    +              model1.intercept
                 assert(prediction1 ~== prediction2 relTol 1E-5)
             }
           }
    @@ -469,18 +476,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
             .setFitIntercept(false).setStandardization(false).setSolver(solver)
     
    -      var model1: LinearRegressionModel = null
    -      var model2: LinearRegressionModel = null
    -
           // Normal optimizer is not supported with non-zero elasticnet parameter.
           if (solver == "normal") {
             intercept[IllegalArgumentException] {
    -            trainer1.fit(dataset)
    -            trainer2.fit(dataset)
    +            trainer1.fit(datasetWithDenseFeature)
    +            trainer2.fit(datasetWithDenseFeature)
               }
           } else {
    -        model1 = trainer1.fit(dataset)
    -        model2 = trainer2.fit(dataset)
    +        val model1 = trainer1.fit(datasetWithDenseFeature)
    +        val model2 = trainer2.fit(datasetWithDenseFeature)
     
             /*
                weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
    @@ -514,10 +518,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
             assert(model2.intercept ~== interceptR2 absTol 1E-3)
             assert(model2.weights ~= weightsR2 relTol 1E-3)
     
    -        model1.transform(dataset).select("features", "prediction").collect().foreach {
    +        model1.transform(datasetWithDenseFeature).select("features", "prediction")
    +          .collect().foreach {
               case Row(features: DenseVector, prediction1: Double) =>
    -            val prediction2 =
    -              features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
    +            val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) +
    +              model1.intercept
                 assert(prediction1 ~== prediction2 relTol 1E-5)
             }
           }
    @@ -527,27 +532,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("linear regression model training summary") {
         Seq("auto", "l-bfgs", "normal").foreach { solver =>
           val trainer = new LinearRegression().setSolver(solver)
    -      val model = trainer.fit(dataset)
    +      val model = trainer.fit(datasetWithDenseFeature)
           val trainerNoPredictionCol = trainer.setPredictionCol("")
    -      val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset)
    -
    +      val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature)
     
           // Training results for the model should be available
           assert(model.hasSummary)
           assert(modelNoPredictionCol.hasSummary)
     
           // Schema should be a superset of the input dataset
    -      assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf(
    +      assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf(
             model.summary.predictions.schema.fieldNames.toSet))
           // Validate that we re-insert a prediction column for evaluation
           val modelNoPredictionColFieldNames
           = modelNoPredictionCol.summary.predictions.schema.fieldNames
    -      assert((dataset.schema.fieldNames.toSet).subsetOf(
    +      assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf(
             modelNoPredictionColFieldNames.toSet))
           assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
     
           // Residuals in [[LinearRegressionResults]] should equal those manually computed
    -      val expectedResiduals = dataset.select("features", "label")
    +      val expectedResiduals = datasetWithDenseFeature.select("features", "label")
             .map { case Row(features: DenseVector, label: Double) =>
             val prediction =
               features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
    @@ -585,6 +589,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
                 .objectiveHistory
                 .sliding(2)
                 .forall(x => x(0) >= x(1)))
    +      } else {
    +        // To clalify that the normal solver is used here.
    +        assert(model.summary.objectiveHistory.length == 1)
    +        assert(model.summary.objectiveHistory(0) == 0.0)
           }
         }
       }
    @@ -592,10 +600,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       test("linear regression model testset evaluation summary") {
         Seq("auto", "l-bfgs", "normal").foreach { solver =>
           val trainer = new LinearRegression().setSolver(solver)
    -      val model = trainer.fit(dataset)
    +      val model = trainer.fit(datasetWithDenseFeature)
     
           // Evaluating on training dataset should yield results summary equal to training summary
    -      val testSummary = model.evaluate(dataset)
    +      val testSummary = model.evaluate(datasetWithDenseFeature)
           assert(model.summary.meanSquaredError ~== testSummary.meanSquaredError relTol 1E-5)
           assert(model.summary.r2 ~== testSummary.r2 relTol 1E-5)
           model.summary.residuals.select("residuals").collect()
    @@ -693,4 +701,18 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
           assert(model4a0.weights ~== model4b.weights absTol 1E-3)
         }
       }
    +
    +  test("linear regression model with l-bfgs with big feature datasets") {
    +    val trainer = new LinearRegression().setSolver("auto")
    +    val model = trainer.fit(datasetWithSparseFeature)
    +
    +    // Training results for the model should be available
    +    assert(model.hasSummary)
    +    // When LBFGS is used as optimizer, objective history can be restored.
    +    assert(
    +      model.summary
    +        .objectiveHistory
    +        .sliding(2)
    +        .forall(x => x(0) >= x(1)))
    +  }
     }
    
    From 59db9e9c382fab40aac0633f2c779bee8cf2025f Mon Sep 17 00:00:00 2001
    From: hyukjinkwon 
    Date: Fri, 30 Oct 2015 18:17:35 +0800
    Subject: [PATCH 091/324] [SPARK-11103][SQL] Filter applied on Merged Parquet
     shema with new column fail
    
    When enabling mergedSchema and predicate filter, this fails since Parquet does not accept filters pushed down when the columns of the filters do not exist in the schema.
    This is related with Parquet issue (https://issues.apache.org/jira/browse/PARQUET-389).
    
    For now, it just simply disables predicate push down when using merged schema in this PR.
    
    Author: hyukjinkwon 
    
    Closes #9327 from HyukjinKwon/SPARK-11103.
    ---
     .../datasources/parquet/ParquetRelation.scala |  6 +++++-
     .../parquet/ParquetFilterSuite.scala          | 20 +++++++++++++++++++
     2 files changed, 25 insertions(+), 1 deletion(-)
    
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
    index 77d851ca486b3..44649a68b3c9b 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
    @@ -292,6 +292,10 @@ private[sql] class ParquetRelation(
         val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
         val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
     
    +    // When merging schemas is enabled and the column of the given filter does not exist,
    +    // Parquet emits an exception which is an issue of Parquet (PARQUET-389).
    +    val safeParquetFilterPushDown = !shouldMergeSchemas && parquetFilterPushDown
    +
         // Parquet row group size. We will use this value as the value for
         // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value
         // of these flags are smaller than the parquet row group size.
    @@ -305,7 +309,7 @@ private[sql] class ParquetRelation(
             dataSchema,
             parquetBlockSize,
             useMetadataCache,
    -        parquetFilterPushDown,
    +        safeParquetFilterPushDown,
             assumeBinaryIsString,
             assumeInt96IsTimestamp) _
     
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
    index 13fdd555a4c71..b2101beb92224 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
    @@ -316,4 +316,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
           }
         }
       }
    +
    +  test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") {
    +    import testImplicits._
    +
    +    withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
    +      SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") {
    +      withTempPath { dir =>
    +        var pathOne = s"${dir.getCanonicalPath}/table1"
    +        (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne)
    +        var pathTwo = s"${dir.getCanonicalPath}/table2"
    +        (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo)
    +
    +        // If the "c = 1" filter gets pushed down, this query will throw an exception which
    +        // Parquet emits. This is a Parquet issue (PARQUET-389).
    +        checkAnswer(
    +          sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1"),
    +          (1 to 1).map(i => Row(i, i.toString, null)))
    +      }
    +    }
    +  }
     }
    
    From 14d08b99085d4e609aeae0cf54d4584e860eb552 Mon Sep 17 00:00:00 2001
    From: Wenchen Fan 
    Date: Fri, 30 Oct 2015 12:17:51 +0100
    Subject: [PATCH 092/324] [SPARK-11393] [SQL] CoGroupedIterator should respect
     the fact that GroupedIterator.hasNext is not idempotent
    
    When we cogroup 2 `GroupedIterator`s in `CoGroupedIterator`, if the right side is smaller, we will consume right data and keep the left data unchanged. Then we call `hasNext` which will call `left.hasNext`. This will make `GroupedIterator` generate an extra group as the previous one has not been comsumed yet.
    
    Author: Wenchen Fan 
    
    Closes #9346 from cloud-fan/cogroup and squashes the following commits:
    
    9be67c8 [Wenchen Fan] SPARK-11393
    ---
     .../sql/execution/CoGroupedIterator.scala     | 14 ++++++-----
     .../execution/CoGroupedIteratorSuite.scala    | 24 +++++++++++++++++++
     2 files changed, 32 insertions(+), 6 deletions(-)
    
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
    index ce5827855e4aa..663bc904f39c8 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
    @@ -38,17 +38,19 @@ class CoGroupedIterator(
       private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
       private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
     
    -  override def hasNext: Boolean = left.hasNext || right.hasNext
    -
    -  override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
    -    if (currentLeftData.eq(null) && left.hasNext) {
    +  override def hasNext: Boolean = {
    +    if (currentLeftData == null && left.hasNext) {
           currentLeftData = left.next()
         }
    -    if (currentRightData.eq(null) && right.hasNext) {
    +    if (currentRightData == null && right.hasNext) {
           currentRightData = right.next()
         }
     
    -    assert(currentLeftData.ne(null) || currentRightData.ne(null))
    +    currentLeftData != null || currentRightData != null
    +  }
    +
    +  override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
    +    assert(hasNext)
     
         if (currentLeftData.eq(null)) {
           // left is null, right is not null, consume the right data.
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
    index d1fe81947e9ea..4ff96e6574cac 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
    @@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
           Nil
         )
       }
    +
    +  test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
    +    val leftInput = Seq(create_row(2, "a")).iterator
    +    val rightInput = Seq(create_row(1, 2L)).iterator
    +    val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
    +    val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
    +    val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
    +
    +    val result = cogrouped.map {
    +      case (key, leftData, rightData) =>
    +        assert(key.numFields == 1)
    +        (key.getInt(0), leftData.toSeq, rightData.toSeq)
    +    }.toSeq
    +
    +    assert(result ==
    +      (1,
    +        Seq.empty,
    +        Seq(create_row(1, 2L))) ::
    +      (2,
    +        Seq(create_row(2, "a")),
    +        Seq.empty) ::
    +      Nil
    +    )
    +  }
     }
    
    From 0451b00148a294c665146563242d2fe2de943a02 Mon Sep 17 00:00:00 2001
    From: Iulian Dragos 
    Date: Fri, 30 Oct 2015 16:51:32 +0000
    Subject: [PATCH 093/324] [SPARK-10986][MESOS] Set the context class loader in
     the Mesos executor backend.
    
    See [SPARK-10986](https://issues.apache.org/jira/browse/SPARK-10986) for details.
    
    This fixes the `ClassNotFoundException` for Spark classes in the serializer.
    
    I am not sure this is the right way to handle the class loader, but I couldn't find any documentation on how the context class loader is used and who relies on it. It seems at least the serializer uses it to instantiate classes during deserialization.
    
    I am open to suggestions (I tried this fix on a real Mesos cluster and it *does* fix the issue).
    
    tnachen andrewor14
    
    Author: Iulian Dragos 
    
    Closes #9282 from dragos/issue/mesos-classloader.
    ---
     .../org/apache/spark/executor/MesosExecutorBackend.scala     | 5 +++++
     1 file changed, 5 insertions(+)
    
    diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
    index 0474fd2ccc12e..c9f18ebc7f0ea 100644
    --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
    +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
    @@ -63,6 +63,11 @@ private[spark] class MesosExecutorBackend
     
         logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus")
         this.driver = driver
    +    // Set a context class loader to be picked up by the serializer. Without this call
    +    // the serializer would default to the null class loader, and fail to find Spark classes
    +    // See SPARK-10986.
    +    Thread.currentThread().setContextClassLoader(this.getClass.getClassLoader)
    +
         val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
           Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
         val conf = new SparkConf(loadDefaults = true).setAll(properties)
    
    From fab710a9171932f01ac81d100db8523dbd314925 Mon Sep 17 00:00:00 2001
    From: Sun Rui 
    Date: Fri, 30 Oct 2015 10:51:11 -0700
    Subject: [PATCH 094/324] [SPARK-11414][SPARKR] Forgot to update usage of
     'spark.sparkr.r.command' in RRDD in the PR for SPARK-10971.
    
    Author: Sun Rui 
    
    Closes #9368 from sun-rui/SPARK-11414.
    ---
     core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 7 ++++++-
     1 file changed, 6 insertions(+), 1 deletion(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
    index 9d5bbb5d609f3..6b418e908cb53 100644
    --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
    +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
    @@ -392,7 +392,12 @@ private[r] object RRDD {
       }
     
       private def createRProcess(port: Int, script: String): BufferedStreamThread = {
    -    val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
    +    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
    +    // but kept here for backward compatibility.
    +    val sparkConf = SparkEnv.get.conf
    +    var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
    +    rCommand = sparkConf.get("spark.r.command", rCommand)
    +
         val rOptions = "--vanilla"
         val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
         val rExecScript = rLibDir + "/SparkR/worker/" + script
    
    From 40c77fb23a1ee0a4e69d735ee6247f83b7e13b92 Mon Sep 17 00:00:00 2001
    From: Sun Rui 
    Date: Fri, 30 Oct 2015 10:56:06 -0700
    Subject: [PATCH 095/324] [SPARK-11210][SPARKR] Add window functions into
     SparkR [step 2].
    
    Author: Sun Rui 
    
    Closes #9196 from sun-rui/SPARK-11210.
    ---
     R/pkg/NAMESPACE                  |  4 ++
     R/pkg/R/functions.R              | 92 ++++++++++++++++++++++++++++++++
     R/pkg/R/generics.R               | 16 ++++++
     R/pkg/inst/tests/test_sparkSQL.R |  5 ++
     4 files changed, 117 insertions(+)
    
    diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
    index b73bed3128242..cd9537a2655f0 100644
    --- a/R/pkg/NAMESPACE
    +++ b/R/pkg/NAMESPACE
    @@ -126,6 +126,7 @@ exportMethods("%in%",
                   "datediff",
                   "dayofmonth",
                   "dayofyear",
    +              "denseRank",
                   "desc",
                   "endsWith",
                   "exp",
    @@ -182,16 +183,19 @@ exportMethods("%in%",
                   "next_day",
                   "ntile",
                   "otherwise",
    +              "percentRank",
                   "pmod",
                   "quarter",
                   "rand",
                   "randn",
    +              "rank",
                   "regexp_extract",
                   "regexp_replace",
                   "reverse",
                   "rint",
                   "rlike",
                   "round",
    +              "rowNumber",
                   "rpad",
                   "rtrim",
                   "second",
    diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
    index 366290fe66276..d7fd279279137 100644
    --- a/R/pkg/R/functions.R
    +++ b/R/pkg/R/functions.R
    @@ -2038,6 +2038,28 @@ setMethod("cumeDist",
                 column(jc)
               })
     
    +#' denseRank
    +#' 
    +#' Window function: returns the rank of rows within a window partition, without any gaps.
    +#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking
    +#' sequence when there are ties. That is, if you were ranking a competition using denseRank
    +#' and had three people tie for second place, you would say that all three were in second
    +#' place and that the next person came in third.
    +#' 
    +#' This is equivalent to the DENSE_RANK function in SQL.
    +#'
    +#' @rdname denseRank
    +#' @name denseRank
    +#' @family window_funcs
    +#' @export
    +#' @examples \dontrun{denseRank()}
    +setMethod("denseRank",
    +          signature(x = "missing"),
    +          function() {
    +            jc <- callJStatic("org.apache.spark.sql.functions", "denseRank")
    +            column(jc)
    +          })
    +
     #' lag
     #'
     #' Window function: returns the value that is `offset` rows before the current row, and
    @@ -2111,3 +2133,73 @@ setMethod("ntile",
                 jc <- callJStatic("org.apache.spark.sql.functions", "ntile", as.integer(x))
                 column(jc)
               })
    +
    +#' percentRank
    +#'
    +#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
    +#' 
    +#' This is computed by:
    +#' 
    +#'   (rank of row in its partition - 1) / (number of rows in the partition - 1)
    +#'
    +#' This is equivalent to the PERCENT_RANK function in SQL.
    +#'
    +#' @rdname percentRank
    +#' @name percentRank
    +#' @family window_funcs
    +#' @export
    +#' @examples \dontrun{percentRank()}
    +setMethod("percentRank",
    +          signature(x = "missing"),
    +          function() {
    +            jc <- callJStatic("org.apache.spark.sql.functions", "percentRank")
    +            column(jc)
    +          })
    +
    +#' rank
    +#'
    +#' Window function: returns the rank of rows within a window partition.
    +#' 
    +#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking
    +#' sequence when there are ties. That is, if you were ranking a competition using denseRank
    +#' and had three people tie for second place, you would say that all three were in second
    +#' place and that the next person came in third.
    +#' 
    +#' This is equivalent to the RANK function in SQL.
    +#'
    +#' @rdname rank
    +#' @name rank
    +#' @family window_funcs
    +#' @export
    +#' @examples \dontrun{rank()}
    +setMethod("rank",
    +          signature(x = "missing"),
    +          function() {
    +            jc <- callJStatic("org.apache.spark.sql.functions", "rank")
    +            column(jc)
    +          })
    +
    +# Expose rank() in the R base package
    +setMethod("rank",
    +          signature(x = "ANY"),
    +          function(x, ...) {
    +            base::rank(x, ...)
    +          })
    +
    +#' rowNumber
    +#'
    +#' Window function: returns a sequential number starting at 1 within a window partition.
    +#' 
    +#' This is equivalent to the ROW_NUMBER function in SQL.
    +#'
    +#' @rdname rowNumber
    +#' @name rowNumber
    +#' @family window_funcs
    +#' @export
    +#' @examples \dontrun{rowNumber()}
    +setMethod("rowNumber",
    +          signature(x = "missing"),
    +          function() {
    +            jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber")
    +            column(jc)
    +          })
    diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
    index c11c3c8d3e150..0b35340e48e42 100644
    --- a/R/pkg/R/generics.R
    +++ b/R/pkg/R/generics.R
    @@ -742,6 +742,10 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") })
     #' @export
     setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
     
    +#' @rdname denseRank
    +#' @export
    +setGeneric("denseRank", function(x) { standardGeneric("denseRank") })
    +
     #' @rdname explode
     #' @export
     setGeneric("explode", function(x) { standardGeneric("explode") })
    @@ -878,6 +882,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
     #' @export
     setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
     
    +#' @rdname percentRank
    +#' @export
    +setGeneric("percentRank", function(x) { standardGeneric("percentRank") })
    +
     #' @rdname pmod
     #' @export
     setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
    @@ -894,6 +902,10 @@ setGeneric("rand", function(seed) { standardGeneric("rand") })
     #' @export
     setGeneric("randn", function(seed) { standardGeneric("randn") })
     
    +#' @rdname rank
    +#' @export
    +setGeneric("rank", function(x, ...) { standardGeneric("rank") })
    +
     #' @rdname regexp_extract
     #' @export
     setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") })
    @@ -911,6 +923,10 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") })
     #' @export
     setGeneric("rint", function(x, ...) { standardGeneric("rint") })
     
    +#' @rdname rowNumber
    +#' @export
    +setGeneric("rowNumber", function(x) { standardGeneric("rowNumber") })
    +
     #' @rdname rpad
     #' @export
     setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") })
    diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
    index e1d4499925fe7..b4a4d03b2643b 100644
    --- a/R/pkg/inst/tests/test_sparkSQL.R
    +++ b/R/pkg/inst/tests/test_sparkSQL.R
    @@ -831,6 +831,11 @@ test_that("column functions", {
       c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c)
       c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1)
       c13 <- cumeDist() + ntile(1)
    +  c14 <- denseRank() + percentRank() + rank() + rowNumber()
    +
    +  # Test if base::rank() is exposed
    +  expect_equal(class(rank())[[1]], "Column")
    +  expect_equal(rank(1:3), as.numeric(c(1:3)))
     
       df <- jsonFile(sqlContext, jsonPath)
       df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20)))
    
    From 729f983e66cf65da2e8f48c463ccde2b355240c4 Mon Sep 17 00:00:00 2001
    From: Jeff Zhang 
    Date: Fri, 30 Oct 2015 18:50:12 +0000
    Subject: [PATCH 096/324] =?UTF-8?q?[SPARK-11342][TESTS]=20Allow=20to=20set?=
     =?UTF-8?q?=20hadoop=20profile=20when=20running=20dev/ru=E2=80=A6?=
    MIME-Version: 1.0
    Content-Type: text/plain; charset=UTF-8
    Content-Transfer-Encoding: 8bit
    
    …n_tests
    
    Author: Jeff Zhang 
    
    Closes #9295 from zjffdu/SPARK-11342.
    ---
     dev/run-tests.py | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/dev/run-tests.py b/dev/run-tests.py
    index 6b4b71073453d..9e1abb0697192 100755
    --- a/dev/run-tests.py
    +++ b/dev/run-tests.py
    @@ -486,7 +486,7 @@ def main():
         else:
             # else we're running locally and can use local settings
             build_tool = "sbt"
    -        hadoop_version = "hadoop2.3"
    +        hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop2.3")
             test_env = "local"
     
         print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version,
    
    From bb5a2af034196620d869fc9b1a400e014e718b8c Mon Sep 17 00:00:00 2001
    From: felixcheung 
    Date: Fri, 30 Oct 2015 13:51:32 -0700
    Subject: [PATCH 097/324] [SPARK-11340][SPARKR] Support setting driver
     properties when starting Spark from R programmatically or from RStudio
    
    Mapping spark.driver.memory from sparkEnvir to spark-submit commandline arguments.
    
    shivaram suggested that we possibly add other spark.driver.* properties - do we want to add all of those? I thought those could be set in SparkConf?
    sun-rui
    
    Author: felixcheung 
    
    Closes #9290 from felixcheung/rdrivermem.
    ---
     R/pkg/R/sparkR.R                | 45 +++++++++++++++++++++++++++++----
     R/pkg/inst/tests/test_context.R | 27 ++++++++++++++++++++
     docs/sparkr.md                  | 28 ++++++++++++++------
     3 files changed, 87 insertions(+), 13 deletions(-)
    
    diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
    index 043b0057bd04a..004d08e74e1cd 100644
    --- a/R/pkg/R/sparkR.R
    +++ b/R/pkg/R/sparkR.R
    @@ -77,7 +77,9 @@ sparkR.stop <- function() {
     
     #' Initialize a new Spark Context.
     #'
    -#' This function initializes a new SparkContext.
    +#' This function initializes a new SparkContext. For details on how to initialize
    +#' and use SparkR, refer to SparkR programming guide at 
    +#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}.
     #'
     #' @param master The Spark master URL.
     #' @param appName Application name to register with cluster manager
    @@ -93,7 +95,7 @@ sparkR.stop <- function() {
     #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark",
     #'                  list(spark.executor.memory="1g"))
     #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark",
    -#'                  list(spark.executor.memory="1g"),
    +#'                  list(spark.executor.memory="4g"),
     #'                  list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"),
     #'                  c("jarfile1.jar","jarfile2.jar"))
     #'}
    @@ -123,16 +125,21 @@ sparkR.init <- function(
         uriSep <- "////"
       }
     
    +  sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
    +
       existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
       if (existingPort != "") {
         backendPort <- existingPort
       } else {
         path <- tempfile(pattern = "backend_port")
    +    submitOps <- getClientModeSparkSubmitOpts(
    +        Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
    +        sparkEnvirMap)
         launchBackend(
             args = path,
             sparkHome = sparkHome,
             jars = jars,
    -        sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
    +        sparkSubmitOpts = submitOps,
             packages = sparkPackages)
         # wait atmost 100 seconds for JVM to launch
         wait <- 0.1
    @@ -171,8 +178,6 @@ sparkR.init <- function(
         sparkHome <- suppressWarnings(normalizePath(sparkHome))
       }
     
    -  sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
    -
       sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
       if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
         sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
    @@ -320,3 +325,33 @@ clearJobGroup <- function(sc) {
     cancelJobGroup <- function(sc, groupId) {
       callJMethod(sc, "cancelJobGroup", groupId)
     }
    +
    +sparkConfToSubmitOps <- new.env()
    +sparkConfToSubmitOps[["spark.driver.memory"]]           <- "--driver-memory"
    +sparkConfToSubmitOps[["spark.driver.extraClassPath"]]   <- "--driver-class-path"
    +sparkConfToSubmitOps[["spark.driver.extraJavaOptions"]] <- "--driver-java-options"
    +sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-path"
    +
    +# Utility function that returns Spark Submit arguments as a string
    +#
    +# A few Spark Application and Runtime environment properties cannot take effect after driver
    +# JVM has started, as documented in:
    +# http://spark.apache.org/docs/latest/configuration.html#application-properties
    +# When starting SparkR without using spark-submit, for example, from Rstudio, add them to
    +# spark-submit commandline if not already set in SPARKR_SUBMIT_ARGS so that they can be effective.
    +getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) {
    +  envirToOps <- lapply(ls(sparkConfToSubmitOps), function(conf) {
    +    opsValue <- sparkEnvirMap[[conf]]
    +    # process only if --option is not already specified
    +    if (!is.null(opsValue) &&
    +        nchar(opsValue) > 1 &&
    +        !grepl(sparkConfToSubmitOps[[conf]], submitOps)) {
    +      # put "" around value in case it has spaces
    +      paste0(sparkConfToSubmitOps[[conf]], " \"", opsValue, "\" ")
    +    } else {
    +      ""
    +    }
    +  })
    +  # --option must be before the application class "sparkr-shell" in submitOps
    +  paste0(paste0(envirToOps, collapse = ""), submitOps)
    +}
    diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R
    index e99815ed1562c..80c1b89a4c627 100644
    --- a/R/pkg/inst/tests/test_context.R
    +++ b/R/pkg/inst/tests/test_context.R
    @@ -65,3 +65,30 @@ test_that("job group functions can be called", {
       cancelJobGroup(sc, "groupId")
       clearJobGroup(sc)
     })
    +
    +test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", {
    +  e <- new.env()
    +  e[["spark.driver.memory"]] <- "512m"
    +  ops <- getClientModeSparkSubmitOpts("sparkrmain", e)
    +  expect_equal("--driver-memory \"512m\" sparkrmain", ops)
    +
    +  e[["spark.driver.memory"]] <- "5g"
    +  e[["spark.driver.extraClassPath"]] <- "/opt/class_path" # nolint
    +  e[["spark.driver.extraJavaOptions"]] <- "-XX:+UseCompressedOops -XX:+UseCompressedStrings"
    +  e[["spark.driver.extraLibraryPath"]] <- "/usr/local/hadoop/lib" # nolint
    +  e[["random"]] <- "skipthis"
    +  ops2 <- getClientModeSparkSubmitOpts("sparkr-shell", e)
    +  # nolint start
    +  expect_equal(ops2, paste0("--driver-class-path \"/opt/class_path\" --driver-java-options \"",
    +                      "-XX:+UseCompressedOops -XX:+UseCompressedStrings\" --driver-library-path \"",
    +                      "/usr/local/hadoop/lib\" --driver-memory \"5g\" sparkr-shell"))
    +  # nolint end
    +
    +  e[["spark.driver.extraClassPath"]] <- "/" # too short
    +  ops3 <- getClientModeSparkSubmitOpts("--driver-memory 4g sparkr-shell2", e)
    +  # nolint start
    +  expect_equal(ops3, paste0("--driver-java-options \"-XX:+UseCompressedOops ",
    +                      "-XX:+UseCompressedStrings\" --driver-library-path \"/usr/local/hadoop/lib\"",
    +                      " --driver-memory 4g sparkr-shell2"))
    +  # nolint end
    +})
    diff --git a/docs/sparkr.md b/docs/sparkr.md
    index 7139d16b4a068..497a276679f3b 100644
    --- a/docs/sparkr.md
    +++ b/docs/sparkr.md
    @@ -29,7 +29,7 @@ All of the examples on this page use sample data included in R or the Spark dist
     The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster.
     You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name
     , any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`,
    -which can be created from the  SparkContext. If you are working from the SparkR shell, the
    +which can be created from the  SparkContext. If you are working from the `sparkR` shell, the
     `SQLContext` and `SparkContext` should already be created for you.
     
     {% highlight r %}
    @@ -37,17 +37,29 @@ sc <- sparkR.init()
     sqlContext <- sparkRSQL.init(sc)
     {% endhighlight %}
     
    +In the event you are creating `SparkContext` instead of using `sparkR` shell or `spark-submit`, you 
    +could also specify certain Spark driver properties. Normally these
    +[Application properties](configuration.html#application-properties) and
    +[Runtime Environment](configuration.html#runtime-environment) cannot be set programmatically, as the
    +driver JVM process would have been started, in this case SparkR takes care of this for you. To set
    +them, pass them as you would other configuration properties in the `sparkEnvir` argument to
    +`sparkR.init()`.
    +
    +{% highlight r %}
    +sc <- sparkR.init("local[*]", "SparkR", "/home/spark", list(spark.driver.memory="2g"))
    +{% endhighlight %}
    +
     
     
     ## Creating DataFrames
     With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources).
     
     ### From local data frames
    -The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. 
    +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R.
     
     
    {% highlight r %} -df <- createDataFrame(sqlContext, faithful) +df <- createDataFrame(sqlContext, faithful) # Displays the content of the DataFrame to stdout head(df) @@ -96,7 +108,7 @@ printSchema(people)
    The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` +to a Parquet file using `write.df`
    {% highlight r %} @@ -139,7 +151,7 @@ Here we include some basic examples and a complete list can be found in the [API
    {% highlight r %} # Create the DataFrame -df <- createDataFrame(sqlContext, faithful) +df <- createDataFrame(sqlContext, faithful) # Get basic information about the DataFrame df @@ -152,7 +164,7 @@ head(select(df, df$eruptions)) ##2 1.800 ##3 3.333 -# You can also pass in column name as strings +# You can also pass in column name as strings head(select(df, "eruptions")) # Filter the DataFrame to only retain rows with wait times shorter than 50 mins @@ -166,7 +178,7 @@ head(filter(df, df$waiting < 50))
    -### Grouping, Aggregation +### Grouping, Aggregation SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below @@ -194,7 +206,7 @@ head(arrange(waiting_counts, desc(waiting_counts$count))) ### Operating on Columns -SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
    {% highlight r %} From 45029bfdea42eb8964f2ba697859687393d2a558 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 30 Oct 2015 15:47:40 -0700 Subject: [PATCH 098/324] [SPARK-11423] remove MapPartitionsWithPreparationRDD Since we do not need to preserve a page before calling compute(), MapPartitionsWithPreparationRDD is not needed anymore. This PR basically revert #8543, #8511, #8038, #8011 Author: Davies Liu Closes #9381 from davies/remove_prepare2. --- .../rdd/MapPartitionsWithPreparationRDD.scala | 66 ---------------- .../spark/rdd/ZippedPartitionsRDD.scala | 13 --- ...MapPartitionsWithPreparationRDDSuite.scala | 66 ---------------- project/MimaExcludes.scala | 6 +- .../UnsafeFixedWidthAggregationMap.java | 9 +-- .../aggregate/TungstenAggregate.scala | 79 +++++++------------ .../TungstenAggregationIterator.scala | 78 ++++++++---------- .../org/apache/spark/sql/execution/sort.scala | 28 ++----- 8 files changed, 75 insertions(+), 270 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala delete mode 100644 core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala deleted file mode 100644 index 417ff5278db2a..0000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - -import org.apache.spark.{Partition, Partitioner, TaskContext} - -/** - * An RDD that applies a user provided function to every partition of the parent RDD, and - * additionally allows the user to prepare each partition before computing the parent partition. - */ -private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag]( - prev: RDD[T], - preparePartition: () => M, - executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner: Option[Partitioner] = { - if (preservesPartitioning) firstParent[T].partitioner else None - } - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - // In certain join operations, prepare can be called on the same partition multiple times. - // In this case, we need to ensure that each call to compute gets a separate prepare argument. - private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] - - /** - * Prepare a partition for a single call to compute. - */ - def prepare(): Unit = { - preparedArguments += preparePartition() - } - - /** - * Prepare a partition before computing it from its parent. - */ - override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val prepared = - if (preparedArguments.isEmpty) { - preparePartition() - } else { - preparedArguments.remove(0) - } - val parentIterator = firstParent[T].iterator(partition, context) - executePartition(context, partition.index, prepared, parentIterator) - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 70bf04de6400d..4333a679c8aae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -73,16 +73,6 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( super.clearDependencies() rdds = null } - - /** - * Call the prepare method of every parent that has one. - * This is needed for reserving execution memory in advance. - */ - protected def tryPrepareParents(): Unit = { - rdds.collect { - case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare() - } - } } private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( @@ -94,7 +84,6 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -118,7 +107,6 @@ private[spark] class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), @@ -146,7 +134,6 @@ private[spark] class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala deleted file mode 100644 index e281e817e493d..0000000000000 --- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import scala.collection.mutable - -import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext} - -class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext { - - test("prepare called before parent partition is computed") { - sc = new SparkContext("local", "test") - - // Have the parent partition push a number to the list - val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter => - TestObject.things.append(20) - iter - } - - // Push a different number during the prepare phase - val preparePartition = () => { TestObject.things.append(10) } - - // Push yet another number during the execution phase - val executePartition = ( - taskContext: TaskContext, - partitionIndex: Int, - notUsed: Unit, - parentIterator: Iterator[Int]) => { - TestObject.things.append(30) - TestObject.things.iterator - } - - // Verify that the numbers are pushed in the order expected - val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( - parent, preparePartition, executePartition) - val result = rdd.collect() - assert(result === Array(10, 20, 30)) - - TestObject.things.clear() - // Zip two of these RDDs, both should be prepared before the parent is executed - val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( - parent, preparePartition, executePartition) - val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect() - assert(result2 === Array(10, 10, 20, 30, 20, 30)) - } - -} - -private object TestObject { - val things = new mutable.ListBuffer[Int] -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b5e661d3ecfa8..8282f7ea62400 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -107,7 +107,11 @@ object MimaExcludes { "org.apache.spark.sql.SQLContext.createSession") ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.preferredNodeLocationData_=") + "org.apache.spark.SparkContext.preferredNodeLocationData_="), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 889f97003450c..d4b6d75b4d981 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,9 +19,8 @@ import java.io.IOException; -import com.google.common.annotations.VisibleForTesting; - import org.apache.spark.SparkEnv; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -31,7 +30,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -218,11 +216,6 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } - @VisibleForTesting - public int getNumDataPages() { - return map.getNumDataPages(); - } - /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 0d3a4b36c161b..15616915f7364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.TaskContext -import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.types.StructType case class TungstenAggregate( @@ -84,59 +83,39 @@ case class TungstenAggregate( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - /** - * Set up the underlying unsafe data structures used before computing the parent partition. - * This makes sure our iterator is not starved by other operators in the same task. - */ - def preparePartition(): TungstenAggregationIterator = { - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - testFallbackStartsAt, - numInputRows, - numOutputRows, - dataSize, - spillSize) - } + child.execute().mapPartitions { iter => - /** Compute a partition using the iterator already set up previously. */ - def executePartition( - context: TaskContext, - partitionIndex: Int, - aggregationIterator: TungstenAggregationIterator, - parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { - val hasInput = parentIterator.hasNext - if (!hasInput) { - // We're not using the underlying map, so we just can free it here - aggregationIterator.free() - if (groupingExpressions.isEmpty) { + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter, + testFallbackStartsAt, + numInputRows, + numOutputRows, + dataSize, + spillSize) + if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty + aggregationIterator } - } else { - aggregationIterator.start(parentIterator) - aggregationIterator } } - - // Note: we need to set up the iterator in each partition before computing the - // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). - val resultRdd = { - new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( - child.execute(), preparePartition, executePartition, preservesPartitioning = true) - } - resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index fb2fc98e34fbe..713a4db0cd59b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -74,6 +74,8 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -85,6 +87,7 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric, @@ -92,9 +95,6 @@ class TungstenAggregationIterator( spillSize: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { - // The parent partition iterator, to be initialized later in `start` - private[this] var inputIter: Iterator[InternalRow] = null - /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -486,15 +486,11 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) - // Exposed for testing - private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap - // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -526,7 +522,6 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -567,15 +562,11 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { - assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() - // Step 2: Free the memory used by the map. - hashMap.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, + // Step 2: If we have aggregate function with mode Partial or Complete, // we need to process input rows to get aggregation buffer. // So, later in the sort-based aggregation iterator, we can do merge. // If aggregate functions are with mode Final and PartialMerge, @@ -770,31 +761,27 @@ class TungstenAggregationIterator( /** * Start processing input rows. - * Only after this method is called will this iterator be non-empty. */ - def start(parentIter: Iterator[InternalRow]): Unit = { - inputIter = parentIter - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() - } + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() } } @@ -868,13 +855,16 @@ class TungstenAggregationIterator( * Generate a output row when there is no input and there is no grouping expression. */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - assert(groupingExpressions.isEmpty) - assert(inputIter == null) - generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) - } - - /** Free memory used in the underlying map. */ - def free(): Unit = { - hashMap.free() + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index dd92dda480601..1a3832a698b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -77,6 +77,7 @@ case class Sort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ + case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, @@ -106,11 +107,7 @@ case class TungstenSort( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - /** - * Set up the sorter in each partition before computing the parent partition. - * This makes sure our sorter is not starved by other sorters used in the same task. - */ - def preparePartition(): UnsafeExternalRowSorter = { + child.execute().mapPartitions { iter => val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -131,33 +128,20 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter - } - /** Compute a partition using the sorter already set up previously. */ - def executePartition( - taskContext: TaskContext, - partitionIndex: Int, - sorter: UnsafeExternalRowSorter, - parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = { // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) dataSize += sorter.getPeakMemoryUsage spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore - taskContext.internalMetricsToAccumulators( + TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator } - - // Note: we need to set up the external sorter in each partition before computing - // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709). - new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter]( - child.execute(), preparePartition, executePartition, preservesPartitioning = true) } } From e8ec2a7b01cc86329a6fbafc3d371bdfd79fc1d6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 30 Oct 2015 16:12:33 -0700 Subject: [PATCH 099/324] Revert "[SPARK-11236][CORE] Update Tachyon dependency from 0.7.1 -> 0.8.0." This reverts commit 4f5e60c647d7d6827438721b7fabbc3a57b81023. --- core/pom.xml | 6 +++++- make-distribution.sh | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index dff40e91ad228..319a50049a82d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -266,7 +266,7 @@ org.tachyonproject tachyon-client - 0.8.0 + 0.7.1 org.apache.hadoop @@ -288,6 +288,10 @@ org.tachyonproject tachyon-underfs-glusterfs + + org.tachyonproject + tachyon-underfs-s3 + diff --git a/make-distribution.sh b/make-distribution.sh index f6766784813c3..24418ace26270 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,9 +33,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.8.0" +TACHYON_VERSION="0.7.1" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" +TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" MAKE_TGZ=false NAME=none @@ -240,10 +240,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx From 69b9e4b3c2f929e3df55f5e71875c03bb9712948 Mon Sep 17 00:00:00 2001 From: Nakul Jindal Date: Fri, 30 Oct 2015 17:12:24 -0700 Subject: [PATCH 100/324] [SPARK-11385] [ML] foreachActive made public in MLLib's vector API Made foreachActive public in MLLib's vector API Author: Nakul Jindal Closes #9362 from nakul02/SPARK-11385_foreach_for_mllib_linalg_vector. --- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index dcdc614455d34..bd9badc03c345 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -123,7 +123,8 @@ sealed trait Vector extends Serializable { * the vector with type `Int`, and the second parameter is the corresponding value * with type `Double`. */ - private[spark] def foreachActive(f: (Int, Double) => Unit) + @Since("1.6.0") + def foreachActive(f: (Int, Double) => Unit): Unit /** * Number of active entries. An "active entry" is an element which is explicitly stored, @@ -570,7 +571,8 @@ class DenseVector @Since("1.0.0") ( new DenseVector(values.clone()) } - private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + @Since("1.6.0") + override def foreachActive(f: (Int, Double) => Unit): Unit = { var i = 0 val localValuesSize = values.length val localValues = values @@ -700,7 +702,8 @@ class SparseVector @Since("1.0.0") ( private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) - private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + @Since("1.6.0") + override def foreachActive(f: (Int, Double) => Unit): Unit = { var i = 0 val localValuesSize = values.length val localIndices = indices From 3c471885dc4f86bea95ab542e0d48d22ae748404 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 30 Oct 2015 20:05:07 -0700 Subject: [PATCH 101/324] [SPARK-11434][SPARK-11103][SQL] Fix test ": Filter applied on merged Parquet schema with new column fails" https://issues.apache.org/jira/browse/SPARK-11434 Author: Yin Huai Closes #9387 from yhuai/SPARK-11434. --- .../execution/datasources/parquet/ParquetFilterSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b2101beb92224..f88ddc77a6a4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -323,15 +323,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { withTempPath { dir => - var pathOne = s"${dir.getCanonicalPath}/table1" + val pathOne = s"${dir.getCanonicalPath}/table1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) - var pathTwo = s"${dir.getCanonicalPath}/table2" + val pathTwo = s"${dir.getCanonicalPath}/table2" (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) // If the "c = 1" filter gets pushed down, this query will throw an exception which // Parquet emits. This is a Parquet issue (PARQUET-389). checkAnswer( - sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1"), + sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a"), (1 to 1).map(i => Row(i, i.toString, null))) } } From 97b3c8fb470f0d3c1cdb1aeb27f675e695442e87 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Sat, 31 Oct 2015 11:10:37 +0000 Subject: [PATCH 102/324] [SPARK-11226][SQL] Empty line in json file should be skipped Currently the empty line in json file will be parsed into Row with all null field values. But in json, "{}" represents a json object, empty line is supposed to be skipped. Make a trivial change for this. Author: Jeff Zhang Closes #9211 from zjffdu/SPARK-11226. --- .../datasources/json/JacksonParser.scala | 46 ++++++++++--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++ .../datasources/json/JsonSuite.scala | 3 -- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index b2e52011a7276..4f53eeb081b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -245,29 +245,33 @@ private[sql] object JacksonParser { val factory = new JsonFactory() iter.flatMap { record => - try { - Utils.tryWithResource(factory.createParser(record)) { parser => - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") + if (record.trim.isEmpty) { + Nil + } else { + try { + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of " + + "the file (or each string in the RDD) is a valid JSON object or " + + "an array of JSON objects.") + } } + } catch { + case _: JsonProcessingException => + failedRecord(record) } - } catch { - case _: JsonProcessingException => - failedRecord(record) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a616fac0bc2d..5413ef1287da1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -225,6 +225,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row("1"), Row("2"))) } + test("SPARK-11226 Skip empty line in json file") { + sqlContext.read.json( + sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) + .registerTempTable("d") + + checkAnswer( + sql("select count(1) from d"), + Seq(Row(3))) + } + test("SPARK-8828 sum should return null if all input values are null") { withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index d3fd409291f29..28b8f02bdf87f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -959,7 +959,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempTable("jsonTable") { val jsonDF = sqlContext.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") - val schema = StructType( StructField("_unparsed", StringType, true) :: StructField("a", StringType, true) :: @@ -976,7 +975,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { |FROM jsonTable """.stripMargin), Row(null, null, null, "{") :: - Row(null, null, null, "") :: Row(null, null, null, """{"a":1, b:2}""") :: Row(null, null, null, """{"a":{, b:3}""") :: Row("str_a_4", "str_b_4", "str_c_4", null) :: @@ -1001,7 +999,6 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { |WHERE _unparsed IS NOT NULL """.stripMargin), Row("{") :: - Row("") :: Row("""{"a":1, b:2}""") :: Row("""{"a":{, b:3}""") :: Row("]") :: Nil From ac4118db2dda802b936bb7a18a08844846c71285 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 31 Oct 2015 10:47:22 -0700 Subject: [PATCH 103/324] [SPARK-11424] Guard against double-close() of RecordReaders **TL;DR**: We can rule out one rare but potential cause of input stream corruption via defensive programming. ## Background [MAPREDUCE-5918](https://issues.apache.org/jira/browse/MAPREDUCE-5918) is a bug where an instance of a decompressor ends up getting placed into a pool multiple times. Since the pool is backed by a list instead of a set, this can lead to the same decompressor being used in different places at the same time, which is not safe because those decompressors will overwrite each other's buffers. Sometimes this buffer sharing will lead to exceptions but other times it will might silently result in invalid / garbled input. That Hadoop bug is fixed in Hadoop 2.7 but is still present in many Hadoop versions that we wish to support. As a result, I think that we should try to work around this issue in Spark via defensive programming to prevent RecordReaders from being closed multiple times. So far, I've had a hard time coming up with explanations of exactly how double-`close()`s occur in practice, but I do have a couple of explanations that work on paper. For instance, it looks like https://github.com/apache/spark/pull/7424, added in 1.5, introduces at least one extremely~rare corner-case path where Spark could double-close() a LineRecordReader instance in a way that triggers the bug. Here are the steps involved in the bad execution that I brainstormed up: * [The task has finished reading input, so we call close()](https://github.com/apache/spark/blob/v1.5.1/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala#L168). * [While handling the close call and trying to close the reader, reader.close() throws an exception]( https://github.com/apache/spark/blob/v1.5.1/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala#L190) * We don't set `reader = null` after handling this exception, so the [TaskCompletionListener also ends up calling NewHadoopRDD.close()](https://github.com/apache/spark/blob/v1.5.1/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala#L156), which, in turn, closes the record reader again. In this hypothetical situation, `LineRecordReader.close()` could [fail with an exception if its InputStream failed to close](https://github.com/apache/hadoop/blob/release-1.2.1/src/mapred/org/apache/hadoop/mapred/LineRecordReader.java#L212). I googled for "Exception in RecordReader.close()" and it looks like it's possible for a closed Hadoop FileSystem to trigger an error there: [SPARK-757](https://issues.apache.org/jira/browse/SPARK-757), [SPARK-2491](https://issues.apache.org/jira/browse/SPARK-2491) Looking at [SPARK-3052](https://issues.apache.org/jira/browse/SPARK-3052), it seems like it's possible to get spurious exceptions there when there is an error reading from Hadoop. If the Hadoop FileSystem were to get into an error state _right_ after reading the last record then it looks like we could hit the bug here in 1.5. ## The fix This patch guards against these issues by modifying `HadoopRDD.close()` and `NewHadoopRDD.close()` so that they set `reader = null` even if an exception occurs in the `reader.close()` call. In addition, I modified `NextIterator. closeIfNeeded()` to guard against double-close if the first `close()` call throws an exception. I don't have an easy way to test this, since I haven't been able to reproduce the bug that prompted this patch, but these changes seem safe and seem to rule out the on-paper reproductions that I was able to brainstorm up. Author: Josh Rosen Closes #9382 from JoshRosen/hadoop-decompressor-pooling-fix and squashes the following commits: 5ec97d7 [Josh Rosen] Add SqlNewHadoopRDD.unsetInputFileName() that I accidentally deleted. ae46cf4 [Josh Rosen] Merge remote-tracking branch 'origin/master' into hadoop-decompressor-pooling-fix 087aa63 [Josh Rosen] Guard against double-close() of RecordReaders. --- .../org/apache/spark/rdd/HadoopRDD.scala | 23 +++++---- .../org/apache/spark/rdd/NewHadoopRDD.scala | 44 ++++++++--------- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 47 ++++++++++--------- .../org/apache/spark/util/NextIterator.scala | 4 +- 4 files changed, 66 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 77b57132b9f1f..d841f05ec52cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -251,8 +251,21 @@ class HadoopRDD[K, V]( } override def close() { - try { - reader.close() + if (reader != null) { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { + reader.close() + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } finally { + reader = null + } if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit] || @@ -266,12 +279,6 @@ class HadoopRDD[K, V]( logWarning("Unable to get input size to set InputMetrics for task", e) } } - } catch { - case e: Exception => { - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2872b93b8730e..9c4b70844bdbe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -184,30 +184,32 @@ class NewHadoopRDD[K, V]( } private def close() { - try { - if (reader != null) { - // Close reader and release it + if (reader != null) { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { reader.close() - reader = null - - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) } - } + } finally { + reader = null } - } catch { - case e: Exception => { - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 0228c54e0511c..264dae7f39085 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -189,32 +189,35 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } private def close() { - try { - if (reader != null) { + if (reader != null) { + SqlNewHadoopRDD.unsetInputFileName() + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic + // corruption issues when reading compressed input. + try { reader.close() - reader = null - - SqlNewHadoopRDD.unsetInputFileName() - - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + } catch { + case e: Exception => + if (!ShutdownHookManager.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) } - } + } finally { + reader = null } - } catch { - case e: Exception => - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) } + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala index e5c732a5a559b..0b505a576768c 100644 --- a/core/src/main/scala/org/apache/spark/util/NextIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -60,8 +60,10 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] { */ def closeIfNeeded() { if (!closed) { - close() + // Note: it's important that we set closed = true before calling close(), since setting it + // afterwards would permit us to call close() multiple times if close() threw an exception. closed = true + close() } } From fc27dfbf0f8d3f96c70e27d88f7d0316c97ddb1e Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 31 Oct 2015 12:55:33 -0700 Subject: [PATCH 104/324] [SPARK-11024][SQL] Optimize NULL in by folding it to Literal(null) Add a rule in optimizer to convert NULL [NOT] IN (expr1,...,expr2) to Literal(null). This is a follow up defect to SPARK-8654 cloud-fan Can you please take a look ? Author: Dilip Biswal Closes #9348 from dilipbiswal/spark_11024. --- .../sql/catalyst/optimizer/Optimizer.scala | 5 ++ .../catalyst/optimizer/OptimizeInSuite.scala | 51 ++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d37f43888fd4f..338c5193cb7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -417,6 +417,11 @@ object NullPropagation extends Rule[LogicalPlan] { case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } + + // If the value expression is NULL then transform the In expression to + // Literal(null) + case In(Literal(null, _), list) => Literal.create(null, BooleanType) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 6f7b5b9572e22..48cab01ac1004 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -35,7 +35,8 @@ class OptimizeInSuite extends PlanTest { val batches = Batch("AnalysisNodes", Once, EliminateSubQueries) :: - Batch("ConstantFolding", Once, + Batch("ConstantFolding", FixedPoint(10), + NullPropagation, ConstantFolding, BooleanSimplification, OptimizeIn) :: Nil @@ -82,4 +83,52 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { + val originalQuery = + testRelation + .where(In(Literal.create(null, NullType), Seq(Literal(1), Literal(2)))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: Inset optimization disabled as " + + "list expression contains attribute)") { + val originalQuery = + testRelation + .where(In(Literal.create(null, StringType), Seq(Literal(1), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: Inset optimization disabled as " + + "list expression contains attribute - select)") { + val originalQuery = + testRelation + .select(In(Literal.create(null, StringType), + Seq(Literal(1), UnresolvedAttribute("b"))).as("a")).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Literal.create(null, BooleanType).as("a")) + .analyze + + comparePlans(optimized, correctAnswer) + } + } From 40d3c6797a3dfd037eb69b2bcd336d8544deddf5 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Sat, 31 Oct 2015 18:23:15 -0700 Subject: [PATCH 105/324] [SPARK-11265][YARN] YarnClient can't get tokens to talk to Hive 1.2.1 in a secure cluster This is a fix for SPARK-11265; the introspection code to get Hive delegation tokens failing on Spark 1.5.1+, due to changes in the Hive codebase Author: Steve Loughran Closes #9232 from steveloughran/stevel/patches/SPARK-11265-hive-tokens. --- yarn/pom.xml | 25 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 51 +------------ .../deploy/yarn/YarnSparkHadoopUtil.scala | 73 +++++++++++++++++++ .../yarn/YarnSparkHadoopUtilSuite.scala | 29 ++++++++ 4 files changed, 129 insertions(+), 49 deletions(-) diff --git a/yarn/pom.xml b/yarn/pom.xml index 3eadacba13e18..989b820bec9ef 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -162,6 +162,31 @@ jersey-server test + + + + ${hive.group} + hive-exec + test + + + ${hive.group} + hive-metastore + test + + + org.apache.thrift + libthrift + test + + + org.apache.thrift + libfb303 + test + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4954b6180902e..a3f33d80184a3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1337,55 +1337,8 @@ object Client extends Logging { conf: Configuration, credentials: Credentials) { if (shouldGetTokens(sparkConf, "hive") && UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") - val hiveConf = hiveConfClass.newInstance() - - val hiveConfGet = (param: String) => Option(hiveConfClass - .getMethod("get", classOf[java.lang.String]) - .invoke(hiveConf, param)) - - val metastore_uri = hiveConfGet("hive.metastore.uris") - - // Check for local metastore - if (metastore_uri != None && metastore_uri.get.toString.size > 0) { - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hive = hiveClass.getMethod("get").invoke(null, hiveConf.asInstanceOf[Object]) - - val metastore_kerberos_principal_conf_var = mirror.classLoader - .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") - .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString - - val principal = hiveConfGet(metastore_kerberos_principal_conf_var) - - val username = Option(UserGroupInformation.getCurrentUser().getUserName) - if (principal != None && username != None) { - val tokenStr = hiveClass.getMethod("getDelegationToken", - classOf[java.lang.String], classOf[java.lang.String]) - .invoke(hive, username.get, principal.get).asInstanceOf[java.lang.String] - - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - credentials.addToken(new Text("hive.server2.delegation.token"), hive2Token) - logDebug("Added hive.Server2.delegation.token to conf.") - hiveClass.getMethod("closeCurrent").invoke(null) - } else { - logError("Username or principal == NULL") - logError(s"""username=${username.getOrElse("(NULL)")}""") - logError(s"""principal=${principal.getOrElse("(NULL)")}""") - throw new IllegalArgumentException("username and/or principal is equal to null!") - } - } else { - logDebug("HiveMetaStore configured in localmode") - } - } catch { - case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e: Exception => { logError("Unexpected Exception " + e) - throw new RuntimeException("Unexpected exception", e) - } + YarnSparkHadoopUtil.get.obtainTokenForHiveMetastore(conf).foreach { + credentials.addToken(new Text("hive.server2.delegation.token"), _) } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f276e7efde9d7..5924daf3ece49 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -22,14 +22,17 @@ import java.util.regex.Matcher import java.util.regex.Pattern import scala.collection.mutable.HashMap +import scala.reflect.runtime._ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.token.Token import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -142,6 +145,76 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) ConverterUtils.toContainerId(containerIdString) } + + /** + * Obtains token for the Hive metastore, using the current user as the principal. + * Some exceptions are caught and downgraded to a log message. + * @param conf hadoop configuration; the Hive configuration will be based on this + * @return a token, or `None` if there's no need for a token (no metastore URI or principal + * in the config), or if a binding exception was caught and downgraded. + */ + def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { + try { + obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + } catch { + case e: ClassNotFoundException => + logInfo(s"Hive class not found $e") + logDebug("Hive class not found", e) + None + } + } + + /** + * Inner routine to obtains token for the Hive metastore; exceptions are raised on any problem. + * @param conf hadoop configuration; the Hive configuration will be based on this. + * @param username the username of the principal requesting the delegating token. + * @return a delegation token + */ + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, + username: String): Option[Token[DelegationTokenIdentifier]] = { + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) + + // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down + // to a Configuration and used without reflection + val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + // using the (Configuration, Class) constructor allows the current configuratin to be included + // in the hive config. + val ctor = hiveConfClass.getDeclaredConstructor(classOf[Configuration], + classOf[Object].getClass) + val hiveConf = ctor.newInstance(conf, hiveConfClass).asInstanceOf[Configuration] + val metastoreUri = hiveConf.getTrimmed("hive.metastore.uris", "") + + // Check for local metastore + if (metastoreUri.nonEmpty) { + require(username.nonEmpty, "Username undefined") + val principalKey = "hive.metastore.kerberos.principal" + val principal = hiveConf.getTrimmed(principalKey, "") + require(principal.nonEmpty, "Hive principal $principalKey undefined") + logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val closeCurrent = hiveClass.getMethod("closeCurrent") + try { + // get all the instance methods before invoking any + val getDelegationToken = hiveClass.getMethod("getDelegationToken", + classOf[String], classOf[String]) + val getHive = hiveClass.getMethod("get", hiveConfClass) + + // invoke + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } finally { + Utils.tryLogNonFatalError { + closeCurrent.invoke(null) + } + } + } else { + logDebug("HiveMetaStore configured in localmode") + None + } + } } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e1c67db76571f..9132c56a91754 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} +import java.lang.reflect.InvocationTargetException import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -245,4 +247,31 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging System.clearProperty("SPARK_YARN_MODE") } } + + test("Obtain tokens For HiveMetastore") { + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.kerberos.principal", "bob") + // thrift picks up on port 0 and bails out, without trying to talk to endpoint + hadoopConf.set("hive.metastore.uris", "http://localhost:0") + val util = new YarnSparkHadoopUtil + assertNestedHiveException(intercept[InvocationTargetException] { + util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + }) + // expect exception trapping code to unwind this hive-side exception + assertNestedHiveException(intercept[InvocationTargetException] { + util.obtainTokenForHiveMetastore(hadoopConf) + }) + } + + def assertNestedHiveException(e: InvocationTargetException): Throwable = { + val inner = e.getCause + if (inner == null) { + fail("No inner cause", e) + } + if (!inner.isInstanceOf[HiveException]) { + fail("Not a hive exception", inner) + } + inner + } + } From aa494a9c2ebd59baec47beb434cd09bf3f188218 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 31 Oct 2015 21:16:09 -0700 Subject: [PATCH 106/324] [SPARK-11117] [SPARK-11345] [SQL] Makes all HadoopFsRelation data sources produce UnsafeRow This PR fixes two issues: 1. `PhysicalRDD.outputsUnsafeRows` is always `false` Thus a `ConvertToUnsafe` operator is often required even if the underlying data source relation does output `UnsafeRow`. 1. Internal/external row conversion for `HadoopFsRelation` is kinda messy Currently we're using `HadoopFsRelation.needConversion` and [dirty type erasure hacks][1] to indicate whether the relation outputs external row or internal row and apply external-to-internal conversion when necessary. Basically, all builtin `HadoopFsRelation` data sources, i.e. Parquet, JSON, ORC, and Text output `InternalRow`, while typical external `HadoopFsRelation` data sources, e.g. spark-avro and spark-csv, output `Row`. This PR adds a `private[sql]` interface method `HadoopFsRelation.buildInternalScan`, which by default invokes `HadoopFsRelation.buildScan` and converts `Row`s to `UnsafeRow`s (which are also `InternalRow`s). All builtin `HadoopFsRelation` data sources override this method and directly output `UnsafeRow`s. In this way, now `HadoopFsRelation` always produces `UnsafeRow`s. Thus `PhysicalRDD.outputsUnsafeRows` can be properly set by checking whether the underlying data source is a `HadoopFsRelation`. A remaining question is that, can we assume that all non-builtin `HadoopFsRelation` data sources output external rows? At least all well known ones do so. However it's possible that some users implemented their own `HadoopFsRelation` data sources that leverages `InternalRow` and thus all those unstable internal data representations. If this assumption is safe, we can deprecate `HadoopFsRelation.needConversion` and cleanup some more conversion code (like [here][2] and [here][3]). This PR supersedes #9125. Follow-ups: 1. Makes JSON and ORC data sources output `UnsafeRow` directly 1. Makes `HiveTableScan` output `UnsafeRow` directly This is related to 1 since ORC data source shares the same `Writable` unwrapping code with `HiveTableScan`. [1]: https://github.com/apache/spark/blob/v1.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala#L353 [2]: https://github.com/apache/spark/blob/v1.5.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala#L331-L335 [3]: https://github.com/apache/spark/blob/v1.5.1/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala#L630-L669 Author: Cheng Lian Closes #9305 from liancheng/spark-11345.unsafe-hadoop-fs-relation. --- .../sql/columnar/GenerateColumnAccessor.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 8 ++-- .../datasources/DataSourceStrategy.scala | 30 ++++++++----- .../datasources/json/JSONRelation.scala | 18 +++++--- .../parquet/CatalystRecordMaterializer.scala | 2 +- .../parquet/CatalystRowConverter.scala | 8 +++- .../datasources/parquet/ParquetRelation.scala | 6 +-- .../datasources/text/DefaultSource.scala | 34 +++++++++----- .../apache/spark/sql/sources/interfaces.scala | 44 +++++++++++++++++-- .../parquet/ParquetQuerySuite.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 30 ++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 31 +++++++++++++ 12 files changed, 156 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index 7980a6f36d8ea..ff9393b465b7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -34,7 +34,7 @@ abstract class ColumnarIterator extends Iterator[InternalRow] { /** * An helper class to update the fields of UnsafeRow, used by ColumnAccessor * - * WARNNING: These setter MUST be called in increasing order of ordinals. + * WARNING: These setter MUST be called in increasing order of ordinals. */ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 87bd92e00a2c1..7a466cf6a0a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} @@ -93,7 +93,9 @@ private[sql] case class LogicalRDD( private[sql] case class PhysicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], - extraInformation: String) extends LeafNode { + extraInformation: String, + override val outputsUnsafeRows: Boolean = false) + extends LeafNode { protected override def doExecute(): RDD[InternalRow] = rdd @@ -105,7 +107,7 @@ private[sql] object PhysicalRDD { output: Seq[Attribute], rdd: RDD[InternalRow], relation: BaseRelation): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString) + PhysicalRDD(output, rdd, relation.toString, relation.isInstanceOf[HadoopFsRelation]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index af6626c897583..65859865c8fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,21 +17,21 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -106,8 +106,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l, projects, filters, - (a, f) => - toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil + (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil case l @ LogicalRelation(baseRelation: TableScan, _) => execution.PhysicalRDD.createFromDataSource( @@ -152,7 +151,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Don't scan any partition columns to save I/O. Here we are being optimistic and // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. - val dataRows = relation.buildScan( + val dataRows = relation.buildInternalScan( requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast) // Merges data values with partition values. @@ -161,7 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requiredDataColumns, partitionColumns, partitionValues, - toCatalystRDD(logicalRelation, requiredDataColumns, dataRows)) + dataRows) } val unionedRows = @@ -199,15 +198,24 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Builds `AttributeReference`s for all partition columns so that we can use them to project // required partition columns. Note that if a partition column appears in `requiredColumns`, // we should use the `AttributeReference` in `requiredColumns`. - val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap - val partitionColumns = partitionColumnSchema.toAttributes.map { a => - requiredColumnMap.getOrElse(a.name, a) + val partitionColumns = { + val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap + partitionColumnSchema.toAttributes.map { a => + requiredColumnMap.getOrElse(a.name, a) + } } val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { - val projection = UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) + // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and + // `UnsafeProjection`. Because the projection may also adjust column order. val mutableJoinedRow = new JoinedRow() - iterator.map(dataRow => projection(mutableJoinedRow(dataRow, partitionValues))) + val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues) + val unsafeProjection = + UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) + + iterator.map { unsafeDataRow => + unsafeProjection(mutableJoinedRow(unsafeDataRow, unsafePartitionValues)) + } } // This is an internal RDD whose call site the user should not be concerned with diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 5f104fca7d629..85b52f04c8d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -34,6 +34,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -122,14 +123,21 @@ private[sql] class JSONRelation( jsonSchema } - override def buildScan( + override private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[FileStatus]): RDD[Row] = { - JacksonParser( + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { + val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) + val rows = JacksonParser( inputRDD.getOrElse(createBaseRdd(inputPaths)), - StructType(requiredColumns.map(dataSchema(_))), - sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] + requiredDataSchema, + sqlContext.conf.columnNameOfCorruptRecord) + + rows.mapPartitions { iterator => + val unsafeProjection = UnsafeProjection.create(requiredDataSchema) + iterator.map(unsafeProjection) + } } override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index ed9e0aa65977b..eeead9f5d88a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -35,7 +35,7 @@ private[parquet] class CatalystRecordMaterializer( private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - override def getCurrentRecord: InternalRow = rootConverter.currentRow + override def getCurrentRecord: InternalRow = rootConverter.currentRecord override def getRootConverter: GroupConverter = rootConverter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index b16c46579f7c5..1f653cd3d3cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -163,10 +163,14 @@ private[parquet] class CatalystRowConverter( override def setFloat(value: Float): Unit = row.setFloat(ordinal, value) } + private val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + + private val unsafeProjection = UnsafeProjection.create(catalystType) + /** - * Represents the converted row object once an entire Parquet record is converted. + * The [[UnsafeRow]] converted from an entire Parquet record. */ - val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) + def currentRecord: UnsafeRow = unsafeProjection(currentRow) // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 44649a68b3c9b..5a7c6b95b565f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -282,11 +282,11 @@ private[sql] class ParquetRelation( } } - override def buildScan( + override def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString @@ -361,7 +361,7 @@ private[sql] class ParquetRelation( id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) } } - }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index ab26c57ad1923..52c4421d7e87e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -25,16 +25,20 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} +import org.apache.spark.sql.columnar.MutableUnsafeRow import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. @@ -79,8 +83,12 @@ private[sql] class TextRelation( /** This is an internal data source that outputs internal row format. */ override val needConversion: Boolean = false - /** Read path. */ - override def buildScan(inputPaths: Array[FileStatus]): RDD[Row] = { + + override private[sql] def buildInternalScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) val paths = inputPaths.map(_.getPath).sortBy(_.toUri) @@ -92,17 +100,19 @@ private[sql] class TextRelation( sqlContext.sparkContext.hadoopRDD( conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) .mapPartitions { iter => - var buffer = new Array[Byte](1024) - val row = new GenericMutableRow(1) + val bufferHolder = new BufferHolder + val unsafeRowWriter = new UnsafeRowWriter + val unsafeRow = new UnsafeRow + iter.map { case (_, line) => - if (line.getLength > buffer.length) { - buffer = new Array[Byte](line.getLength) - } - System.arraycopy(line.getBytes, 0, buffer, 0, line.getLength) - row.update(0, UTF8String.fromBytes(buffer, 0, line.getLength)) - row + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.initialize(bufferHolder, 1) + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) + unsafeRow } - }.asInstanceOf[RDD[Row]] + } } /** Write path. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index a9a013e936fd3..7a553511483ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -585,11 +585,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - final private[sql] def buildScan( + final private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) @@ -604,7 +604,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } } - buildScan(requiredColumns, filters, inputStatuses, broadcastedConf) + buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf) } /** @@ -740,6 +740,44 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio buildScan(requiredColumns, filters, inputFiles) } + /** + * For a non-partitioned relation, this method builds an `RDD[InternalRow]` containing all rows + * within this relation. For partitioned relations, this method is called for each selected + * partition, and builds an `RDD[InternalRow]` containing all rows within that single partition. + * + * Note: + * + * 1. Rows contained in the returned `RDD[InternalRow]` are assumed to be `UnsafeRow`s. + * 2. This interface is subject to change in future. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the + * overhead of broadcasting the Configuration for every Hadoop RDD. + */ + private[sql] def buildInternalScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { + val requiredSchema = StructType(requiredColumns.map(dataSchema.apply)) + val internalRows = { + val externalRows = buildScan(requiredColumns, filters, inputFiles, broadcastedConf) + execution.RDDConversions.rowToRowRdd(externalRows, requiredSchema.map(_.dataType)) + } + + internalRows.mapPartitions { iterator => + val unsafeProjection = UnsafeProjection.create(requiredSchema) + iterator.map(unsafeProjection) + } + } + /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index baff7f5752a75..70fae32b7e7a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{TableIdentifier, InternalRow} import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index d1f30e188eafb..45de567039760 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -19,21 +19,20 @@ package org.apache.spark.sql.hive.orc import java.util.Properties -import scala.collection.JavaConverters._ - import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit, OrcStruct} import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfoUtils, StructTypeInfo} +import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} @@ -199,12 +198,13 @@ private[sql] class OrcRelation( partitionColumns) } - override def buildScan( + override private[sql] def buildInternalScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[FileStatus]): RDD[Row] = { + inputPaths: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] + OrcTableScan(output, this, filters, inputPaths).execute() } override def prepareJobForWrite(job: Job): OutputWriterFactory = { @@ -253,16 +253,17 @@ private[orc] case class OrcTableScan( path: String, conf: Configuration, iterator: Iterator[Writable], - nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[InternalRow] = { + nonPartitionKeyAttrs: Seq[Attribute]): Iterator[InternalRow] = { val deserializer = new OrcSerde val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(StructType.fromAttributes(nonPartitionKeyAttrs)) // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty // partition since we know that this file is empty. maybeStructOI.map { soi => - val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.zipWithIndex.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip @@ -280,7 +281,7 @@ private[orc] case class OrcTableScan( } i += 1 } - mutableRow: InternalRow + unsafeProjection(mutableRow) } }.getOrElse { Iterator.empty @@ -322,13 +323,8 @@ private[orc] case class OrcTableScan( val wrappedConf = new SerializableConfiguration(conf) rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => - val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) - fillObject( - split.getPath.toString, - wrappedConf.value, - iterator.map(_._2), - attributes.zipWithIndex, - mutableRow) + val writableIterator = iterator.map(_._2) + fillObject(split.getPath.toString, wrappedConf.value, writableIterator, attributes) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index e3605bb3f6bf0..100b97137cff0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,6 +27,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -687,6 +688,36 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } + + test("HadoopFsRelation produces UnsafeRow") { + withTempTable("test_unsafe") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(3).write.format(dataSourceName).save(path) + sqlContext.read + .format(dataSourceName) + .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) + .load(path) + .registerTempTable("test_unsafe") + + val df = sqlContext.sql( + """SELECT COUNT(*) + |FROM test_unsafe a JOIN test_unsafe b + |WHERE a.id = b.id + """.stripMargin) + + val plan = df.queryExecution.executedPlan + + assert( + plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, + s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): + |$plan + """.stripMargin) + + checkAnswer(df, Row(3)) + } + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 643c49c75ee95243fd19ae73b5170e6e6e212b8d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 1 Nov 2015 12:25:49 +0000 Subject: [PATCH 107/324] [SPARK-11305][DOCS] Remove Third-Party Hadoop Distributions Doc Page Remove Hadoop third party distro page, and move Hadoop cluster config info to configuration page CC pwendell Author: Sean Owen Closes #9298 from srowen/SPARK-11305. --- README.md | 5 +- docs/_layouts/global.html | 1 - docs/configuration.md | 15 +++ docs/hadoop-third-party-distributions.md | 117 ----------------------- docs/index.md | 1 - docs/programming-guide.md | 9 +- 6 files changed, 19 insertions(+), 129 deletions(-) delete mode 100644 docs/hadoop-third-party-distributions.md diff --git a/README.md b/README.md index 4116ef3563879..c0d6a946035a9 100644 --- a/README.md +++ b/README.md @@ -87,10 +87,7 @@ Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at ["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including -building for particular Hive and Hive Thriftserver distributions. See also -["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) -for guidance on building a Spark application that works with a particular -distribution. +building for particular Hive and Hive Thriftserver distributions. ## Configuration diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b4952fe97ca0e..467ff7a03fb70 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -112,7 +112,6 @@
  4. Job Scheduling
  5. Security
  6. Hardware Provisioning
  7. -
  8. 3rd-Party Hadoop Distros
  9. Building Spark
  10. Contributing to Spark
  11. diff --git a/docs/configuration.md b/docs/configuration.md index 682384d4249e0..c276e8e90decf 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1674,3 +1674,18 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. + +# Inheriting Hadoop Cluster Configuration + +If you plan to read and write from HDFS using Spark, there are two Hadoop configuration files that +should be included on Spark's classpath: + +* `hdfs-site.xml`, which provides default behaviors for the HDFS client. +* `core-site.xml`, which sets the default filesystem name. + +The location of these configuration files varies across CDH and HDP versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +configurations on-the-fly, but offer a mechanisms to download copies of them. + +To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` +to a location containing the configuration files. diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md deleted file mode 100644 index 795dd82a6be06..0000000000000 --- a/docs/hadoop-third-party-distributions.md +++ /dev/null @@ -1,117 +0,0 @@ ---- -layout: global -title: Third-Party Hadoop Distributions ---- - -Spark can run against all versions of Cloudera's Distribution Including Apache Hadoop (CDH) and -the Hortonworks Data Platform (HDP). There are a few things to keep in mind when using Spark -with these distributions: - -# Compile-time Hadoop Version - -When compiling Spark, you'll need to specify the Hadoop version by defining the `hadoop.version` -property. For certain versions, you will need to specify additional profiles. For more detail, -see the guide on [building with maven](building-spark.html#specifying-the-hadoop-version): - - mvn -Dhadoop.version=1.0.4 -DskipTests clean package - mvn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package - -The table below lists the corresponding `hadoop.version` code for each CDH/HDP release. Note that -some Hadoop releases are binary compatible across client versions. This means the pre-built Spark -distribution may "just work" without you needing to compile. That said, we recommend compiling with -the _exact_ Hadoop version you are running to avoid any compatibility errors. - - - - - - -
    -

    CDH Releases

    - - - - -
    ReleaseVersion code
    CDH 4.X.X (YARN mode)2.0.0-cdh4.X.X
    CDH 4.X.X2.0.0-mr1-cdh4.X.X
    -
    -

    HDP Releases

    - - - - - - - -
    ReleaseVersion code
    HDP 1.31.2.0
    HDP 1.21.1.2
    HDP 1.11.0.3
    HDP 1.01.0.3
    HDP 2.02.2.0
    -
    - -In SBT, the equivalent can be achieved by setting the the `hadoop.version` property: - - build/sbt -Dhadoop.version=1.0.4 assembly - -# Linking Applications to the Hadoop Version - -In addition to compiling Spark itself against the right version, you need to add a Maven dependency on that -version of `hadoop-client` to any Spark applications you run, so they can also talk to the HDFS version -on the cluster. If you are using CDH, you also need to add the Cloudera Maven repository. -This looks as follows in SBT: - -{% highlight scala %} -libraryDependencies += "org.apache.hadoop" % "hadoop-client" % "" - -// If using CDH, also add Cloudera repo -resolvers += "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" -{% endhighlight %} - -Or in Maven: - -{% highlight xml %} - - - ... - - org.apache.hadoop - hadoop-client - [version] - - - - - - ... - - Cloudera repository - https://repository.cloudera.com/artifactory/cloudera-repos/ - - - - -{% endhighlight %} - -# Where to Run Spark - -As described in the [Hardware Provisioning](hardware-provisioning.html#storage-systems) guide, -Spark can run in a variety of deployment modes: - -* Using dedicated set of Spark nodes in your cluster. These nodes should be co-located with your - Hadoop installation. -* Running on the same nodes as an existing Hadoop installation, with a fixed amount memory and - cores dedicated to Spark on each node. -* Run Spark alongside Hadoop using a cluster resource manager, such as YARN or Mesos. - -These options are identical for those using CDH and HDP. - -# Inheriting Cluster Configuration - -If you plan to read and write from HDFS using Spark, there are two Hadoop configuration files that -should be included on Spark's classpath: - -* `hdfs-site.xml`, which provides default behaviors for the HDFS client. -* `core-site.xml`, which sets the default filesystem name. - -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create -configurations on-the-fly, but offer a mechanisms to download copies of them. - -To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` -to a location containing the configuration files. diff --git a/docs/index.md b/docs/index.md index c0dc2b8d7412a..f1d9e012c6cf0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -117,7 +117,6 @@ options for deployment: * [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware -* [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 22656fd7910c0..f823b89a4b5e9 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -34,8 +34,7 @@ To write a Spark application, you need to add a Maven dependency on Spark. Spark version = {{site.SPARK_VERSION}} In addition, if you wish to access an HDFS cluster, you need to add a dependency on -`hadoop-client` for your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +`hadoop-client` for your version of HDFS. groupId = org.apache.hadoop artifactId = hadoop-client @@ -66,8 +65,7 @@ To write a Spark application in Java, you need to add a dependency on Spark. Spa version = {{site.SPARK_VERSION}} In addition, if you wish to access an HDFS cluster, you need to add a dependency on -`hadoop-client` for your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +`hadoop-client` for your version of HDFS. groupId = org.apache.hadoop artifactId = hadoop-client @@ -93,8 +91,7 @@ This script will load Spark's Java/Scala libraries and allow you to submit appli You can also use `bin/pyspark` to launch an interactive Python shell. If you wish to access HDFS data, you need to use a build of PySpark linking -to your version of HDFS. Some common HDFS version tags are listed on the -[third party distributions](hadoop-third-party-distributions.html) page. +to your version of HDFS. [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. From fae9bbaede31ce3ff6326c1a2cbd12c52b3243d9 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 12:33:28 -0800 Subject: [PATCH 108/324] Take R types instead to map to JVM types, add check for NA to keep column --- R/pkg/R/DataFrame.R | 24 +++++++++++++++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index bc4c61d67421c..2fc7b6a67de38 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -293,13 +293,22 @@ setMethod("colnames<-", dataFrame(sdf) }) +rToScalaTypes <- new.env() +rToScalaTypes[["integer"]] <- "integer" # in R, integer is 32bit +rToScalaTypes[["numeric"]] <- "double" # in R, numeric == double which is 64bit +rToScalaTypes[["double"]] <- "double" +rToScalaTypes[["character"]] <- "string" +rToScalaTypes[["logical"]] <- "boolean" + #' coltypes #' #' Set the column types of a DataFrame. #' #' @name coltypes #' @param x (DataFrame) -#' @return value (character) A character vector with the target column types for the given DataFrame +#' @return value (character) A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. #' @rdname coltypes #' @aliases coltypes #' @export @@ -309,7 +318,8 @@ setMethod("colnames<-", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) -#' coltypes(df) <- c("string", "integer") +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") #'} setMethod("coltypes<-", signature(x = "DataFrame", value = "character"), @@ -321,7 +331,15 @@ setMethod("coltypes<-", } newCols <- lapply(seq_len(ncols), function(i) { col <- getColumn(x, cols[i]) - cast(col, value[i]) + if (!is.na(value[i])) { + stype <- rToScalaTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } }) nx <- select(x, newCols) dataFrame(nx@sdf) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3020776e3c6f8..f87b1e05e405c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -612,14 +612,20 @@ test_that("coltypes() set the column types", { expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) - coltypes(df) <- c("string", "integer") + coltypes(df) <- c("character", "integer") expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) value <- collect(df[, 2])[[3, 1]] expect_equal(value, collect(df1)[[3, 1]]) expect_equal(value, 22) - expect_error(coltypes(df) <- c("string"), + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("cast(name as string)", "string"), + c("cast(cast(age as int) as double)", "double"))) + + expect_error(coltypes(df) <- c("character"), "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) test_that("head() and first() return the correct data", { From dc7e399fc01e74f2ba28ebd945785cc0f7759ccd Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Sun, 1 Nov 2015 13:09:42 -0800 Subject: [PATCH 109/324] [SPARK-11338] [WEBUI] Prepend app links on HistoryPage with uiRoot path [SPARK-11338: HistoryPage not multi-tenancy enabled ...](https://issues.apache.org/jira/browse/SPARK-11338) - `HistoryPage.scala` ...prepending all page links with the web proxy (`uiRoot`) path - `HistoryServerSuite.scala` ...adding a test case to verify all site-relative links are prefixed when the environment variable `APPLICATION_WEB_PROXY_BASE` (or System property `spark.ui.proxyBase`) is set Author: Christian Kadner Closes #9291 from ckadner/SPARK-11338 and squashes the following commits: 01d2f35 [Christian Kadner] [SPARK-11338][WebUI] nit fixes d054bd7 [Christian Kadner] [SPARK-11338][WebUI] prependBaseUri in method makePageLink 8bcb3dc [Christian Kadner] [SPARK-11338][WebUI] Prepend application links on HistoryPage with uiRoot path --- .../spark/deploy/history/HistoryPage.scala | 9 ++++---- .../deploy/history/HistoryServerSuite.scala | 21 +++++++++++++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index b347cb3be69f7..642d71b18c9e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -161,7 +161,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") info: ApplicationHistoryInfo, attempt: ApplicationAttemptInfo, isFirst: Boolean): Seq[Node] = { - val uiAddress = HistoryServer.getAttemptURI(info.id, attempt.attemptId) + val uiAddress = UIUtils.prependBaseUri(HistoryServer.getAttemptURI(info.id, attempt.attemptId)) val startTime = UIUtils.formatDate(attempt.startTime) val endTime = if (attempt.endTime > 0) UIUtils.formatDate(attempt.endTime) else "-" val duration = @@ -190,8 +190,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (renderAttemptIdColumn) { if (info.attempts.size > 1 && attempt.attemptId.isDefined) { - - {attempt.attemptId.get} + {attempt.attemptId.get} } else {   } @@ -218,9 +217,9 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = { - "/?" + Array( + UIUtils.prependBaseUri("/?" + Array( "page=" + linkPage, "showIncomplete=" + showIncomplete - ).mkString("&") + ).mkString("&")) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index e5b5e1bb65337..4b7fd4f13b692 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -261,7 +261,24 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers l <- links attrs <- l.attribute("href") } yield (attrs.toString) - justHrefs should contain(link) + justHrefs should contain (UIUtils.prependBaseUri(resource = link)) + } + + test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") + val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") + val page = new HistoryPage(server) + val request = mock[HttpServletRequest] + + // when + System.setProperty("spark.ui.proxyBase", uiRoot) + val response = page.render(request) + System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse("")) + + // then + val urls = response \\ "@href" map (_.toString) + val siteRelativeLinks = urls filter (_.startsWith("/")) + all (siteRelativeLinks) should startWith (uiRoot) } def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { From 49f71179726a70ce129ea1284cc83bd113f594f8 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 13:27:10 -0800 Subject: [PATCH 110/324] This seems to fix the Rd error - no idea why it worked before. --- R/pkg/R/generics.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 08f25b2cd0d01..ba61fc7c88cf6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -399,15 +399,15 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) -#' @rdname colnames +#' @rdname columns #' @export setGeneric("colnames", function(x) { standardGeneric("colnames") }) -#' @rdname colnames<- +#' @rdname columns #' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) -#' @rdname coltypes<- +#' @rdname columns #' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) From 046e32ed8467e0f46ffeca1a95d4d40017eb5bdb Mon Sep 17 00:00:00 2001 From: Nong Li Date: Sun, 1 Nov 2015 14:32:21 -0800 Subject: [PATCH 111/324] [SPARK-11410][SQL] Add APIs to provide functionality similar to Hive's DISTRIBUTE BY and SORT BY. DISTRIBUTE BY allows the user to hash partition the data by specified exprs. It also allows for optioning sorting within each resulting partition. There is no required relationship between the exprs for partitioning and sorting (i.e. one does not need to be a prefix of the other). This patch adds to APIs to DataFrames which can be used together to provide this functionality: 1. distributeBy() which partitions the data frame into a specified number of partitions using the partitioning exprs. 2. localSort() which sorts each partition using the provided sorting exprs. To get the DISTRIBUTE BY functionality, the user simply does: df.distributeBy(...).localSort(...) Author: Nong Li Closes #9364 from nongli/spark-11410. --- .../catalyst/plans/logical/partitioning.scala | 19 ++- .../org/apache/spark/sql/DataFrame.scala | 60 +++++++-- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 118 +++++++++++++++++- 4 files changed, 186 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 1f76b03bcb0f6..a5bdee1b854ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -31,10 +31,19 @@ case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData /** - * This method repartitions data using [[Expression]]s, and receives information about the - * number of partitions during execution. Used when a specific ordering or distribution is - * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * This method repartitions data using [[Expression]]s into `numPartitions`, and receives + * information about the number of partitions during execution. Used when a specific ordering or + * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. + * If `numPartitions` is not specified, the number of partitions will be the number set by + * `spark.sql.shuffle.partitions`. */ -case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData +case class RepartitionByExpression( + partitionExpressions: Seq[Expression], + child: LogicalPlan, + numPartitions: Option[Int] = None) extends RedistributeData { + numPartitions match { + case Some(n) => require(n > 0, "numPartitions must be greater than 0.") + case None => // Ok + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index aa817a037ef5e..53ad3c0266cdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -241,6 +241,18 @@ class DataFrame private[sql]( sb.toString() } + private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = global, logicalPlan) + } + override def toString: String = { try { schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") @@ -633,15 +645,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortExprs: Column*): DataFrame = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = true, logicalPlan) + sortInternal(true, sortExprs) } /** @@ -662,6 +666,44 @@ class DataFrame private[sql]( @scala.annotation.varargs def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting DataFrame is hash partitioned. + * @group dfops + * @since 1.6.0 + */ + def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = { + RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions)) + } + + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting DataFrame is hash partitioned. + * @group dfops + * @since 1.6.0 + */ + def distributeBy(partitionExprs: Seq[Column]): DataFrame = { + RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None) + } + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*) + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def localSort(sortExprs: Column*): DataFrame = { + sortInternal(false, sortExprs) + } + /** * Selects column based on the column name and return it as a [[Column]]. * Note that the column name can also reference to a nested column like `a.b`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86d1d390f1918..f4464e0b916f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Strategy, execution} +import org.apache.spark.sql.{Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -455,8 +454,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil - case logical.RepartitionByExpression(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case logical.RepartitionByExpression(expressions, child, nPartitions) => + execution.Exchange(HashPartitioning( + expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c9d6e19d2ce93..6b86c5951b413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,10 +24,14 @@ import scala.util.Random import org.scalatest.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestData.TestData2 +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -997,4 +1001,116 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } } + + /** + * Verifies that there is no Exchange between the Aggregations for `df` + */ + private def verifyNonExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + atFirstAgg = !atFirstAgg + } + case _ => { + if (atFirstAgg) { + fail("Should not have operators between the two aggregations") + } + } + } + } + + /** + * Verifies that there is an Exchange between the Aggregations for `df` + */ + private def verifyExchangingAgg(df: DataFrame) = { + var atFirstAgg: Boolean = false + df.queryExecution.executedPlan.foreach { + case agg: TungstenAggregate => { + if (atFirstAgg) { + fail("Should not have back to back Aggregates") + } + atFirstAgg = true + } + case e: Exchange => atFirstAgg = false + case _ => + } + } + + test("distributeBy and localSort") { + val original = testData.repartition(1) + assert(original.rdd.partitions.length == 1) + val df = original.distributeBy(Column("key") :: Nil, 5) + assert(df.rdd.partitions.length == 5) + checkAnswer(original.select(), df.select()) + + val df2 = original.distributeBy(Column("key") :: Nil, 10) + assert(df2.rdd.partitions.length == 10) + checkAnswer(original.select(), df2.select()) + + // Group by the column we are distributed by. This should generate a plan with no exchange + // between the aggregates + val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count() + verifyNonExchangingAgg(df3) + verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + .groupBy("key", "value").count()) + + // Grouping by just the first distributeBy expr, need to exchange. + verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + .groupBy("key").count()) + + val data = sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData2(i % 10, i))).toDF() + + // Distribute and order by. + val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc) + // Walk each partition and verify that it is sorted descending and does not contain all + // the values. + df4.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue < v) throw new SparkException("Partition is not ordered.") + if (v + 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (allSequential) throw new SparkException("Partition should not be globally ordered") + }) + + // Distribute and order by with multiple order bys + val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc) + // Walk each partition and verify that it is sorted ascending + df5.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (allSequential) throw new SparkException("Partition should not be all sequential") + }) + + // Distribute into one partition and order by. This partition should contain all the values. + val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc) + // Walk each partition and verify that it is sorted descending and not globally sorted. + df6.rdd.foreachPartition(p => { + var previousValue: Int = -1 + var allSequential: Boolean = true + p.foreach(r => { + val v: Int = r.getInt(1) + if (previousValue != -1) { + if (previousValue > v) throw new SparkException("Partition is not ordered.") + if (v - 1 != previousValue) allSequential = false + } + previousValue = v + }) + if (!allSequential) throw new SparkException("Partition should contain all sequential values") + }) + } } From cf04fdfe71abc395163a625cc1f99ec5e54cc07e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sun, 1 Nov 2015 14:42:18 -0800 Subject: [PATCH 112/324] [SPARK-11020][CORE] Wait for HDFS to leave safe mode before initializing HS. Large HDFS clusters may take a while to leave safe mode when starting; this change makes the HS wait for that before doing checks about its configuraton. This means the HS won't stop right away if HDFS is in safe mode and the configuration is not correct, but that should be a very uncommon situation. Author: Marcelo Vanzin Closes #9043 from vanzin/SPARK-11020. --- .../deploy/history/FsHistoryProvider.scala | 104 +++++++++++++++++- .../history/FsHistoryProviderSuite.scala | 65 +++++++++++ 2 files changed, 166 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 80bfda9dddb39..24aa386c7212b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -52,6 +53,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val NOT_STARTED = "" + // Interval between safemode checks. + private val SAFEMODE_CHECK_INTERVAL_S = conf.getTimeAsSeconds( + "spark.history.fs.safemodeCheck.interval", "5s") + // Interval between each check for event log updates private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") @@ -107,9 +112,57 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - initialize() + // Conf option used for testing the initialization code. + val initThread = if (!conf.getBoolean("spark.history.testing.skipInitialize", false)) { + initialize(None) + } else { + null + } + + private[history] def initialize(errorHandler: Option[Thread.UncaughtExceptionHandler]): Thread = { + if (!isFsInSafeMode()) { + startPolling() + return null + } + + // Cannot probe anything while the FS is in safe mode, so spawn a new thread that will wait + // for the FS to leave safe mode before enabling polling. This allows the main history server + // UI to be shown (so that the user can see the HDFS status). + // + // The synchronization in the run() method is needed because of the tests; mockito can + // misbehave if the test is modifying the mocked methods while the thread is calling + // them. + val initThread = new Thread(new Runnable() { + override def run(): Unit = { + try { + clock.synchronized { + while (isFsInSafeMode()) { + logInfo("HDFS is still in safe mode. Waiting...") + val deadline = clock.getTimeMillis() + + TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S) + clock.waitTillTime(deadline) + } + } + startPolling() + } catch { + case _: InterruptedException => + } + } + }) + initThread.setDaemon(true) + initThread.setName(s"${getClass().getSimpleName()}-init") + initThread.setUncaughtExceptionHandler(errorHandler.getOrElse( + new Thread.UncaughtExceptionHandler() { + override def uncaughtException(t: Thread, e: Throwable): Unit = { + logError("Error initializing FsHistoryProvider.", e) + System.exit(1) + } + })) + initThread.start() + initThread + } - private def initialize(): Unit = { + private def startPolling(): Unit = { // Validate the log directory. val path = new Path(logDir) if (!fs.exists(path)) { @@ -170,7 +223,21 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString) + override def getConfig(): Map[String, String] = { + val safeMode = if (isFsInSafeMode()) { + Map("HDFS State" -> "In safe mode, application logs not available.") + } else { + Map() + } + Map("Event log directory" -> logDir.toString) ++ safeMode + } + + override def stop(): Unit = { + if (initThread != null && initThread.isAlive()) { + initThread.interrupt() + initThread.join() + } + } /** * Builds the application list based on the current contents of the log directory. @@ -585,6 +652,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Checks whether HDFS is in safe mode. The API is slightly different between hadoop 1 and 2, + * so we have to resort to ugly reflection (as usual...). + * + * Note that DistributedFileSystem is a `@LimitedPrivate` class, which for all practical reasons + * makes it more public than not. + */ + private[history] def isFsInSafeMode(): Boolean = fs match { + case dfs: DistributedFileSystem => + isFsInSafeMode(dfs) + case _ => + false + } + + // For testing. + private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { + val hadoop1Class = "org.apache.hadoop.hdfs.protocol.FSConstants$SafeModeAction" + val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" + val actionClass: Class[_] = + try { + getClass().getClassLoader().loadClass(hadoop2Class) + } catch { + case _: ClassNotFoundException => + getClass().getClassLoader().loadClass(hadoop1Class) + } + + val action = actionClass.getField("SAFEMODE_GET").get(null) + val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) + method.invoke(dfs, action).asInstanceOf[Boolean] + } + } private[history] object FsHistoryProvider { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 73cff89544dc3..833aab14ca2da 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -24,13 +24,19 @@ import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} import scala.io.Source +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.Path +import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ +import org.mockito.Matchers.any +import org.mockito.Mockito.{doReturn, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ @@ -407,6 +413,65 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("provider correctly checks whether fs is in safe mode") { + val provider = spy(new FsHistoryProvider(createTestConf())) + val dfs = mock(classOf[DistributedFileSystem]) + // Asserts that safe mode is false because we can't really control the return value of the mock, + // since the API is different between hadoop 1 and 2. + assert(!provider.isFsInSafeMode(dfs)) + } + + test("provider waits for safe mode to finish before initializing") { + val clock = new ManualClock() + val conf = createTestConf().set("spark.history.testing.skipInitialize", "true") + val provider = spy(new FsHistoryProvider(conf, clock)) + doReturn(true).when(provider).isFsInSafeMode() + + val initThread = provider.initialize(None) + try { + provider.getConfig().keys should contain ("HDFS State") + + clock.setTime(5000) + provider.getConfig().keys should contain ("HDFS State") + + // Synchronization needed because of mockito. + clock.synchronized { + doReturn(false).when(provider).isFsInSafeMode() + clock.setTime(10000) + } + + eventually(timeout(1 second), interval(10 millis)) { + provider.getConfig().keys should not contain ("HDFS State") + } + } finally { + provider.stop() + } + } + + test("provider reports error after FS leaves safe mode") { + testDir.delete() + val clock = new ManualClock() + val conf = createTestConf().set("spark.history.testing.skipInitialize", "true") + val provider = spy(new FsHistoryProvider(conf, clock)) + doReturn(true).when(provider).isFsInSafeMode() + + val errorHandler = mock(classOf[Thread.UncaughtExceptionHandler]) + val initThread = provider.initialize(Some(errorHandler)) + try { + // Synchronization needed because of mockito. + clock.synchronized { + doReturn(false).when(provider).isFsInSafeMode() + clock.setTime(10000) + } + + eventually(timeout(1 second), interval(10 millis)) { + verify(errorHandler).uncaughtException(any(), any()) + } + } finally { + provider.stop() + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: From f8d93edec82eedab59d50aec06ca2de7e4cf14f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sun, 1 Nov 2015 15:57:42 -0800 Subject: [PATCH 113/324] [SPARK-11073][CORE][YARN] Remove akka dependency in secret key generation. Use standard JDK APIs for that (with a little help from Guava). Most of the changes here are in test code, since there were no tests specific to that part of the code. Author: Marcelo Vanzin Closes #9257 from vanzin/SPARK-11073. --- .../org/apache/spark/SecurityManager.scala | 72 +++++++++++-------- .../apache/spark/SecurityManagerSuite.scala | 23 +++++- .../spark/deploy/LogUrlsStandaloneSuite.scala | 13 +--- .../deploy/worker/WorkerArgumentsTest.scala | 28 +------- .../apache/spark/storage/LocalDirsSuite.scala | 16 +---- .../apache/spark/util/SparkConfWithEnv.scala | 34 +++++++++ .../deploy/yarn/YarnSparkHadoopUtil.scala | 3 +- .../yarn/YarnSparkHadoopUtilSuite.scala | 32 ++++++++- 8 files changed, 138 insertions(+), 83 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 746d2081d4393..64e483e384772 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,11 +17,13 @@ package org.apache.spark +import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} -import java.security.KeyStore +import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate import javax.net.ssl._ +import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text @@ -130,15 +132,16 @@ import org.apache.spark.util.Utils * * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. * - * For Yarn deployments, the secret is automatically generated using the Akka remote - * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed - * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels - * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn - * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn - * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there - * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use - * filters to do authentication. That authentication then happens via the ResourceManager Proxy - * and Spark will use that to do authorization against the view acls. + * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop + * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to + * support different levels of protection. See the Hadoop documentation for more details. Each + * Spark application on YARN gets a different shared secret. + * + * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user + * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web + * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That + * authentication then happens via the ResourceManager Proxy and Spark will use that to do + * authorization against the view acls. * * For other Spark deployments, the shared secret must be specified via the * spark.authenticate.secret config. @@ -189,8 +192,7 @@ import org.apache.spark.util.Utils private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { - // key used to store the spark secret in the Hadoop UGI - private val sparkSecretLookupKey = "sparkCookie" + import SecurityManager._ private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 @@ -365,33 +367,38 @@ private[spark] class SecurityManager(sparkConf: SparkConf) * we throw an exception. */ private def generateSecretKey(): String = { - if (!isAuthenticationEnabled) return null - // first check to see if the secret is already set, else generate a new one if on yarn - val sCookie = if (SparkHadoopUtil.get.isYarnMode) { - val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) - if (secretKey != null) { - logDebug("in yarn mode, getting secret from credentials") - return new Text(secretKey).toString + if (!isAuthenticationEnabled) { + null + } else if (SparkHadoopUtil.get.isYarnMode) { + // In YARN mode, the secure cookie will be created by the driver and stashed in the + // user's credentials, where executors can get it. The check for an array of size 0 + // is because of the test code in YarnSparkHadoopUtilSuite. + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY) + if (secretKey == null || secretKey.length == 0) { + logDebug("generateSecretKey: yarn mode, secret key from credentials is null") + val rnd = new SecureRandom() + val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE + val secret = new Array[Byte](length) + rnd.nextBytes(secret) + + val cookie = HashCodes.fromBytes(secret).toString() + SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie) + cookie } else { - logDebug("getSecretKey: yarn mode, secret key from credentials is null") + new Text(secretKey).toString } - val cookie = akka.util.Crypt.generateSecureCookie - // if we generated the secret then we must be the first so lets set it so t - // gets used by everyone else - SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) - logInfo("adding secret to credentials in yarn mode") - cookie } else { // user must have set spark.authenticate.secret config // For Master/Worker, auth secret is in conf; for Executors, it is in env variable - sys.env.get(SecurityManager.ENV_AUTH_SECRET) + Option(sparkConf.getenv(SecurityManager.ENV_AUTH_SECRET)) .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value - case None => throw new Exception("Error: a secret key must be specified via the " + - SecurityManager.SPARK_AUTH_SECRET_CONF + " config") + case None => + throw new IllegalArgumentException( + "Error: a secret key must be specified via the " + + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } - sCookie } /** @@ -475,6 +482,9 @@ private[spark] object SecurityManager { val SPARK_AUTH_CONF: String = "spark.authenticate" val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" // This is used to set auth secret to an executor's env variable. It should have the same - // value as SPARK_AUTH_SECERET_CONF set in SparkConf + // value as SPARK_AUTH_SECRET_CONF set in SparkConf val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" + + // key used to store the spark secret in the Hadoop UGI + val SECRET_LOOKUP_KEY = "sparkCookie" } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index f29160d834082..26b95c06789f7 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkConfWithEnv, Utils} class SecurityManagerSuite extends SparkFunSuite { @@ -223,5 +223,26 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.hostnameVerifier.isDefined === false) } + test("missing secret authentication key") { + val conf = new SparkConf().set("spark.authenticate", "true") + intercept[IllegalArgumentException] { + new SecurityManager(conf) + } + } + + test("secret authentication key") { + val key = "very secret key" + val conf = new SparkConf() + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, key) + assert(key === new SecurityManager(conf).getSecretKey()) + + val keyFromEnv = "very secret key from env" + val conf2 = new SparkConfWithEnv(Map(SecurityManager.ENV_AUTH_SECRET -> keyFromEnv)) + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, key) + assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) + } + } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 86eb41dd7e5d7..8dd31b4b6fdda 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -25,6 +25,7 @@ import scala.io.Source import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { @@ -53,17 +54,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { test("verify that log urls reflect SPARK_PUBLIC_DNS (SPARK-6175)") { val SPARK_PUBLIC_DNS = "public_dns" - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_PUBLIC_DNS") SPARK_PUBLIC_DNS - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf().set( + val conf = new SparkConfWithEnv(Map("SPARK_PUBLIC_DNS" -> SPARK_PUBLIC_DNS)).set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala index 15f7ca4a6dacc..637e78fda0193 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import org.apache.spark.{SparkConf, SparkFunSuite} - +import org.apache.spark.util.SparkConfWithEnv class WorkerArgumentsTest extends SparkFunSuite { @@ -34,18 +34,7 @@ class WorkerArgumentsTest extends SparkFunSuite { test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { val args = Array("spark://localhost:0000 ") - - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_WORKER_MEMORY") "50000" - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf() + val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "50000")) intercept[IllegalStateException] { new WorkerArguments(args, conf) } @@ -53,18 +42,7 @@ class WorkerArgumentsTest extends SparkFunSuite { test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { val args = Array("spark://localhost:0000 ") - - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_WORKER_MEMORY") "5G" - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } - val conf = new MySparkConf() + val conf = new SparkConfWithEnv(Map("SPARK_WORKER_MEMORY" -> "5G")) val workerArgs = new WorkerArguments(args, conf) assert(workerArgs.memory === 5120) } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index ac6fec56bbf4f..cc50289c7b3ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} - +import org.apache.spark.util.SparkConfWithEnv /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. @@ -45,20 +45,10 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 assert(!new File("/NONEXISTENT_DIR").exists()) - // SPARK_LOCAL_DIRS is a valid directory: - class MySparkConf extends SparkConf(false) { - override def getenv(name: String): String = { - if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") - else super.getenv(name) - } - - override def clone: SparkConf = { - new MySparkConf().setAll(getAll) - } - } // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: - val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH") + val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) + .set("spark.local.dir", "/NONEXISTENT_PATH") assert(new File(Utils.getLocalDir(conf)).exists()) } diff --git a/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala new file mode 100644 index 0000000000000..ddd5edf4f7396 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SparkConfWithEnv.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkConf + +/** + * Customized SparkConf that allows env variables to be overridden. + */ +class SparkConfWithEnv(env: Map[String, String]) extends SparkConf(false) { + override def getenv(name: String): String = { + env.get(name).getOrElse(super.getenv(name)) + } + + override def clone: SparkConf = { + new SparkConfWithEnv(env).setAll(getAll) + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5924daf3ece49..561ad79ee0228 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.nio.charset.StandardCharsets.UTF_8 import java.util.regex.Matcher import java.util.regex.Pattern @@ -81,7 +82,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { override def addSecretKeyToUserCredentials(key: String, secret: String) { val creds = new Credentials() - creds.addSecretKey(new Text(key), secret.getBytes("utf-8")) + creds.addSecretKey(new Text(key), secret.getBytes(UTF_8)) addCurrentUserCredentials(creds) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 9132c56a91754..a70e66d39a64e 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -24,8 +24,10 @@ import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -263,7 +265,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging }) } - def assertNestedHiveException(e: InvocationTargetException): Throwable = { + private def assertNestedHiveException(e: InvocationTargetException): Throwable = { val inner = e.getCause if (inner == null) { fail("No inner cause", e) @@ -274,4 +276,32 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging inner } + // This test needs to live here because it depends on isYarnMode returning true, which can only + // happen in the YARN module. + test("security manager token generation") { + try { + System.setProperty("SPARK_YARN_MODE", "true") + val initial = SparkHadoopUtil.get + .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) + assert(initial === null || initial.length === 0) + + val conf = new SparkConf() + .set(SecurityManager.SPARK_AUTH_CONF, "true") + .set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val sm = new SecurityManager(conf) + + val generated = SparkHadoopUtil.get + .getSecretKeyFromUserCredentials(SecurityManager.SECRET_LOOKUP_KEY) + assert(generated != null) + val genString = new Text(generated).toString() + assert(genString != "unused") + assert(sm.getSecretKey() === genString) + } finally { + // removeSecretKey() was only added in Hadoop 2.6, so instead we just set the secret + // to an empty string. + SparkHadoopUtil.get.addSecretKeyToUserCredentials(SecurityManager.SECRET_LOOKUP_KEY, "") + System.clearProperty("SPARK_YARN_MODE") + } + } + } From 8d3f04375a6010164f3efef28fa1105ce59eea16 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 16:33:40 -0800 Subject: [PATCH 114/324] fix test broken from column name change from cast --- R/pkg/inst/tests/test_sparkSQL.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index f87b1e05e405c..78d036a0c609c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -613,14 +613,13 @@ test_that("coltypes() set the column types", { df1 <- select(df, cast(df$age, "integer")) coltypes(df) <- c("character", "integer") - expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) value <- collect(df[, 2])[[3, 1]] expect_equal(value, collect(df1)[[3, 1]]) expect_equal(value, 22) coltypes(df) <- c(NA, "numeric") - expect_equal(dtypes(df), list(c("cast(name as string)", "string"), - c("cast(cast(age as int) as double)", "double"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) expect_error(coltypes(df) <- c("character"), "Length of type vector should match the number of columns for DataFrame") From 3e770a64a48c271c5829d2bcbdc1d6430cda2ac9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 1 Nov 2015 18:37:27 -0800 Subject: [PATCH 115/324] [SPARK-9298][SQL] Add pearson correlation aggregation function JIRA: https://issues.apache.org/jira/browse/SPARK-9298 This patch adds pearson correlation aggregation function based on `AggregateExpression2`. Author: Liang-Chi Hsieh Closes #8587 from viirya/corr_aggregation. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/aggregate/functions.scala | 159 ++++++++++++++++++ .../expressions/aggregate/utils.scala | 6 + .../sql/catalyst/expressions/aggregates.scala | 18 ++ .../org/apache/spark/sql/functions.scala | 18 ++ .../execution/HiveCompatibilitySuite.scala | 7 +- .../execution/AggregationQuerySuite.scala | 104 ++++++++++++ 7 files changed, 311 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed9fcfe014f0c..5f3ec74ac0d92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -178,6 +178,7 @@ object FunctionRegistry { // aggregate functions expression[Average]("avg"), + expression[Corr]("corr"), expression[Count]("count"), expression[First]("first"), expression[First]("first_value"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 281404f285a98..5d2eb7b017ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -23,6 +23,7 @@ import java.util import com.clearspring.analytics.hash.MurmurHash import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -524,6 +525,164 @@ case class Sum(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = Cast(currentSum, resultType) } +/** + * Compute Pearson correlation between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + * + * @param left one of the expressions to compute correlation with. + * @param right another expression to compute correlation with. + */ +case class Corr( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { + + def children: Seq[Expression] = Seq(left, right) + + def nullable: Boolean = false + + def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + def inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) + + val aggBufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("MkX", DoubleType)(), + AttributeReference("MkY", DoubleType)(), + AttributeReference("count", LongType)()) + + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 + private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 + private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 + private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 + private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 + private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 + private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 + private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 + private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def initialize(buffer: MutableRow): Unit = { + buffer.setDouble(mutableAggBufferOffset, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) + buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(mutableAggBufferOffset) + var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer.getLong(mutableAggBufferOffsetPlus5) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + + buffer.setDouble(mutableAggBufferOffset, xAvg) + buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) + + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. + if (count2 > 0) { + var xAvg = buffer1.getDouble(mutableAggBufferOffset) + var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer1.getLong(mutableAggBufferOffsetPlus5) + + val xAvg2 = buffer2.getDouble(inputAggBufferOffset) + val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) + val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) + val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) + val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 + MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 + count = totalCount + + buffer1.setDouble(mutableAggBufferOffset, xAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer1.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(mutableAggBufferOffsetPlus5) + if (count > 0) { + val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + val corr = Ck / math.sqrt(MkX * MkY) + if (corr.isNaN) { + null + } else { + corr + } + } else { + null + } + } +} + // scalastyle:off /** * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. This class diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index c911ec53f1ba0..564174f9b64e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -127,6 +127,12 @@ object Utils { mode = aggregate.Complete, isDistinct = true) + case expressions.Corr(left, right) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Corr(left, right), + mode = aggregate.Complete, + isDistinct = false) + case expressions.ApproxCountDistinct(child, rsd) => aggregate.AggregateExpression2( aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index c1bab6d36ab29..bf59660c385ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -747,6 +747,24 @@ case class LastFunction( } } +/** + * Calculate Pearson Correlation Coefficient for the given columns. + * Only support AggregateExpression2. + * + */ +case class Corr(left: Expression, right: Expression) + extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { + override def nullable: Boolean = false + override def dataType: DoubleType.type = DoubleType + override def toString: String = s"CORRELATION($left, $right)" + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def newInstance(): AggregateFunction1 = { + throw new UnsupportedOperationException( + "Corr only supports the new AggregateExpression2 and can only be used " + + "when spark.sql.useAggregate2 = true") + } +} + // Compute standard deviation based on online algorithm specified here: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c1737b1ef663c..5a5c695e6ab3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -172,6 +172,24 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(column1: Column, column2: Column): Column = + Corr(column1.expr, column2.expr) + + /** + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @group agg_funcs + * @since 1.6.0 + */ + def corr(columnName1: String, columnName2: String): Column = + corr(Column(columnName1), Column(columnName2)) + /** * Aggregate function: returns the number of items in a group. * diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9e357bf348c94..6ed40b03975d0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -304,7 +304,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // classpath problems "compute_stats.*", - "udf_bitmap_.*" + "udf_bitmap_.*", + + // The difference between the double numbers generated by Hive and Spark + // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) + "udaf_corr" ) /** @@ -857,7 +861,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "type_cast_1", "type_widening", "udaf_collect_set", - "udaf_corr", "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f38a3f63c3b58..0cf0e0aab9eb2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate @@ -556,6 +557,109 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(0, null, 1, 1, null, 0) :: Nil) } + test("pearson correlation") { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr1 - 1.0) < 1e-12) + val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr2 + 1.0) < 1e-12) + // non-trivial example. To reproduce in python, use: + // >>> from scipy.stats import pearsonr + // >>> import numpy as np + // >>> a = np.array(range(20)) + // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)]) + // >>> pearsonr(a, b) + // (0.95723391394758572, 3.8902121417802199e-11) + // In R, use: + // > a <- 0:19 + // > b <- mapply(function(x) x * x - 2 * x + 3.5, a) + // > cor(a, b) + // [1] 0.957233913947585835 + val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") + val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) + + val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") + val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0) + assert(corr4 == Row(null)) + + val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") + val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + assert(math.abs(corr5 - 1.0) < 1e-12) + val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + assert(math.abs(corr6 + 1.0) < 1e-12) + + // Test for udaf_corr in HiveCompatibilitySuite + // udaf_corr has been blacklisted due to numerical errors + // We test it here: + // SELECT corr(b, c) FROM covar_tab WHERE a < 1; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a < 3; => NULL + // SELECT corr(b, c) FROM covar_tab WHERE a = 3; => NULL + // SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a; => + // 1 NULL + // 2 NULL + // 3 NULL + // 4 NULL + // 5 NULL + // 6 NULL + // SELECT corr(b, c) FROM covar_tab; => 0.6633880657639323 + + val covar_tab = Seq[(Integer, Integer, Integer)]( + (1, null, 15), + (2, 3, null), + (3, 7, 12), + (4, 4, 14), + (5, 8, 17), + (6, 2, 11)).toDF("a", "b", "c") + + covar_tab.registerTempTable("covar_tab") + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 1 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a < 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT corr(b, c) FROM covar_tab WHERE a = 3 + """.stripMargin), + Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a + """.stripMargin), + Row(1, null) :: + Row(2, null) :: + Row(3, null) :: + Row(4, null) :: + Row(5, null) :: + Row(6, null) :: Nil) + + val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) + + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + val errorMessage = intercept[SparkException] { + val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") + val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + }.getMessage + assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + + "Corr only supports the new AggregateExpression2")) + } + } + test("test Last implemented based on AggregateExpression1") { // TODO: Remove this test once we remove AggregateExpression1. import org.apache.spark.sql.functions._ From e963070c13f56fbc2dfaf9f5d4e69d34afd0957c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Sun, 1 Nov 2015 23:52:50 -0800 Subject: [PATCH 116/324] [SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree* Author: Yu ISHIKAWA Closes #9402 from yu-iskw/SPARK-9722. --- .../org/apache/spark/ml/tree/impl/RandomForest.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 96d5652857e08..4a3b12d1440b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val splits = findSplits(retaggedInput, metadata) + val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => @@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata + * @param seed random seed * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging { */ protected[tree] def findSplits( input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): Array[Array[Split]] = { + metadata: DecisionTreeMetadata, + seed : Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } else { new Array[LabeledPoint](0) } From e209fa271ae57dc8849f8b1241bf1ea7d6d3d62c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 2 Nov 2015 08:52:52 +0000 Subject: [PATCH 117/324] [SPARK-11271][SPARK-11016][CORE] Use Spark BitSet instead of RoaringBitmap to reduce memory usage JIRA: https://issues.apache.org/jira/browse/SPARK-11271 As reported in the JIRA ticket, when there are too many tasks, the memory usage of MapStatus will cause problem. Use BitSet instead of RoaringBitMap should be more efficient in memory usage. Author: Liang-Chi Hsieh Closes #9243 from viirya/mapstatus-bitset. --- core/pom.xml | 4 -- .../apache/spark/scheduler/MapStatus.scala | 13 +++-- .../spark/serializer/KryoSerializer.scala | 10 +--- .../apache/spark/util/collection/BitSet.scala | 28 +++++++++-- .../serializer/KryoSerializerSuite.scala | 6 --- .../spark/util/collection/BitSetSuite.scala | 49 +++++++++++++++++++ pom.xml | 5 -- 7 files changed, 82 insertions(+), 33 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 319a50049a82d..1b6b13517bd56 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -173,10 +173,6 @@ net.jpountz.lz4 lz4 - - org.roaringbitmap - RoaringBitmap - commons-net commons-net diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1efce124c0a6b..180c8d1827e13 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} -import org.roaringbitmap.RoaringBitmap - import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.collection.BitSet import org.apache.spark.util.Utils /** @@ -133,7 +132,7 @@ private[spark] class CompressedMapStatus( private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, - private[this] var emptyBlocks: RoaringBitmap, + private[this] var emptyBlocks: BitSet, private[this] var avgSize: Long) extends MapStatus with Externalizable { @@ -146,7 +145,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { - if (emptyBlocks.contains(reduceId)) { + if (emptyBlocks.get(reduceId)) { 0 } else { avgSize @@ -161,7 +160,7 @@ private[spark] class HighlyCompressedMapStatus private ( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) - emptyBlocks = new RoaringBitmap() + emptyBlocks = new BitSet emptyBlocks.readExternal(in) avgSize = in.readLong() } @@ -177,15 +176,15 @@ private[spark] object HighlyCompressedMapStatus { // From a compression standpoint, it shouldn't matter whether we track empty or non-empty // blocks. From a performance standpoint, we benefit from tracking empty blocks because // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. - val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length + val emptyBlocks = new BitSet(totalNumBlocks) while (i < totalNumBlocks) { var size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 totalSize += size } else { - emptyBlocks.add(i) + emptyBlocks.set(i) } i += 1 } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index c5195c1143a8f..bc51d4f2820c8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -30,7 +30,6 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} -import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast @@ -39,7 +38,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -363,12 +362,7 @@ private[serializer] object KryoSerializer { classOf[StorageLevel], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], - classOf[RoaringBitmap], - classOf[RoaringArray], - classOf[RoaringArray.Element], - classOf[Array[RoaringArray.Element]], - classOf[ArrayContainer], - classOf[BitmapContainer], + classOf[BitSet], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 7ab67fc3a2de9..85c5bdbfcebc0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,14 +17,21 @@ package org.apache.spark.util.collection +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import org.apache.spark.util.{Utils => UUtils} + + /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(numBits: Int) extends Serializable { +class BitSet(private[this] var numBits: Int) extends Externalizable { - private val words = new Array[Long](bit2words(numBits)) - private val numWords = words.length + private var words = new Array[Long](bit2words(numBits)) + private def numWords = words.length + + def this() = this(0) /** * Compute the capacity (number of bits) that can be represented @@ -230,4 +237,19 @@ class BitSet(numBits: Int) extends Serializable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 + + override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException { + out.writeInt(numBits) + words.foreach(out.writeLong(_)) + } + + override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException { + numBits = in.readInt() + words = new Array[Long](bit2words(numBits)) + var index = 0 + while (index < words.length) { + words(index) = in.readLong() + index += 1 + } + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e428414cf6e85..afe2e80358ca0 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -322,12 +322,6 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrationRequired", "true") - // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 - // values, and they use a bitmap (dense) if they have more than 4096 values, and an - // array (sparse) if they use less. So we just create two cases, one sparse and one dense. - // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly - // empty blocks - val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index 69dbfa9cd7141..b0db0988eeaab 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.util.collection +import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.util.{Utils => UUtils} class BitSetSuite extends SparkFunSuite { @@ -152,4 +155,50 @@ class BitSetSuite extends SparkFunSuite { assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(86) === -1) } + + test("read and write externally") { + val tempDir = UUtils.createTempDir() + val outputFile = File.createTempFile("bits", null, tempDir) + + val fos = new FileOutputStream(outputFile) + val oos = new ObjectOutputStream(fos) + + // Create BitSet + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + + for (i <- 0 until 100) { + assert(!bitset.get(i)) + } + + setBits.foreach(i => bitset.set(i)) + + for (i <- 0 until 100) { + if (setBits.contains(i)) { + assert(bitset.get(i)) + } else { + assert(!bitset.get(i)) + } + } + assert(bitset.cardinality() === setBits.size) + + bitset.writeExternal(oos) + oos.close() + + val fis = new FileInputStream(outputFile) + val ois = new ObjectInputStream(fis) + + // Read BitSet from the file + val bitset2 = new BitSet(0) + bitset2.readExternal(ois) + + for (i <- 0 until 100) { + if (setBits.contains(i)) { + assert(bitset2.get(i)) + } else { + assert(!bitset2.get(i)) + } + } + assert(bitset2.cardinality() === setBits.size) + } } diff --git a/pom.xml b/pom.xml index 3dfc434fb553b..50c8f29cdbcd4 100644 --- a/pom.xml +++ b/pom.xml @@ -623,11 +623,6 @@ - - org.roaringbitmap - RoaringBitmap - 0.4.5 - commons-net commons-net From ea4a3e7d06dd4a0f669460513b27469c468214fb Mon Sep 17 00:00:00 2001 From: Yongjia Wang Date: Mon, 2 Nov 2015 08:59:35 +0000 Subject: [PATCH 118/324] [SPARK-11413][BUILD] Bump joda-time version to 2.9 for java 8 and s3 It's a known issue that joda-time before 2.8.1 is incompatible with java 1.8u60 or later, which causes s3 request to fail. This affects Spark when using s3 as data source. https://github.com/aws/aws-sdk-java/issues/444 Author: Yongjia Wang Closes #9379 from yongjiaw/SPARK-11413. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 50c8f29cdbcd4..762bfc7282335 100644 --- a/pom.xml +++ b/pom.xml @@ -176,7 +176,7 @@ 3.2.10 2.7.8 1.9 - 2.5 + 2.9 3.5.2 1.3.9 0.9.2 From 767522dc4e66dd26773d41d1576945187180d2b9 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Mon, 2 Nov 2015 21:31:10 +0800 Subject: [PATCH 119/324] [SPARK-10786][SQL] Take the whole statement to generate the CommandProcessor In the now implementation of `SparkSQLCLIDriver.scala`: `val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf)` `CommandProcessorFactory` only take the first token of the statement, and this will be hard to diff the statement `delete jar xxx` and `delete from xxx`. So maybe it's better to take the whole statement into the `CommandProcessorFactory`. And in [HiveCommand](https://github.com/SaintBacchus/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/processors/HiveCommand.java#L76), it already special handing these two statement. ```java if(command.length > 1 && "from".equalsIgnoreCase(command[1])) { //special handling for SQL "delete from where..." return null; } ``` Author: huangzhaowei Closes #8895 from SaintBacchus/SPARK-10786. --- .../apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 62e912c69abc6..6419002a2aa89 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -290,7 +290,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(tokens, hconf) if (proc != null) { // scalastyle:off println From 74ba95228d71a6dc4e95fef19f41dabe7c363d9e Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 2 Nov 2015 23:07:30 +0800 Subject: [PATCH 120/324] [SPARK-11311][SQL] spark cannot describe temporary functions When describe temporary function, spark would return 'Unable to find function', this is not right. Author: Daoyuan Wang Closes #9277 from adrian-wang/functionreg. --- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 6 +++++- .../spark/sql/hive/execution/HiveQuerySuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 2ccad474b4f7a..0b5e863506142 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -119,7 +119,11 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) annotation.value(), annotation.extended())) } else { - None + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + name, + null, + null)) } }.getOrElse(None)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b52f7d4b57899..e597d6865f67a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -953,6 +953,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE t1") } + test("CREATE TEMPORARY FUNCTION") { + val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath + sql(s"ADD JAR $funcJar") + sql( + """CREATE TEMPORARY FUNCTION udtf_count2 AS + | 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'""".stripMargin) + assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) + sql("DROP TEMPORARY FUNCTION udtf_count2") + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") From a930e624eb9feb0f7d37d99dcb8178feb9c0f177 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 2 Nov 2015 10:23:30 -0800 Subject: [PATCH 121/324] [SPARK-9817][YARN] Improve the locality calculation of containers by taking pending container requests into consideraion This is a follow-up PR to further improve the locality calculation by considering the pending container's request. Since the locality preferences of tasks may be shifted from time to time, current localities of pending container requests may not fully match the new preferences, this PR improve it by removing outdated, unmatched container requests and replace with new requests. sryza please help to review, thanks a lot. Author: jerryshao Closes #8100 from jerryshao/SPARK-9817. --- .../spark/deploy/yarn/ApplicationMaster.scala | 2 +- ...yPreferredContainerPlacementStrategy.scala | 60 +++++++++++++-- .../spark/deploy/yarn/YarnAllocator.scala | 73 +++++++++++++++---- .../ContainerPlacementStrategySuite.scala | 38 ++++++++-- .../deploy/yarn/YarnAllocatorSuite.scala | 26 +++---- 5 files changed, 159 insertions(+), 40 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4b4d9990ce9f9..c6a6d7ac56bf3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -375,7 +375,7 @@ private[spark] class ApplicationMaster( } } try { - val numPendingAllocate = allocator.getNumPendingAllocate + val numPendingAllocate = allocator.getPendingAllocate.size val sleepInterval = if (numPendingAllocate > 0) { val currentAllocationInterval = diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index 081780204e424..2ec189de7c914 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -18,9 +18,11 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.{ArrayBuffer, HashMap, Set} +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf @@ -30,8 +32,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack /** * This strategy is calculating the optimal locality preferences of YARN containers by considering * the node ratio of pending tasks, number of required cores/containers and and locality of current - * existing containers. The target of this algorithm is to maximize the number of tasks that - * would run locally. + * existing and pending allocated containers. The target of this algorithm is to maximize the number + * of tasks that would run locally. * * Consider a situation in which we have 20 tasks that require (host1, host2, host3) * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores @@ -91,6 +93,11 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( * @param numLocalityAwareTasks number of locality required tasks * @param hostToLocalTaskCount a map to store the preferred hostname and possible task * numbers running on it, used as hints for container allocation + * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the + * expected locality preference by considering the existing + * containers + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. * @return node localities and rack localities, each locality is an array of string, * the length of localities is the same as number of containers */ @@ -98,10 +105,12 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( numContainer: Int, numLocalityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int], - allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]], + localityMatchedPendingAllocations: Seq[ContainerRequest] ): Array[ContainerLocalityPreferences] = { val updatedHostToContainerCount = expectedHostToContainerCount( - numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap, + localityMatchedPendingAllocations) val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum // The number of containers to allocate, divided into two groups, one with preferred locality, @@ -158,20 +167,28 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( * @param localityAwareTasks number of locality aware tasks * @param hostToLocalTaskCount a map to store the preferred hostname and possible task * numbers running on it, used as hints for container allocation + * @param allocatedHostToContainersMap host to allocated containers map, used to calculate the + * expected locality preference by considering the existing + * containers + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. * @return a map with hostname as key and required number of containers on this host as value */ private def expectedHostToContainerCount( localityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int], - allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]], + localityMatchedPendingAllocations: Seq[ContainerRequest] ): Map[String, Int] = { val totalLocalTaskNum = hostToLocalTaskCount.values.sum + val pendingHostToContainersMap = pendingHostToContainerCount(localityMatchedPendingAllocations) + hostToLocalTaskCount.map { case (host, count) => val expectedCount = count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum - val existedCount = allocatedHostToContainersMap.get(host) - .map(_.size) - .getOrElse(0) + // Take the locality of pending containers into consideration + val existedCount = allocatedHostToContainersMap.get(host).map(_.size).getOrElse(0) + + pendingHostToContainersMap.getOrElse(host, 0.0) // If existing container can not fully satisfy the expected number of container, // the required container number is expected count minus existed count. Otherwise the @@ -179,4 +196,31 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) } } + + /** + * According to the locality ratio and number of container requests, calculate the host to + * possible number of containers for pending allocated containers. + * + * If current locality ratio of hosts is: Host1 : Host2 : Host3 = 20 : 20 : 10, + * and pending container requests is 3, so the possible number of containers on + * Host1 : Host2 : Host3 will be 1.2 : 1.2 : 0.6. + * @param localityMatchedPendingAllocations A sequence of pending container request which + * matches the localities of current required tasks. + * @return a Map with hostname as key and possible number of containers on this host as value + */ + private def pendingHostToContainerCount( + localityMatchedPendingAllocations: Seq[ContainerRequest]): Map[String, Double] = { + val pendingHostToContainerCount = new HashMap[String, Int]() + localityMatchedPendingAllocations.foreach { cr => + cr.getNodes.asScala.foreach { n => + val count = pendingHostToContainerCount.getOrElse(n, 0) + 1 + pendingHostToContainerCount(n) = count + } + } + + val possibleTotalContainerNum = pendingHostToContainerCount.values.sum + val localityMatchedPendingNum = localityMatchedPendingAllocations.size.toDouble + pendingHostToContainerCount.mapValues(_ * localityMatchedPendingNum / possibleTotalContainerNum) + .toMap + } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 875bbd4e4e3d5..a0cf1b4aa469b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -157,15 +157,19 @@ private[yarn] class YarnAllocator( def getNumExecutorsFailed: Int = numExecutorsFailed /** - * Number of container requests that have not yet been fulfilled. + * A sequence of pending container requests that have not yet been fulfilled. */ - def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST) + def getPendingAllocate: Seq[ContainerRequest] = getPendingAtLocation(ANY_HOST) /** - * Number of container requests at the given location that have not yet been fulfilled. + * A sequence of pending container requests at the given location that have not yet been + * fulfilled. */ - private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala.map(_.size).sum + private def getPendingAtLocation(location: String): Seq[ContainerRequest] = { + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala + .flatMap(_.asScala) + .toSeq + } /** * Request as many executors from the ResourceManager as needed to reach the desired total. If @@ -251,20 +255,31 @@ private[yarn] class YarnAllocator( * Visible for testing. */ def updateResourceRequests(): Unit = { - val numPendingAllocate = getNumPendingAllocate + val pendingAllocate = getPendingAllocate + val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning - // TODO. Consider locality preferences of pending container requests. - // Since the last time we made container requests, stages have completed and been submitted, - // and that the localities at which we requested our pending executors - // no longer apply to our current needs. We should consider to remove all outstanding - // container requests and add requests anew each time to avoid this. if (missing > 0) { logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") + // Split the pending container request into three groups: locality matched list, locality + // unmatched list and non-locality list. Take the locality matched container request into + // consideration of container placement, treat as allocated containers. + // For locality unmatched and locality free container requests, cancel these container + // requests, since required locality preference has been changed, recalculating using + // container placement strategy. + val (localityMatched, localityUnMatched, localityFree) = splitPendingAllocationsByLocality( + hostToLocalTaskCounts, pendingAllocate) + + // Remove the outdated container request and recalculate the requested container number + localityUnMatched.foreach(amClient.removeContainerRequest) + localityFree.foreach(amClient.removeContainerRequest) + val updatedNumContainer = missing + localityUnMatched.size + localityFree.size + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( - missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + updatedNumContainer, numLocalityAwareTasks, hostToLocalTaskCounts, + allocatedHostToContainersMap, localityMatched) for (locality <- containerLocalityPreferences) { val request = createContainerRequest(resource, locality.nodes, locality.racks) @@ -291,7 +306,7 @@ private[yarn] class YarnAllocator( * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. */ - protected def createContainerRequest( + private def createContainerRequest( resource: Resource, nodes: Array[String], racks: Array[String]): ContainerRequest = { @@ -535,6 +550,38 @@ private[yarn] class YarnAllocator( private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + /** + * Split the pending container requests into 3 groups based on current localities of pending + * tasks. + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. + * @param pendingAllocations A sequence of pending allocation container request. + * @return A tuple of 3 sequences, first is a sequence of locality matched container + * requests, second is a sequence of locality unmatched container requests, and third is a + * sequence of locality free container requests. + */ + private def splitPendingAllocationsByLocality( + hostToLocalTaskCount: Map[String, Int], + pendingAllocations: Seq[ContainerRequest] + ): (Seq[ContainerRequest], Seq[ContainerRequest], Seq[ContainerRequest]) = { + val localityMatched = ArrayBuffer[ContainerRequest]() + val localityUnMatched = ArrayBuffer[ContainerRequest]() + val localityFree = ArrayBuffer[ContainerRequest]() + + val preferredHosts = hostToLocalTaskCount.keySet + pendingAllocations.foreach { cr => + val nodes = cr.getNodes + if (nodes == null) { + localityFree += cr + } else if (nodes.asScala.toSet.intersect(preferredHosts).nonEmpty) { + localityMatched += cr + } else { + localityUnMatched += cr + } + } + + (localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq) + } } private object YarnAllocator { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala index b7fe4ccc67a38..afb4b691b52de 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.yarn +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite @@ -26,6 +27,9 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B private val yarnAllocatorSuite = new YarnAllocatorSuite import yarnAllocatorSuite._ + def createContainerRequest(nodes: Array[String]): ContainerRequest = + new ContainerRequest(containerResource, nodes, null, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) + override def beforeEach() { yarnAllocatorSuite.beforeEach() } @@ -44,7 +48,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array( Array("host3", "host4", "host5"), @@ -66,7 +71,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null, Array("host2", "host3"), Array("host2", "host3"))) @@ -86,7 +92,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) } @@ -105,7 +112,8 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B )) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null, null, null)) } @@ -118,8 +126,28 @@ class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with B handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( - 1, 0, Map.empty, handler.allocatedHostToContainersMap) + 1, 0, Map.empty, handler.allocatedHostToContainersMap, Seq.empty) assert(localities.map(_.nodes) === Array(null)) } + + test("allocate locality preferred containers by considering the localities of pending requests") { + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val pendingAllocationRequests = Seq( + createContainerRequest(Array("host2", "host3")), + createContainerRequest(Array("host1", "host4"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), + handler.allocatedHostToContainersMap, pendingAllocationRequests) + + assert(localities.map(_.nodes) === Array(Array("host3"))) + } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 5d05f514adde3..bd80036c5cfa7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -116,7 +116,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(1) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) val container = createContainer("host1") handler.handleAllocatedContainers(Array(container)) @@ -134,7 +134,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host1") @@ -154,7 +154,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(2) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (2) + handler.getPendingAllocate.size should be (2) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -174,11 +174,11 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (3) + handler.getPendingAllocate.size should be (3) val container = createContainer("host1") handler.handleAllocatedContainers(Array(container)) @@ -189,18 +189,18 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) } test("decrease total requested executors to less than currently running") { val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (3) + handler.getPendingAllocate.size should be (3) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -210,7 +210,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.updateResourceRequests() - handler.getNumPendingAllocate should be (0) + handler.getPendingAllocate.size should be (0) handler.getNumExecutorsRunning should be (2) } @@ -218,7 +218,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -233,14 +233,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.updateResourceRequests() handler.processCompletedContainers(statuses.toSeq) handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (1) + handler.getPendingAllocate.size should be (1) } test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (4) + handler.getPendingAllocate.size should be (4) val container1 = createContainer("host1") val container2 = createContainer("host2") @@ -255,7 +255,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.processCompletedContainers(statuses.toSeq) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) - handler.getNumPendingAllocate should be (2) + handler.getPendingAllocate.size should be (2) handler.getNumExecutorsFailed should be (2) handler.getNumUnexpectedContainerRelease should be (2) } From 71d1c907dec446db566b19f912159fd8f46deb7d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 2 Nov 2015 10:26:36 -0800 Subject: [PATCH 122/324] [SPARK-10997][CORE] Add "client mode" to netty rpc env. "Client mode" means the RPC env will not listen for incoming connections. This allows certain processes in the Spark stack (such as Executors or tha YARN client-mode AM) to act as pure clients when using the netty-based RPC backend, reducing the number of sockets needed by the app and also the number of open ports. Client connections are also preferred when endpoints that actually have a listening socket are involved; so, for example, if a Worker connects to a Master and the Master needs to send a message to a Worker endpoint, that client connection will be used, even though the Worker is also listening for incoming connections. With this change, the workaround for SPARK-10987 isn't necessary anymore, and is removed. The AM connects to the driver in "client mode", and that connection is used for all driver <-> AM communication, and so the AM is properly notified when the connection goes down. Author: Marcelo Vanzin Closes #9210 from vanzin/SPARK-10997. --- .../scala/org/apache/spark/SparkEnv.scala | 7 +- .../CoarseGrainedExecutorBackend.scala | 20 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 8 +- .../apache/spark/rpc/netty/Dispatcher.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 245 ++++++++++-------- .../org/apache/spark/rpc/netty/Outbox.scala | 24 +- .../spark/rpc/netty/RpcEndpointAddress.scala | 24 +- .../cluster/CoarseGrainedClusterMessage.scala | 10 +- .../CoarseGrainedSchedulerBackend.scala | 18 +- .../cluster/YarnSchedulerBackend.scala | 2 - .../org/apache/spark/rpc/RpcEnvSuite.scala | 50 ++-- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 11 +- .../rpc/netty/NettyRpcAddressSuite.scala | 7 +- .../spark/rpc/netty/NettyRpcEnvSuite.scala | 9 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 8 +- network/yarn/pom.xml | 5 + .../spark/deploy/yarn/ApplicationMaster.scala | 6 +- 17 files changed, 266 insertions(+), 190 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 398e0936906a3..23ae9360f6a22 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -252,7 +252,8 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, + clientMode = !isDriver) val actorSystem: ActorSystem = if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem @@ -262,9 +263,11 @@ object SparkEnv extends Logging { } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // In the non-driver case, the RPC env's address may be null since it may not be listening + // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else { + } else if (rpcEnv.address != null) { conf.set("spark.executor.port", rpcEnv.address.port.toString) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a9c6a05ecd434..c2ebf30596215 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { - Utils.checkHostPort(hostPort, "Expected hostport") - var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive: PartialFunction[Any, Unit] = { - case RegisteredExecutor => + case RegisteredExecutor(hostname) => logInfo("Successfully registered with driver") - val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => @@ -163,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname, port, executorConf, - new SecurityManager(executorConf)) + new SecurityManager(executorConf), + clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) @@ -188,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) - // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore. - val boundPort = env.conf.getInt("spark.executor.port", 0) - assert(boundPort != 0) - - // Start the CoarseGrainedExecutorBackend endpoint. - val sparkHostPort = hostname + ":" + boundPort + // SparkEnv will set spark.executor.port if the rpc env is listening for incoming + // connections (e.g., if it's using akka). Otherwise, the executor is running in + // client mode only, and does not accept incoming connections. + val sparkHostPort = env.conf.getOption("spark.executor.port").map { port => + hostname + ":" + port + }.orNull env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 2c4a8b9a0a878..a560fd10cdf76 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -43,9 +43,10 @@ private[spark] object RpcEnv { host: String, port: Int, conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { + securityManager: SecurityManager, + clientMode: Boolean = false): RpcEnv = { // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) + val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) getRpcEnvFactory(conf).create(config) } } @@ -139,4 +140,5 @@ private[spark] case class RpcEnvConfig( name: String, host: String, port: Int, - securityManager: SecurityManager) + securityManager: SecurityManager, + clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 7bf44a6565b61..eb25d6c7b721b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -55,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val addr = RpcEndpointAddress(nettyEnv.address, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) synchronized { if (stopped) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 284284eb805b7..09093819bb22c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -17,10 +17,12 @@ package org.apache.spark.rpc.netty import java.io._ +import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -29,6 +31,7 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal +import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -45,15 +48,14 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - // Override numConnectionsPerPeer to 1 for RPC. private val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) - private val transportContext = - new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) + private val transportContext = new TransportContext(transportConf, + new NettyRpcHandler(dispatcher, this)) private val clientFactory = { val bootstraps: java.util.List[TransportClientBootstrap] = @@ -95,7 +97,7 @@ private[netty] class NettyRpcEnv( } } - def start(port: Int): Unit = { + def startServer(port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) @@ -107,9 +109,9 @@ private[netty] class NettyRpcEnv( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } + @Nullable override lazy val address: RpcAddress = { - require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort) + if (server != null) RpcAddress(host, server.getPort()) else null } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @@ -120,7 +122,7 @@ private[netty] class NettyRpcEnv( val addr = RpcEndpointAddress(uri) val endpointRef = new NettyRpcEndpointRef(conf, addr, this) val verifier = new NettyRpcEndpointRef( - conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this) verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) @@ -135,28 +137,34 @@ private[netty] class NettyRpcEnv( dispatcher.stop(endpointRef) } - private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = { - val targetOutbox = { - val outbox = outboxes.get(address) - if (outbox == null) { - val newOutbox = new Outbox(this, address) - val oldOutbox = outboxes.putIfAbsent(address, newOutbox) - if (oldOutbox == null) { - newOutbox + private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { + if (receiver.client != null) { + receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + } else { + require(receiver.address != null, + "Cannot send message to client endpoint with no listen address.") + val targetOutbox = { + val outbox = outboxes.get(receiver.address) + if (outbox == null) { + val newOutbox = new Outbox(this, receiver.address) + val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } } else { - oldOutbox + outbox } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(receiver.address) + targetOutbox.stop() } else { - outbox + targetOutbox.send(message) } } - if (stopped.get) { - // It's possible that we put `targetOutbox` after stopping. So we need to clean it. - outboxes.remove(address) - targetOutbox.stop() - } else { - targetOutbox.send(message) - } } private[netty] def send(message: RequestMessage): Unit = { @@ -174,17 +182,14 @@ private[netty] class NettyRpcEnv( }(ThreadUtils.sameThread) } else { // Message to a remote RPC endpoint. - postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { + postToOutbox(message.receiver, OutboxMessage(serialize(message), + (e) => { logWarning(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[Ack](response) + }, + (client, response) => { + val ack = deserialize[Ack](client, response) logDebug(s"Receive ack from ${ack.sender}") - } - })) + })) } } @@ -214,16 +219,14 @@ private[netty] class NettyRpcEnv( } }(ThreadUtils.sameThread) } else { - postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { + postToOutbox(message.receiver, OutboxMessage(serialize(message), + (e) => { if (!promise.tryFailure(e)) { logWarning("Ignore Exception", e) } - } - - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) + }, + (client, response) => { + val reply = deserialize[AskResponse](client, response) if (reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { logWarning(s"Ignore failure: ${reply.reply}") @@ -231,8 +234,7 @@ private[netty] class NettyRpcEnv( } else if (!promise.trySuccess(reply.reply)) { logWarning(s"Ignore message: ${reply}") } - } - })) + })) } promise.future } @@ -243,9 +245,11 @@ private[netty] class NettyRpcEnv( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } - private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + NettyRpcEnv.currentClient.withValue(client) { + deserialize { () => + javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + } } } @@ -254,7 +258,7 @@ private[netty] class NettyRpcEnv( } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address.host, address.port, endpointName).toString + new RpcEndpointAddress(address, endpointName).toString override def shutdown(): Unit = { cleanup() @@ -297,6 +301,7 @@ private[netty] class NettyRpcEnv( deserializationAction() } } + } private[netty] object NettyRpcEnv extends Logging { @@ -312,6 +317,13 @@ private[netty] object NettyRpcEnv extends Logging { * }}} */ private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) + + /** + * Similar to `currentEnv`, this variable references the client instance associated with an + * RPC, in case it's needed to find out the remote address during deserialization. + */ + private[netty] val currentClient = new DynamicVariable[TransportClient](null) + } private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { @@ -324,47 +336,68 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) - val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => - nettyEnv.start(actualPort) - (nettyEnv, actualPort) - } - try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 - } catch { - case NonFatal(e) => - nettyEnv.shutdown() - throw e + if (!config.clientMode) { + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.startServer(actualPort) + (nettyEnv, actualPort) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } } + nettyEnv } } -private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf) +/** + * The NettyRpcEnv version of RpcEndpointRef. + * + * This class behaves differently depending on where it's created. On the node that "owns" the + * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance. + * + * On other machines that receive a serialized version of the reference, the behavior changes. The + * instance will keep track of the TransportClient that sent the reference, so that messages + * to the endpoint are sent over the client connection, instead of needing a new connection to + * be opened. + * + * The RpcAddress of this ref can be null; what that means is that the ref can only be used through + * a client connection, since the process hosting the endpoint is not listening for incoming + * connections. These refs should not be shared with 3rd parties, since they will not be able to + * send messages to the endpoint. + * + * @param conf Spark configuration. + * @param endpointAddress The address where the endpoint is listening. + * @param nettyEnv The RpcEnv associated with this ref. + * @param local Whether the referenced endpoint lives in the same process. + */ +private[netty] class NettyRpcEndpointRef( + @transient private val conf: SparkConf, + endpointAddress: RpcEndpointAddress, + @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) with Serializable with Logging { - @transient @volatile private var nettyEnv: NettyRpcEnv = _ + @transient @volatile var client: TransportClient = _ - @transient @volatile private var _address: RpcEndpointAddress = _ + private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null + private val _name = endpointAddress.name - def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { - this(conf) - this._address = _address - this.nettyEnv = nettyEnv - } - - override def address: RpcAddress = _address.toRpcAddress + override def address: RpcAddress = if (_address != null) _address.rpcAddress else null private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - _address = in.readObject().asInstanceOf[RpcEndpointAddress] nettyEnv = NettyRpcEnv.currentEnv.value + client = NettyRpcEnv.currentClient.value } private def writeObject(out: ObjectOutputStream): Unit = { out.defaultWriteObject() - out.writeObject(_address) } - override def name: String = _address.name + override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() @@ -429,41 +462,43 @@ private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessa private[netty] case class RpcFailure(e: Throwable) /** - * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast - * network events and forward messages to [[Dispatcher]]. + * Dispatches incoming RPCs to registered endpoints. + * + * The handler keeps track of all client instances that communicate with it, so that the RpcEnv + * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e., + * one that is not listening for incoming connections, but rather needs to be contacted via the + * client socket). + * + * Events are sent on a per-connection basis, so if a client opens multiple connections to the + * RpcEnv, multiple connection / disconnection events will be created for that client (albeit + * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { - private type ClientAddress = RpcAddress - private type RemoteEnvAddress = RpcAddress - - // Store all client addresses and their NettyRpcEnv addresses. - // TODO: Is this even necessary? - @GuardedBy("this") - private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() override def receive( - client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { - val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + client: TransportClient, + message: Array[Byte], + callback: RpcResponseCallback): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) - val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - var dispatchRemoteProcessConnected = false - synchronized { - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time, fire "RemoteProcessConnected" - dispatchRemoteProcessConnected = true - } + if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } - if (dispatchRemoteProcessConnected) { - dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) - } - dispatcher.postRemoteMessage(requestMessage, callback) + val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) + val messageToDispatch = if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, + requestMessage.needReply) + } else { + requestMessage + } + dispatcher.postRemoteMessage(messageToDispatch, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager @@ -472,15 +507,7 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) - } - if (broadcastMessage.isEmpty) { - logError(cause.getMessage, cause) - } else { - dispatcher.postToAll(broadcastMessage.get) - } + dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -493,15 +520,9 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + clients.remove(client) nettyEnv.removeOutbox(clientAddr) - val messageOpt: Option[RemoteProcessDisconnected] = - synchronized { - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - Some(RemoteProcessDisconnected(remoteEnvAddress)) - } - } - messageOpt.foreach(dispatcher.postToAll) + dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 7d9d593b36241..2f6817f2eb935 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -26,7 +26,21 @@ import org.apache.spark.SparkException import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.RpcAddress -private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback) +private[netty] case class OutboxMessage(content: Array[Byte], + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, Array[Byte]) => Unit) { + + def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + _onSuccess(client, response) + } + } + +} private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { @@ -68,7 +82,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -108,7 +122,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.callback) + _client.sendRpc(message.content, message.createCallback(_client)) } else { assert(stopped == true) } @@ -181,7 +195,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message.callback.onFailure(e) + message._onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -215,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index 87b6236936817..d2e94f943aba5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -23,15 +23,25 @@ import org.apache.spark.rpc.RpcAddress /** * An address identifier for an RPC endpoint. * - * @param host host name of the remote process. - * @param port the port the remote RPC environment binds to. - * @param name name of the remote endpoint. + * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only + * connection and can only be reached via the client that sent the endpoint reference. + * + * @param rpcAddress The socket address of the endpint. + * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { +private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { + + require(name != null, "RpcEndpoint name must be provided.") - def toRpcAddress: RpcAddress = RpcAddress(host, port) + def this(host: String, port: Int, name: String) = { + this(RpcAddress(host, port), name) + } - override val toString = s"spark://$name@$host:$port" + override val toString = if (rpcAddress != null) { + s"spark://$name@${rpcAddress.host}:${rpcAddress.port}" + } else { + s"spark-client://$name" + } } private[netty] object RpcEndpointAddress { @@ -51,7 +61,7 @@ private[netty] object RpcEndpointAddress { uri.getQuery != null) { throw new SparkException("Invalid Spark URL: " + sparkUrl) } - RpcEndpointAddress(host, port, name) + new RpcEndpointAddress(host, port, name) } catch { case e: java.net.URISyntaxException => throw new SparkException("Invalid Spark URL: " + sparkUrl, e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 8103efa7302e7..f3d0d85476772 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -38,7 +38,7 @@ private[spark] object CoarseGrainedClusterMessages { sealed trait RegisterExecutorResponse - case object RegisteredExecutor extends CoarseGrainedClusterMessage + case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage @@ -51,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages { hostPort: String, cores: Int, logUrls: Map[String, String]) - extends CoarseGrainedClusterMessage { - Utils.checkHostPort(hostPort, "Expected host port") - } + extends CoarseGrainedClusterMessage case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends CoarseGrainedClusterMessage @@ -107,8 +105,4 @@ private[spark] object CoarseGrainedClusterMessages { // Used internally by executors to shut themselves down. case object Shutdown extends CoarseGrainedClusterMessage - // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back - // to the AM. - case object DriverHello extends CoarseGrainedClusterMessage - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 55a564b5c8eac..439a11927026b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -131,16 +131,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { - logInfo("Registered executor: " + executorRef + " with ID " + executorId) - addressToExecutorId(executorRef.address) = executorId + // If the executor's rpc env is not listening for incoming connections, `hostPort` + // will be null, and the client connection should be used to contact the executor. + val executorAddress = if (executorRef.address != null) { + executorRef.address + } else { + context.senderAddress + } + logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId") + addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host, + cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -151,7 +157,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } // Note: some tests expect the reply to come after we put the executor in the map - context.reply(RegisteredExecutor) + context.reply(RegisteredExecutor(executorAddress.host)) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index e483688edef5f..cb24072d7d941 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -170,8 +170,6 @@ private[spark] abstract class YarnSchedulerBackend( case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") amEndpoint = Option(am) - // See SPARK-10987. - am.send(DriverHello) case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 3bead6395d384..834e4743df866 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -48,7 +48,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv test("send a message locally") { @volatile var message: String = null @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345) + val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,26 +497,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") val remoteAddress = anotherEnv.address rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List(("onConnected", remoteAddress))) + // anotherEnv is connected in client mode, so the remote address may be unknown depending on + // the implementation. Account for that when doing checks. + if (remoteAddress != null) { + assert(events === List(("onConnected", remoteAddress))) + } else { + assert(events.size === 1) + assert(events(0)._1 === "onConnected") + } } anotherEnv.shutdown() anotherEnv.awaitTermination() eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) + // Account for anotherEnv not having an address due to running in client mode. + if (remoteAddress != null) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress)) || + events === List( + ("onConnected", remoteAddress), + ("onDisconnected", remoteAddress))) + } else { + val eventNames = events.map(_._1) + assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || + eventNames === List("onConnected", "onDisconnected")) + } } } @@ -529,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -558,7 +572,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate.secret", "good") val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) try { @volatile var message: String = null @@ -589,7 +603,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate.secret", "good") val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 4aa75c9230b2c..6478ab51c4da2 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + override def createRpcEnv(conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) } test("setupEndpointRef: systemName, address, endpointName") { @@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 973a07a0bde3a..56743ba650b41 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -22,8 +22,13 @@ import org.apache.spark.SparkFunSuite class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = RpcEndpointAddress("localhost", 12345, "test") + val addr = new RpcEndpointAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } + test("toString for client mode") { + val addr = RpcEndpointAddress(null, "test") + assert(addr.toString === "spark-client://test") + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index be19668e17c04..ce83087ec04d6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -22,8 +22,13 @@ import org.apache.spark.rpc._ class NettyRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { - val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf)) + override def createRpcEnv( + conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf), + clientMode) new NettyRpcEnvFactory().create(config) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 5430e4c0c4d6c..f9d8e80c98b66 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[Array[Byte]]))(any())). + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { @@ -42,7 +42,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } test("connectionTerminated") { @@ -57,9 +57,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.connectionTerminated(client) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( - RemoteProcessDisconnected(RpcAddress("localhost", 12345))) + RemoteProcessDisconnected(RpcAddress("localhost", 40000))) } } diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 541ed9a8d0ab6..e2360eff5cfe1 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -54,6 +54,11 @@ org.apache.hadoop hadoop-client + + org.slf4j + slf4j-api + provided + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c6a6d7ac56bf3..12ae350e4cef6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -321,7 +321,8 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr, + clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -574,9 +575,6 @@ private[spark] class ApplicationMaster( case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver.send(x) - - case DriverHello => - // SPARK-10987: no action needed for this message. } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { From f92f334ca47c03b980b06cf300aa652d0ffa1880 Mon Sep 17 00:00:00 2001 From: Jason White Date: Mon, 2 Nov 2015 10:49:06 -0800 Subject: [PATCH 123/324] [SPARK-11437] [PYSPARK] Don't .take when converting RDD to DataFrame with provided schema When creating a DataFrame from an RDD in PySpark, `createDataFrame` calls `.take(10)` to verify the first 10 rows of the RDD match the provided schema. Similar to https://issues.apache.org/jira/browse/SPARK-8070, but that issue affected cases where a schema was not provided. Verifying the first 10 rows is of limited utility and causes the DAG to be executed non-lazily. If necessary, I believe this verification should be done lazily on all rows. However, since the caller is providing a schema to follow, I think it's acceptable to simply fail if the schema is incorrect. marmbrus We chatted about this at SparkSummitEU. davies you made a similar change for the infer-schema path in https://github.com/apache/spark/pull/6606 Author: Jason White Closes #9392 from JasonMWhite/createDataFrame_without_take. --- python/pyspark/sql/context.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 79453658a167a..924bb6433de0e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -318,13 +318,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data From b3aedca6b55c678e40a5961e2fd3af4cb8c52bba Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 2 Nov 2015 14:36:37 -0600 Subject: [PATCH 124/324] [SPARK-11456][TESTS] Remove deprecated junit.framework in Java tests Replace use of `junit.framework` with `org.junit`, and touch up tests in question Author: Sean Owen Closes #9411 from srowen/SPARK-11456. --- .../spark/unsafe/bitset/BitSetSuite.java | 4 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 11 +-- .../unsafe/types/CalendarIntervalSuite.java | 78 +++++++++---------- .../spark/unsafe/types/UTF8StringSuite.java | 74 +++++++++--------- 4 files changed, 84 insertions(+), 83 deletions(-) diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java index a93fc0ee297c4..14e38683df4ab 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import junit.framework.Assert; +import org.junit.Assert; import org.junit.Test; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -25,7 +25,7 @@ public class BitSetSuite { private static BitSet createBitSet(int capacity) { - assert capacity % 64 == 0; + Assert.assertEquals(0, capacity % 64); return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 2f8cb132ac8b4..e759cb33b3e6a 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -17,12 +17,13 @@ package org.apache.spark.unsafe.hash; +import java.nio.charset.StandardCharsets; import java.util.HashSet; import java.util.Random; import java.util.Set; -import junit.framework.Assert; import org.apache.spark.unsafe.Platform; +import org.junit.Assert; import org.junit.Test; /** @@ -56,7 +57,7 @@ public void randomizedStressTest() { Random rand = new Random(); // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int vint = rand.nextInt(); long lint = rand.nextLong(); @@ -76,7 +77,7 @@ public void randomizedStressTestBytes() { Random rand = new Random(); // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; @@ -98,10 +99,10 @@ public void randomizedStressTestBytes() { public void randomizedStressTestPaddedStrings() { int size = 64000; // A set used to track collision rate. - Set hashcodes = new HashSet(); + Set hashcodes = new HashSet<>(); for (int i = 0; i < size; i++) { int byteArrSize = 8; - byte[] strBytes = ("" + i).getBytes(); + byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 80d4982c4b576..9e69e264ff287 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -19,7 +19,7 @@ import org.junit.Test; -import static junit.framework.Assert.*; +import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.CalendarInterval.*; public class CalendarIntervalSuite { @@ -42,19 +42,19 @@ public void toStringTest() { CalendarInterval i; i = new CalendarInterval(34, 0); - assertEquals(i.toString(), "interval 2 years 10 months"); + assertEquals("interval 2 years 10 months", i.toString()); i = new CalendarInterval(-34, 0); - assertEquals(i.toString(), "interval -2 years -10 months"); + assertEquals("interval -2 years -10 months", i.toString()); i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); + assertEquals("interval 3 weeks 13 hours 123 microseconds", i.toString()); i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); - assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); + assertEquals("interval -3 weeks -13 hours -123 microseconds", i.toString()); i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); + assertEquals("interval 2 years 10 months 3 weeks 13 hours 123 microseconds", i.toString()); } @Test @@ -73,32 +73,32 @@ public void fromStringTest() { input = "interval -5 years 23 month"; CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); input = "interval -5 years 23 month "; - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); input = " interval -5 years 23 month "; - assertEquals(CalendarInterval.fromString(input), result); + assertEquals(fromString(input), result); // Error cases input = "interval 3month 1 hour"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "interval 3 moth 1 hour"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "interval"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = "int"; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = ""; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); input = null; - assertEquals(CalendarInterval.fromString(input), null); + assertNull(fromString(input)); } @Test @@ -108,15 +108,15 @@ public void fromYearMonthStringTest() { input = "99-10"; i = new CalendarInterval(99 * 12 + 10, 0L); - assertEquals(CalendarInterval.fromYearMonthString(input), i); + assertEquals(fromYearMonthString(input), i); input = "-8-10"; i = new CalendarInterval(-8 * 12 - 10, 0L); - assertEquals(CalendarInterval.fromYearMonthString(input), i); + assertEquals(fromYearMonthString(input), i); try { input = "99-15"; - CalendarInterval.fromYearMonthString(input); + fromYearMonthString(input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("month 15 outside range")); @@ -131,19 +131,19 @@ public void fromDayTimeStringTest() { input = "5 12:40:30.999999999"; i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); - assertEquals(CalendarInterval.fromDayTimeString(input), i); + assertEquals(fromDayTimeString(input), i); input = "10 0:12:0.888"; i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); - assertEquals(CalendarInterval.fromDayTimeString(input), i); + assertEquals(fromDayTimeString(input), i); input = "-3 0:0:0"; i = new CalendarInterval(0, -3 * MICROS_PER_DAY); - assertEquals(CalendarInterval.fromDayTimeString(input), i); + assertEquals(fromDayTimeString(input), i); try { input = "5 30:12:20"; - CalendarInterval.fromDayTimeString(input); + fromDayTimeString(input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("hour 30 outside range")); @@ -151,7 +151,7 @@ public void fromDayTimeStringTest() { try { input = "5 30-12"; - CalendarInterval.fromDayTimeString(input); + fromDayTimeString(input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("not match day-time format")); @@ -165,19 +165,19 @@ public void fromSingleUnitStringTest() { input = "12"; i = new CalendarInterval(12 * 12, 0L); - assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + assertEquals(fromSingleUnitString("year", input), i); input = "100"; i = new CalendarInterval(0, 100 * MICROS_PER_DAY); - assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + assertEquals(fromSingleUnitString("day", input), i); input = "1999.38888"; i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); - assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + assertEquals(fromSingleUnitString("second", input), i); try { input = String.valueOf(Integer.MAX_VALUE); - CalendarInterval.fromSingleUnitString("year", input); + fromSingleUnitString("year", input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("outside range")); @@ -185,7 +185,7 @@ public void fromSingleUnitStringTest() { try { input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); - CalendarInterval.fromSingleUnitString("hour", input); + fromSingleUnitString("hour", input); fail("Expected to throw an exception for the invalid input"); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains("outside range")); @@ -197,16 +197,16 @@ public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - CalendarInterval interval = CalendarInterval.fromString(input); - CalendarInterval interval2 = CalendarInterval.fromString(input2); + CalendarInterval interval = fromString(input); + CalendarInterval interval2 = fromString(input2); assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = CalendarInterval.fromString(input); - interval2 = CalendarInterval.fromString(input2); + interval = fromString(input); + interval2 = fromString(input2); assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @@ -216,25 +216,25 @@ public void subtractTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - CalendarInterval interval = CalendarInterval.fromString(input); - CalendarInterval interval2 = CalendarInterval.fromString(input2); + CalendarInterval interval = fromString(input); + CalendarInterval interval2 = fromString(input2); assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = CalendarInterval.fromString(input); - interval2 = CalendarInterval.fromString(input2); + interval = fromString(input); + interval2 = fromString(input2); assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } - private void testSingleUnit(String unit, int number, int months, long microseconds) { + private static void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; CalendarInterval result = new CalendarInterval(months, microseconds); - assertEquals(CalendarInterval.fromString(input1), result); - assertEquals(CalendarInterval.fromString(input2), result); + assertEquals(fromString(input1), result); + assertEquals(fromString(input2), result); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 98aa8a2469a75..e21ffdcff9abf 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -24,13 +24,13 @@ import com.google.common.collect.ImmutableMap; import org.junit.Test; -import static junit.framework.Assert.*; +import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.UTF8String.*; public class UTF8StringSuite { - private void checkBasic(String str, int len) throws UnsupportedEncodingException { + private static void checkBasic(String str, int len) throws UnsupportedEncodingException { UTF8String s1 = fromString(str); UTF8String s2 = fromBytes(str.getBytes("utf8")); assertEquals(s1.numChars(), len); @@ -42,12 +42,12 @@ private void checkBasic(String str, int len) throws UnsupportedEncodingException assertEquals(s1.hashCode(), s2.hashCode()); - assertEquals(s1.compareTo(s2), 0); + assertEquals(0, s1.compareTo(s2)); - assertEquals(s1.contains(s2), true); - assertEquals(s2.contains(s1), true); - assertEquals(s1.startsWith(s1), true); - assertEquals(s1.endsWith(s1), true); + assertTrue(s1.contains(s2)); + assertTrue(s2.contains(s1)); + assertTrue(s1.startsWith(s1)); + assertTrue(s1.endsWith(s1)); } @Test @@ -59,8 +59,8 @@ public void basicTest() throws UnsupportedEncodingException { @Test public void emptyStringTest() { - assertEquals(fromString(""), EMPTY_UTF8); - assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(EMPTY_UTF8, fromString("")); + assertEquals(EMPTY_UTF8, fromBytes(new byte[0])); assertEquals(0, EMPTY_UTF8.numChars()); assertEquals(0, EMPTY_UTF8.numBytes()); } @@ -76,9 +76,9 @@ public void prefix() { byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; byte[] buf2 = {1, 2, 3}; - UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); - UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); - UTF8String str3 = UTF8String.fromBytes(buf2); + UTF8String str1 = fromBytes(buf1, 0, 3); + UTF8String str2 = fromBytes(buf1, 0, 8); + UTF8String str3 = fromBytes(buf2); assertTrue(str1.getPrefix() - str2.getPrefix() < 0); assertEquals(str1.getPrefix(), str3.getPrefix()); } @@ -98,7 +98,7 @@ public void compareTo() { assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); } - protected void testUpperandLower(String upper, String lower) { + protected static void testUpperandLower(String upper, String lower) { UTF8String us = fromString(upper); UTF8String ls = fromString(lower); assertEquals(ls, us.toLowerCase()); @@ -127,22 +127,22 @@ public void titleCase() { @Test public void concatTest() { assertEquals(EMPTY_UTF8, concat()); - assertEquals(null, concat((UTF8String) null)); + assertNull(concat((UTF8String) null)); assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); assertEquals(fromString("ab"), concat(fromString("ab"))); assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); - assertEquals(null, concat(fromString("a"), null, fromString("c"))); - assertEquals(null, concat(fromString("a"), null, null)); - assertEquals(null, concat(null, null, null)); + assertNull(concat(fromString("a"), null, fromString("c"))); + assertNull(concat(fromString("a"), null, null)); + assertNull(concat(null, null, null)); assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头"))); } @Test public void concatWsTest() { // Returns null if the separator is null - assertEquals(null, concatWs(null, (UTF8String)null)); - assertEquals(null, concatWs(null, fromString("a"))); + assertNull(concatWs(null, (UTF8String) null)); + assertNull(concatWs(null, fromString("a"))); // If separator is null, concatWs should skip all null inputs and never return null. UTF8String sep = fromString("哈哈"); @@ -381,16 +381,16 @@ public void split() { @Test public void levenshteinDistance() { - assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); - assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); - assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); - assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); - assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); - assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); - assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); - assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); - assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); - assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); + assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8)); + assertEquals(1, EMPTY_UTF8.levenshteinDistance(fromString("a"))); + assertEquals(7, fromString("aaapppp").levenshteinDistance(EMPTY_UTF8)); + assertEquals(1, fromString("frog").levenshteinDistance(fromString("fog"))); + assertEquals(3, fromString("fly").levenshteinDistance(fromString("ant"))); + assertEquals(7, fromString("elephant").levenshteinDistance(fromString("hippo"))); + assertEquals(7, fromString("hippo").levenshteinDistance(fromString("elephant"))); + assertEquals(8, fromString("hippo").levenshteinDistance(fromString("zzzzzzzz"))); + assertEquals(1, fromString("hello").levenshteinDistance(fromString("hallo"))); + assertEquals(4, fromString("世界千世").levenshteinDistance(fromString("千a世b"))); } @Test @@ -432,14 +432,14 @@ public void createBlankString() { @Test public void findInSet() { - assertEquals(fromString("ab").findInSet(fromString("ab")), 1); - assertEquals(fromString("a,b").findInSet(fromString("b")), 2); - assertEquals(fromString("abc,b,ab,c,def").findInSet(fromString("ab")), 3); - assertEquals(fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab")), 1); - assertEquals(fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab")), 4); - assertEquals(fromString(",ab,abc,b,ab,c,def").findInSet(fromString("")), 1); - assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab")), 4); - assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def")), 6); + assertEquals(1, fromString("ab").findInSet(fromString("ab"))); + assertEquals(2, fromString("a,b").findInSet(fromString("b"))); + assertEquals(3, fromString("abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(1, fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(4, fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(1, fromString(",ab,abc,b,ab,c,def").findInSet(fromString(""))); + assertEquals(4, fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab"))); + assertEquals(6, fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def"))); } @Test From 33ae7a35daa86c34f1f9f72f997e0c2d4cd8abec Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 2 Nov 2015 13:42:16 -0800 Subject: [PATCH 125/324] [SPARK-11358][MLLIB] deprecate runs in k-means This PR deprecates `runs` in k-means. `runs` introduces extra complexity and overhead in MLlib's k-means implementation. I haven't seen much usage with `runs` not equal to `1`. We don't have a unit test for it either. We can deprecate this method in 1.6, and void it in 1.7. It helps us simplify the implementation. cc: srowen Author: Xiangrui Meng Closes #9322 from mengxr/SPARK-11358. --- .../main/scala/org/apache/spark/mllib/clustering/KMeans.scala | 4 ++-- python/pyspark/mllib/clustering.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7168aac32c997..2895db7c9061b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -107,7 +107,7 @@ class KMeans private ( * Number of runs of the algorithm to execute in parallel. */ @Since("1.4.0") - @Experimental + @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") def getRuns: Int = runs /** @@ -117,7 +117,7 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Since("0.8.0") - @Experimental + @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index d1c3755a785f2..8629aa5a17164 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,6 +17,7 @@ import sys import array as pyarray +import warnings if sys.version > '3': xrange = range @@ -170,6 +171,9 @@ class KMeans(object): def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None): """Train a k-means clustering model.""" + if runs != 1: + warnings.warn( + "Support for runs is deprecated in 1.6.0. This param will have no effect in 1.7.0.") clusterInitialModel = [] if initialModel is not None: if not isinstance(initialModel, KMeansModel): From db11ee5e56e5fac59895c772a9a87c5ac86888ef Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 2 Nov 2015 13:51:53 -0800 Subject: [PATCH 126/324] [SPARK-11371] Make "mean" an alias for "avg" operator From Reynold in the thread 'Exception when using some aggregate operators' (http://search-hadoop.com/m/q3RTt0xFr22nXB4/): I don't think these are bugs. The SQL standard for average is "avg", not "mean". Similarly, a distinct count is supposed to be written as "count(distinct col)", not "countDistinct(col)". We can, however, make "mean" an alias for "avg" to improve compatibility between DataFrame and SQL. Author: tedyu Closes #9332 from ted-yu/master. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/hive/execution/AggregationQuerySuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5f3ec74ac0d92..24c1a7b7ac5af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -185,6 +185,7 @@ object FunctionRegistry { expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), + expression[Average]("mean"), expression[Min]("min"), expression[Stddev]("stddev"), expression[StddevPop]("stddev_pop"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0cf0e0aab9eb2..74061db0f28af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -298,6 +298,15 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + checkAnswer( + sqlContext.sql( + """ + |SELECT key, mean(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + checkAnswer( sqlContext.sql( """ From 2804674a7af8f11eeb1280459bc9145815398eed Mon Sep 17 00:00:00 2001 From: Rishabh Bhardwaj Date: Mon, 2 Nov 2015 14:03:50 -0800 Subject: [PATCH 127/324] [SPARK-11383][DOCS] Replaced example code in mllib-naive-bayes.md/mllib-isotonic-regression.md using include_example I have made the required changes in mllib-naive-bayes.md/mllib-isotonic-regression.md and also verified them. Kindle Review it. Author: Rishabh Bhardwaj Closes #9353 from rishabhbhardwaj/SPARK-11383. --- docs/mllib-isotonic-regression.md | 124 +----------------- docs/mllib-naive-bayes.md | 89 +------------ .../mllib/JavaIsotonicRegressionExample.java | 86 ++++++++++++ .../examples/mllib/JavaNaiveBayesExample.java | 64 +++++++++ .../mllib/isotonic_regression_example.py | 56 ++++++++ .../main/python/mllib/naive_bayes_example.py | 56 ++++++++ .../mllib/IsotonicRegressionExample.scala | 66 ++++++++++ .../examples/mllib/NaiveBayesExample.scala | 57 ++++++++ 8 files changed, 391 insertions(+), 207 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java create mode 100644 examples/src/main/python/mllib/isotonic_regression_example.py create mode 100644 examples/src/main/python/mllib/naive_bayes_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index f91a697b31891..85f9226b43416 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -61,42 +61,8 @@ labels and real labels in the test set. Refer to the [`IsotonicRegression` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.IsotonicRegression) and [`IsotonicRegressionModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.regression.IsotonicRegressionModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} - -val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") - -// Create label, feature, weight tuples from input data with weight set to default value 1.0. -val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - (parts(0), parts(1), 1.0) -} - -// Split data into training (60%) and test (40%) sets. -val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0) -val test = splits(1) - -// Create isotonic regression model from training data. -// Isotonic parameter defaults to true so it is only shown for demonstration -val model = new IsotonicRegression().setIsotonic(true).run(training) - -// Create tuples of predicted and real labels. -val predictionAndLabel = test.map { point => - val predictedLabel = model.predict(point._2) - (predictedLabel, point._1) -} - -// Calculate mean squared error between predicted and real labels. -val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() -println("Mean Squared Error = " + meanSquaredError) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala %} -
    Data are read from a file where each line has a format label,feature i.e. 4710.28,500.00. The data are split to training and testing set. @@ -105,66 +71,8 @@ labels and real labels in the test set. Refer to the [`IsotonicRegression` Java docs](api/java/org/apache/spark/mllib/regression/IsotonicRegression.html) and [`IsotonicRegressionModel` Java docs](api/java/org/apache/spark/mllib/regression/IsotonicRegressionModel.html) for details on the API. -{% highlight java %} -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.IsotonicRegressionModel; -import scala.Tuple2; -import scala.Tuple3; - -JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt"); - -// Create label, feature, weight tuples from input data with weight set to default value 1.0. -JavaRDD> parsedData = data.map( - new Function>() { - public Tuple3 call(String line) { - String[] parts = line.split(","); - return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); - } - } -); - -// Split data into training (60%) and test (40%) sets. -JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L); -JavaRDD> training = splits[0]; -JavaRDD> test = splits[1]; - -// Create isotonic regression model from training data. -// Isotonic parameter defaults to true so it is only shown for demonstration -IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); - -// Create tuples of predicted and real labels. -JavaPairRDD predictionAndLabel = test.mapToPair( - new PairFunction, Double, Double>() { - @Override public Tuple2 call(Tuple3 point) { - Double predictedLabel = model.predict(point._2()); - return new Tuple2(predictedLabel, point._1()); - } - } -); - -// Calculate mean squared error between predicted and real labels. -Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( - new Function, Object>() { - @Override public Object call(Tuple2 pl) { - return Math.pow(pl._1() - pl._2(), 2); - } - } -).rdd()).mean(); - -System.out.println("Mean Squared Error = " + meanSquaredError); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java %}
    -
    Data are read from a file where each line has a format label,feature i.e. 4710.28,500.00. The data are split to training and testing set. @@ -173,32 +81,6 @@ labels and real labels in the test set. Refer to the [`IsotonicRegression` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.IsotonicRegression) and [`IsotonicRegressionModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.regression.IsotonicRegressionModel) for more details on the API. -{% highlight python %} -import math -from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel - -data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") - -# Create label, feature, weight tuples from input data with weight set to default value 1.0. -parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) - -# Split data into training (60%) and test (40%) sets. -training, test = parsedData.randomSplit([0.6, 0.4], 11) - -# Create isotonic regression model from training data. -# Isotonic parameter defaults to true so it is only shown for demonstration -model = IsotonicRegression.train(training) - -# Create tuples of predicted and real labels. -predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) - -# Calculate mean squared error between predicted and real labels. -meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() -print("Mean Squared Error = " + str(meanSquaredError)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = IsotonicRegressionModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/isotonic_regression_example.py %}
    diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index f4f6a10c8299e..60ac6c7e5bb1a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -40,32 +40,8 @@ can be used for evaluation and prediction. Refer to the [`NaiveBayes` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes) and [`NaiveBayesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint - -val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") -val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -} -// Split data into training (60%) and test (40%). -val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) -val training = splits(0) -val test = splits(1) - -val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") - -val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) -val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = NaiveBayesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala %} -
    [NaiveBayes](api/java/org/apache/spark/mllib/classification/NaiveBayes.html) implements @@ -77,40 +53,8 @@ can be used for evaluation and prediction. Refer to the [`NaiveBayes` Java docs](api/java/org/apache/spark/mllib/classification/NaiveBayes.html) and [`NaiveBayesModel` Java docs](api/java/org/apache/spark/mllib/classification/NaiveBayesModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.classification.NaiveBayes; -import org.apache.spark.mllib.classification.NaiveBayesModel; -import org.apache.spark.mllib.regression.LabeledPoint; - -JavaRDD training = ... // training set -JavaRDD test = ... // test set - -final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); - -JavaPairRDD predictionAndLabel = - test.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -double accuracy = predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return pl._1().equals(pl._2()); - } - }).count() / (double) test.count(); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java %}
    -
    [NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial @@ -124,33 +68,6 @@ Note that the Python API does not yet support model save/load but will in the fu Refer to the [`NaiveBayes` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) and [`NaiveBayesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint - -def parseLine(line): - parts = line.split(',') - label = float(parts[0]) - features = Vectors.dense([float(x) for x in parts[1].split(' ')]) - return LabeledPoint(label, features) - -data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) - -# Split data aproximately into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 0) - -# Train a naive Bayes model. -model = NaiveBayes.train(training, 1.0) - -# Make prediction and test accuracy. -predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) -accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() - -# Save and load model -model.save(sc, "myModelPath") -sameModel = NaiveBayesModel.load(sc, "myModelPath") -{% endhighlight %} - +{% include_example python/mllib/naive_bayes_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java new file mode 100644 index 0000000000000..37e709b4cbc03 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; +import scala.Tuple3; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.regression.IsotonicRegression; +import org.apache.spark.mllib.regression.IsotonicRegressionModel; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaIsotonicRegressionExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaIsotonicRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // $example on$ + JavaRDD data = jsc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + + // Create label, feature, weight tuples from input data with weight set to default value 1.0. + JavaRDD> parsedData = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(","); + return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + } + } + ); + + // Split data into training (60%) and test (40%) sets. + JavaRDD>[] splits = parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD> training = splits[0]; + JavaRDD> test = splits[1]; + + // Create isotonic regression model from training data. + // Isotonic parameter defaults to true so it is only shown for demonstration + final IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + + // Create tuples of predicted and real labels. + JavaPairRDD predictionAndLabel = test.mapToPair( + new PairFunction, Double, Double>() { + @Override + public Tuple2 call(Tuple3 point) { + Double predictedLabel = model.predict(point._2()); + return new Tuple2(predictedLabel, point._1()); + } + } + ); + + // Calculate mean squared error between predicted and real labels. + Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( + new Function, Object>() { + @Override + public Object call(Tuple2 pl) { + return Math.pow(pl._1() - pl._2(), 2); + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + meanSquaredError); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); + IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java new file mode 100644 index 0000000000000..e6a5904bd71f0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.NaiveBayes; +import org.apache.spark.mllib.classification.NaiveBayesModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaNaiveBayesExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaNaiveBayesExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // $example on$ + String path = "data/mllib/sample_naive_bayes_data.txt"; + JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD(); + JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4}, 12345); + JavaRDD training = tmp[0]; // training set + JavaRDD test = tmp[1]; // test set + final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); + JavaPairRDD predictionAndLabel = + test.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + double accuracy = predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return pl._1().equals(pl._2()); + } + }).count() / (double) test.count(); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myNaiveBayesModel"); + NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/myNaiveBayesModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/isotonic_regression_example.py b/examples/src/main/python/mllib/isotonic_regression_example.py new file mode 100644 index 0000000000000..89dc9f4b6611a --- /dev/null +++ b/examples/src/main/python/mllib/isotonic_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Isotonic Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +import math +from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonIsotonicRegressionExample") + + # $example on$ + data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + + # Create label, feature, weight tuples from input data with weight set to default value 1.0. + parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + + # Split data into training (60%) and test (40%) sets. + training, test = parsedData.randomSplit([0.6, 0.4], 11) + + # Create isotonic regression model from training data. + # Isotonic parameter defaults to true so it is only shown for demonstration + model = IsotonicRegression.train(training) + + # Create tuples of predicted and real labels. + predictionAndLabel = test.map(lambda p: (model.predict(p[1]), p[0])) + + # Calculate mean squared error between predicted and real labels. + meanSquaredError = predictionAndLabel.map(lambda pl: math.pow((pl[0] - pl[1]), 2)).mean() + print("Mean Squared Error = " + str(meanSquaredError)) + + # Save and load model + model.save(sc, "target/tmp/myIsotonicRegressionModel") + sameModel = IsotonicRegressionModel.load(sc, "target/tmp/myIsotonicRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py new file mode 100644 index 0000000000000..a2e7dacf25491 --- /dev/null +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +NaiveBayes Example. +""" +from __future__ import print_function + +# $example on$ +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint + + +def parseLine(line): + parts = line.split(',') + label = float(parts[0]) + features = Vectors.dense([float(x) for x in parts[1].split(' ')]) + return LabeledPoint(label, features) +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonNaiveBayesExample") + + # $example on$ + data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) + + # Split data aproximately into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=0) + + # Train a naive Bayes model. + model = NaiveBayes.train(training, 1.0) + + # Make prediction and test accuracy. + predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label)) + accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + + # Save and load model + model.save(sc, "target/tmp/myNaiveBayesModel") + sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala new file mode 100644 index 0000000000000..52ac9ae7dd2d0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object IsotonicRegressionExample { + + def main(args: Array[String]) : Unit = { + + val conf = new SparkConf().setAppName("IsotonicRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + + // Create label, feature, weight tuples from input data with weight set to default value 1.0. + val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + (parts(0), parts(1), 1.0) + } + + // Split data into training (60%) and test (40%) sets. + val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0) + val test = splits(1) + + // Create isotonic regression model from training data. + // Isotonic parameter defaults to true so it is only shown for demonstration + val model = new IsotonicRegression().setIsotonic(true).run(training) + + // Create tuples of predicted and real labels. + val predictionAndLabel = test.map { point => + val predictedLabel = model.predict(point._2) + (predictedLabel, point._1) + } + + // Calculate mean squared error between predicted and real labels. + val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() + println("Mean Squared Error = " + meanSquaredError) + + // Save and load model + model.save(sc, "target/tmp/myIsotonicRegressionModel") + val sameModel = IsotonicRegressionModel.load(sc, "target/tmp/myIsotonicRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala new file mode 100644 index 0000000000000..a7a47c2a3556a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object NaiveBayesExample { + + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("NaiveBayesExample") + val sc = new SparkContext(conf) + // $example on$ + val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") + val parsedData = data.map { line => + val parts = line.split(',') + LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) + } + + // Split data into training (60%) and test (40%). + val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) + val training = splits(0) + val test = splits(1) + + val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") + + val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) + val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() + + // Save and load model + model.save(sc, "target/tmp/myNaiveBayesModel") + val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") + // $example off$ + } +} + +// scalastyle:on println From ecfb3e73fd0a99f0be96034710974e78b6f9d624 Mon Sep 17 00:00:00 2001 From: lihao Date: Mon, 2 Nov 2015 16:09:22 -0800 Subject: [PATCH 128/324] [SPARK-10286][ML][PYSPARK][DOCS] Add @since annotation to pyspark.ml.param and pyspark.ml.* Author: lihao Closes #9275 from lidinghao/SPARK-10286. --- python/pyspark/ml/evaluation.py | 20 ++++ python/pyspark/ml/feature.py | 164 ++++++++++++++++++++++++++++ python/pyspark/ml/param/__init__.py | 16 +++ python/pyspark/ml/pipeline.py | 30 +++++ 4 files changed, 230 insertions(+) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index cb3b07947e488..dcc1738ec518b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -17,6 +17,7 @@ from abc import abstractmethod, ABCMeta +from pyspark import since from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol @@ -31,6 +32,8 @@ class Evaluator(Params): """ Base class for evaluators that compute metrics from predictions. + + .. versionadded:: 1.4.0 """ __metaclass__ = ABCMeta @@ -46,6 +49,7 @@ def _evaluate(self, dataset): """ raise NotImplementedError() + @since("1.4.0") def evaluate(self, dataset, params=None): """ Evaluates the output with optional parameters. @@ -66,6 +70,7 @@ def evaluate(self, dataset, params=None): else: raise ValueError("Params must be a param map but got %s." % type(params)) + @since("1.5.0") def isLargerBetter(self): """ Indicates whether the metric returned by :py:meth:`evaluate` should be maximized @@ -114,6 +119,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.70... >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) 0.83... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -138,6 +145,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.4.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -145,6 +153,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.4.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -152,6 +161,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.4.0") def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC"): """ @@ -180,6 +190,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) 2.649... + + .. versionadded:: 1.4.0 """ # Because we will maximize evaluation value (ref: `CrossValidator`), # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), @@ -205,6 +217,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.4.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -212,6 +225,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.4.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -219,6 +233,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.4.0") def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse"): """ @@ -246,6 +261,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio 0.66... >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) 0.66... + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", @@ -271,6 +288,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", kwargs = self.__init__._input_kwargs self._set(**kwargs) + @since("1.5.0") def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. @@ -278,6 +296,7 @@ def setMetricName(self, value): self._paramMap[self.metricName] = value return self + @since("1.5.0") def getMetricName(self): """ Gets the value of metricName or its default value. @@ -285,6 +304,7 @@ def getMetricName(self): return self.getOrDefault(self.metricName) @keyword_only + @since("1.5.0") def setParams(self, predictionCol="prediction", labelCol="label", metricName="f1"): """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 55bde6d0ea4fb..c7b6dd926c3e8 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,6 +19,7 @@ if sys.version > '3': basestring = str +from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only @@ -51,6 +52,8 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"} >>> binarizer.transform(df, params).head().vector 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -71,6 +74,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, threshold=0.0, inputCol=None, outputCol=None): """ setParams(self, threshold=0.0, inputCol=None, outputCol=None) @@ -79,6 +83,7 @@ def setParams(self, threshold=0.0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -86,6 +91,7 @@ def setThreshold(self, value): self._paramMap[self.threshold] = value return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. @@ -114,6 +120,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 + + .. versionadded:: 1.3.0 """ # a placeholder to make it appear in the generated doc @@ -150,6 +158,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, splits=None, inputCol=None, outputCol=None): """ setParams(self, splits=None, inputCol=None, outputCol=None) @@ -158,6 +167,7 @@ def setParams(self, splits=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setSplits(self, value): """ Sets the value of :py:attr:`splits`. @@ -165,6 +175,7 @@ def setSplits(self, value): self._paramMap[self.splits] = value return self + @since("1.4.0") def getSplits(self): """ Gets the value of threshold or its default value. @@ -194,6 +205,8 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): ... >>> sorted(map(str, model.vocabulary)) ['a', 'b', 'c'] + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -242,6 +255,7 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outpu self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): """ setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) @@ -250,6 +264,7 @@ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outp kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setMinTF(self, value): """ Sets the value of :py:attr:`minTF`. @@ -257,12 +272,14 @@ def setMinTF(self, value): self._paramMap[self.minTF] = value return self + @since("1.6.0") def getMinTF(self): """ Gets the value of minTF or its default value. """ return self.getOrDefault(self.minTF) + @since("1.6.0") def setMinDF(self, value): """ Sets the value of :py:attr:`minDF`. @@ -270,12 +287,14 @@ def setMinDF(self, value): self._paramMap[self.minDF] = value return self + @since("1.6.0") def getMinDF(self): """ Gets the value of minDF or its default value. """ return self.getOrDefault(self.minDF) + @since("1.6.0") def setVocabSize(self, value): """ Sets the value of :py:attr:`vocabSize`. @@ -283,6 +302,7 @@ def setVocabSize(self, value): self._paramMap[self.vocabSize] = value return self + @since("1.6.0") def getVocabSize(self): """ Gets the value of vocabSize or its default value. @@ -298,9 +318,12 @@ class CountVectorizerModel(JavaModel): .. note:: Experimental Model fitted by CountVectorizer. + + .. versionadded:: 1.6.0 """ @property + @since("1.6.0") def vocabulary(self): """ An array of terms in the vocabulary. @@ -331,6 +354,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol): >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2) >>> df3.head().origVec DenseVector([5.0, 8.0, 6.0]) + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -351,6 +376,7 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, inverse=False, inputCol=None, outputCol=None): """ setParams(self, inverse=False, inputCol=None, outputCol=None) @@ -359,6 +385,7 @@ def setParams(self, inverse=False, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setInverse(self, value): """ Sets the value of :py:attr:`inverse`. @@ -366,6 +393,7 @@ def setInverse(self, value): self._paramMap[self.inverse] = value return self + @since("1.6.0") def getInverse(self): """ Gets the value of inverse or its default value. @@ -390,6 +418,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([2.0, 2.0, 9.0]) >>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod DenseVector([4.0, 3.0, 15.0]) + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -410,6 +440,7 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, scalingVec=None, inputCol=None, outputCol=None): """ setParams(self, scalingVec=None, inputCol=None, outputCol=None) @@ -418,6 +449,7 @@ def setParams(self, scalingVec=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setScalingVec(self, value): """ Sets the value of :py:attr:`scalingVec`. @@ -425,6 +457,7 @@ def setScalingVec(self, value): self._paramMap[self.scalingVec] = value return self + @since("1.5.0") def getScalingVec(self): """ Gets the value of scalingVec or its default value. @@ -449,6 +482,8 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} >>> hashingTF.transform(df, params).head().vector SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + + .. versionadded:: 1.3.0 """ @keyword_only @@ -463,6 +498,7 @@ def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.3.0") def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) @@ -490,6 +526,8 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): >>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"} >>> idf.fit(df, params).transform(df).head().vector DenseVector([0.2877, 0.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -510,6 +548,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): """ setParams(self, minDocFreq=0, inputCol=None, outputCol=None) @@ -518,6 +557,7 @@ def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMinDocFreq(self, value): """ Sets the value of :py:attr:`minDocFreq`. @@ -525,6 +565,7 @@ def setMinDocFreq(self, value): self._paramMap[self.minDocFreq] = value return self + @since("1.4.0") def getMinDocFreq(self): """ Gets the value of minDocFreq or its default value. @@ -540,6 +581,8 @@ class IDFModel(JavaModel): .. note:: Experimental Model fitted by IDF. + + .. versionadded:: 1.4.0 """ @@ -571,6 +614,8 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): |[2.0]| [1.0]| +-----+------+ ... + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -591,6 +636,7 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): """ setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) @@ -599,6 +645,7 @@ def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setMin(self, value): """ Sets the value of :py:attr:`min`. @@ -606,12 +653,14 @@ def setMin(self, value): self._paramMap[self.min] = value return self + @since("1.6.0") def getMin(self): """ Gets the value of min or its default value. """ return self.getOrDefault(self.min) + @since("1.6.0") def setMax(self, value): """ Sets the value of :py:attr:`max`. @@ -619,6 +668,7 @@ def setMax(self, value): self._paramMap[self.max] = value return self + @since("1.6.0") def getMax(self): """ Gets the value of max or its default value. @@ -634,6 +684,8 @@ class MinMaxScalerModel(JavaModel): .. note:: Experimental Model fitted by :py:class:`MinMaxScaler`. + + .. versionadded:: 1.6.0 """ @@ -668,6 +720,8 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -686,6 +740,7 @@ def __init__(self, n=2, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, n=2, inputCol=None, outputCol=None): """ setParams(self, n=2, inputCol=None, outputCol=None) @@ -694,6 +749,7 @@ def setParams(self, n=2, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setN(self, value): """ Sets the value of :py:attr:`n`. @@ -701,6 +757,7 @@ def setN(self, value): self._paramMap[self.n] = value return self + @since("1.5.0") def getN(self): """ Gets the value of n or its default value. @@ -726,6 +783,8 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"} >>> normalizer.transform(df, params).head().vector DenseVector([0.4286, -0.5714]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -744,6 +803,7 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, p=2.0, inputCol=None, outputCol=None): """ setParams(self, p=2.0, inputCol=None, outputCol=None) @@ -752,6 +812,7 @@ def setParams(self, p=2.0, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setP(self, value): """ Sets the value of :py:attr:`p`. @@ -759,6 +820,7 @@ def setP(self, value): self._paramMap[self.p] = value return self + @since("1.4.0") def getP(self): """ Gets the value of p or its default value. @@ -800,6 +862,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -818,6 +882,7 @@ def __init__(self, dropLast=True, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ setParams(self, dropLast=True, inputCol=None, outputCol=None) @@ -826,6 +891,7 @@ def setParams(self, dropLast=True, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setDropLast(self, value): """ Sets the value of :py:attr:`dropLast`. @@ -833,6 +899,7 @@ def setDropLast(self, value): self._paramMap[self.dropLast] = value return self + @since("1.4.0") def getDropLast(self): """ Gets the value of dropLast or its default value. @@ -858,6 +925,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) >>> px.setParams(outputCol="test").transform(df).head().test DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -877,6 +946,7 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, degree=2, inputCol=None, outputCol=None): """ setParams(self, degree=2, inputCol=None, outputCol=None) @@ -885,6 +955,7 @@ def setParams(self, degree=2, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setDegree(self, value): """ Sets the value of :py:attr:`degree`. @@ -892,6 +963,7 @@ def setDegree(self, value): self._paramMap[self.degree] = value return self + @since("1.4.0") def getDegree(self): """ Gets the value of degree or its default value. @@ -929,6 +1001,8 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -951,6 +1025,7 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, o self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) @@ -959,6 +1034,7 @@ def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMinTokenLength(self, value): """ Sets the value of :py:attr:`minTokenLength`. @@ -966,12 +1042,14 @@ def setMinTokenLength(self, value): self._paramMap[self.minTokenLength] = value return self + @since("1.4.0") def getMinTokenLength(self): """ Gets the value of minTokenLength or its default value. """ return self.getOrDefault(self.minTokenLength) + @since("1.4.0") def setGaps(self, value): """ Sets the value of :py:attr:`gaps`. @@ -979,12 +1057,14 @@ def setGaps(self, value): self._paramMap[self.gaps] = value return self + @since("1.4.0") def getGaps(self): """ Gets the value of gaps or its default value. """ return self.getOrDefault(self.gaps) + @since("1.4.0") def setPattern(self, value): """ Sets the value of :py:attr:`pattern`. @@ -992,6 +1072,7 @@ def setPattern(self, value): self._paramMap[self.pattern] = value return self + @since("1.4.0") def getPattern(self): """ Gets the value of pattern or its default value. @@ -1013,6 +1094,8 @@ class SQLTransformer(JavaTransformer): ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") >>> sqlTrans.transform(df).head() Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0) + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -1030,6 +1113,7 @@ def __init__(self, statement=None): self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, statement=None): """ setParams(self, statement=None) @@ -1038,6 +1122,7 @@ def setParams(self, statement=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setStatement(self, value): """ Sets the value of :py:attr:`statement`. @@ -1045,6 +1130,7 @@ def setStatement(self, value): self._paramMap[self.statement] = value return self + @since("1.6.0") def getStatement(self): """ Gets the value of statement or its default value. @@ -1070,6 +1156,8 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) >>> model.transform(df).collect()[1].scaled DenseVector([1.4142]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -1090,6 +1178,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None): """ setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) @@ -1098,6 +1187,7 @@ def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setWithMean(self, value): """ Sets the value of :py:attr:`withMean`. @@ -1105,12 +1195,14 @@ def setWithMean(self, value): self._paramMap[self.withMean] = value return self + @since("1.4.0") def getWithMean(self): """ Gets the value of withMean or its default value. """ return self.getOrDefault(self.withMean) + @since("1.4.0") def setWithStd(self, value): """ Sets the value of :py:attr:`withStd`. @@ -1118,6 +1210,7 @@ def setWithStd(self, value): self._paramMap[self.withStd] = value return self + @since("1.4.0") def getWithStd(self): """ Gets the value of withStd or its default value. @@ -1133,9 +1226,12 @@ class StandardScalerModel(JavaModel): .. note:: Experimental Model fitted by StandardScaler. + + .. versionadded:: 1.4.0 """ @property + @since("1.5.0") def std(self): """ Standard deviation of the StandardScalerModel. @@ -1143,6 +1239,7 @@ def std(self): return self._call_java("std") @property + @since("1.5.0") def mean(self): """ Mean of the StandardScalerModel. @@ -1171,6 +1268,8 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + + .. versionadded:: 1.4.0 """ @keyword_only @@ -1185,6 +1284,7 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): """ setParams(self, inputCol=None, outputCol=None, handleInvalid="error") @@ -1202,8 +1302,11 @@ class StringIndexerModel(JavaModel): .. note:: Experimental Model fitted by StringIndexer. + + .. versionadded:: 1.4.0 """ @property + @since("1.5.0") def labels(self): """ Ordered list of labels, corresponding to indices to be assigned. @@ -1221,6 +1324,8 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): The index-string mapping is either from the ML attributes of the input column, or from user-supplied labels (which take precedence over ML attributes). See L{StringIndexer} for converting strings into indices. + + .. versionadded:: 1.6.0 """ # a placeholder to make the labels show up in generated doc @@ -1243,6 +1348,7 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, inputCol=None, outputCol=None, labels=None): """ setParams(self, inputCol=None, outputCol=None, labels=None) @@ -1251,6 +1357,7 @@ def setParams(self, inputCol=None, outputCol=None, labels=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setLabels(self, value): """ Sets the value of :py:attr:`labels`. @@ -1258,6 +1365,7 @@ def setLabels(self, value): self._paramMap[self.labels] = value return self + @since("1.6.0") def getLabels(self): """ Gets the value of :py:attr:`labels` or its default value. @@ -1271,6 +1379,8 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): A feature transformer that filters out stop words from input. Note: null values from input array are preserved unless adding null to stopWords explicitly. + + .. versionadded:: 1.6.0 """ # a placeholder to make the stopwords show up in generated doc stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") @@ -1297,6 +1407,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): """ @@ -1307,6 +1418,7 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setStopWords(self, value): """ Specify the stopwords to be filtered. @@ -1314,12 +1426,14 @@ def setStopWords(self, value): self._paramMap[self.stopWords] = value return self + @since("1.6.0") def getStopWords(self): """ Get the stopwords. """ return self.getOrDefault(self.stopWords) + @since("1.6.0") def setCaseSensitive(self, value): """ Set whether to do a case sensitive comparison over the stop words @@ -1327,6 +1441,7 @@ def setCaseSensitive(self, value): self._paramMap[self.caseSensitive] = value return self + @since("1.6.0") def getCaseSensitive(self): """ Get whether to do a case sensitive comparison over the stop words. @@ -1360,6 +1475,8 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.3.0 """ @keyword_only @@ -1373,6 +1490,7 @@ def __init__(self, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.3.0") def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") @@ -1398,6 +1516,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} >>> vecAssembler.transform(df, params).head().vector DenseVector([0.0, 1.0]) + + .. versionadded:: 1.4.0 """ @keyword_only @@ -1411,6 +1531,7 @@ def __init__(self, inputCols=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, inputCols=None, outputCol=None): """ setParams(self, inputCols=None, outputCol=None) @@ -1477,6 +1598,8 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> model2 = indexer.fit(df, params) >>> model2.transform(df).head().vector DenseVector([1.0, 0.0]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -1501,6 +1624,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, maxCategories=20, inputCol=None, outputCol=None): """ setParams(self, maxCategories=20, inputCol=None, outputCol=None) @@ -1509,6 +1633,7 @@ def setParams(self, maxCategories=20, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setMaxCategories(self, value): """ Sets the value of :py:attr:`maxCategories`. @@ -1516,6 +1641,7 @@ def setMaxCategories(self, value): self._paramMap[self.maxCategories] = value return self + @since("1.4.0") def getMaxCategories(self): """ Gets the value of maxCategories or its default value. @@ -1531,9 +1657,12 @@ class VectorIndexerModel(JavaModel): .. note:: Experimental Model fitted by VectorIndexer. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def numFeatures(self): """ Number of features, i.e., length of Vectors which this transforms. @@ -1541,6 +1670,7 @@ def numFeatures(self): return self._call_java("numFeatures") @property + @since("1.4.0") def categoryMaps(self): """ Feature value index. Keys are categorical feature indices (column indices). @@ -1573,6 +1703,8 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): >>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4]) >>> vs.transform(df).head().sliced DenseVector([2.3, 1.0]) + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -1600,6 +1732,7 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): """ setParams(self, inputCol=None, outputCol=None, indices=None, names=None): @@ -1608,6 +1741,7 @@ def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.6.0") def setIndices(self, value): """ Sets the value of :py:attr:`indices`. @@ -1615,12 +1749,14 @@ def setIndices(self, value): self._paramMap[self.indices] = value return self + @since("1.6.0") def getIndices(self): """ Gets the value of indices or its default value. """ return self.getOrDefault(self.indices) + @since("1.6.0") def setNames(self, value): """ Sets the value of :py:attr:`names`. @@ -1628,6 +1764,7 @@ def setNames(self, value): self._paramMap[self.names] = value return self + @since("1.6.0") def getNames(self): """ Gets the value of names or its default value. @@ -1666,6 +1803,8 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has ... >>> model.transform(doc).head().model DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -1699,6 +1838,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, inputCol=None, outputCol=None): """ @@ -1709,6 +1849,7 @@ def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setVectorSize(self, value): """ Sets the value of :py:attr:`vectorSize`. @@ -1716,12 +1857,14 @@ def setVectorSize(self, value): self._paramMap[self.vectorSize] = value return self + @since("1.4.0") def getVectorSize(self): """ Gets the value of vectorSize or its default value. """ return self.getOrDefault(self.vectorSize) + @since("1.4.0") def setNumPartitions(self, value): """ Sets the value of :py:attr:`numPartitions`. @@ -1729,12 +1872,14 @@ def setNumPartitions(self, value): self._paramMap[self.numPartitions] = value return self + @since("1.4.0") def getNumPartitions(self): """ Gets the value of numPartitions or its default value. """ return self.getOrDefault(self.numPartitions) + @since("1.4.0") def setMinCount(self, value): """ Sets the value of :py:attr:`minCount`. @@ -1742,6 +1887,7 @@ def setMinCount(self, value): self._paramMap[self.minCount] = value return self + @since("1.4.0") def getMinCount(self): """ Gets the value of minCount or its default value. @@ -1757,8 +1903,11 @@ class Word2VecModel(JavaModel): .. note:: Experimental Model fitted by Word2Vec. + + .. versionadded:: 1.4.0 """ + @since("1.5.0") def getVectors(self): """ Returns the vector representation of the words as a dataframe @@ -1766,6 +1915,7 @@ def getVectors(self): """ return self._call_java("getVectors") + @since("1.5.0") def findSynonyms(self, word, num): """ Find "num" number of words closest in similarity to "word". @@ -1794,6 +1944,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol): >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features DenseVector([1.648..., -4.013...]) + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -1811,6 +1963,7 @@ def __init__(self, k=None, inputCol=None, outputCol=None): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, k=None, inputCol=None, outputCol=None): """ setParams(self, k=None, inputCol=None, outputCol=None) @@ -1819,6 +1972,7 @@ def setParams(self, k=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -1826,6 +1980,7 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of k or its default value. @@ -1841,6 +1996,8 @@ class PCAModel(JavaModel): .. note:: Experimental Model fitted by PCA. + + .. versionadded:: 1.5.0 """ @@ -1879,6 +2036,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): |0.0|0.0| a| [0.0]| 0.0| +---+---+---+--------+-----+ ... + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -1896,6 +2055,7 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label"): self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, formula=None, featuresCol="features", labelCol="label"): """ setParams(self, formula=None, featuresCol="features", labelCol="label") @@ -1904,6 +2064,7 @@ def setParams(self, formula=None, featuresCol="features", labelCol="label"): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setFormula(self, value): """ Sets the value of :py:attr:`formula`. @@ -1911,6 +2072,7 @@ def setFormula(self, value): self._paramMap[self.formula] = value return self + @since("1.5.0") def getFormula(self): """ Gets the value of :py:attr:`formula`. @@ -1926,6 +2088,8 @@ class RFormulaModel(JavaModel): .. note:: Experimental Model fitted by :py:class:`RFormula`. + + .. versionadded:: 1.5.0 """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 2e0c63cb47b17..35c9b776a3d5e 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -18,6 +18,7 @@ from abc import ABCMeta import copy +from pyspark import since from pyspark.ml.util import Identifiable @@ -27,6 +28,8 @@ class Param(object): """ A param with self-contained documentation. + + .. versionadded:: 1.3.0 """ def __init__(self, parent, name, doc): @@ -56,6 +59,8 @@ class Params(Identifiable): """ Components that take parameters. This also provides an internal param map to store parameter values attached to the instance. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta @@ -72,6 +77,7 @@ def __init__(self): self._params = None @property + @since("1.3.0") def params(self): """ Returns all params ordered by name. The default implementation @@ -83,6 +89,7 @@ def params(self): [getattr(self, x) for x in dir(self) if x != "params"])) return self._params + @since("1.4.0") def explainParam(self, param): """ Explains a single param and returns its name, doc, and optional @@ -100,6 +107,7 @@ def explainParam(self, param): valueStr = "(" + ", ".join(values) + ")" return "%s: %s %s" % (param.name, param.doc, valueStr) + @since("1.4.0") def explainParams(self): """ Returns the documentation of all params with their optionally @@ -107,6 +115,7 @@ def explainParams(self): """ return "\n".join([self.explainParam(param) for param in self.params]) + @since("1.4.0") def getParam(self, paramName): """ Gets a param by its name. @@ -117,6 +126,7 @@ def getParam(self, paramName): else: raise ValueError("Cannot find param with name %s." % paramName) + @since("1.4.0") def isSet(self, param): """ Checks whether a param is explicitly set by user. @@ -124,6 +134,7 @@ def isSet(self, param): param = self._resolveParam(param) return param in self._paramMap + @since("1.4.0") def hasDefault(self, param): """ Checks whether a param has a default value. @@ -131,6 +142,7 @@ def hasDefault(self, param): param = self._resolveParam(param) return param in self._defaultParamMap + @since("1.4.0") def isDefined(self, param): """ Checks whether a param is explicitly set by user or has @@ -138,6 +150,7 @@ def isDefined(self, param): """ return self.isSet(param) or self.hasDefault(param) + @since("1.4.0") def hasParam(self, paramName): """ Tests whether this instance contains a param with a given @@ -146,6 +159,7 @@ def hasParam(self, paramName): param = self._resolveParam(paramName) return param in self.params + @since("1.4.0") def getOrDefault(self, param): """ Gets the value of a param in the user-supplied param map or its @@ -157,6 +171,7 @@ def getOrDefault(self, param): else: return self._defaultParamMap[param] + @since("1.4.0") def extractParamMap(self, extra=None): """ Extracts the embedded default param values and user-supplied @@ -175,6 +190,7 @@ def extractParamMap(self, extra=None): paramMap.update(extra) return paramMap + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 312a8502b3a2c..4475451edb781 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -17,6 +17,7 @@ from abc import ABCMeta, abstractmethod +from pyspark import since from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc @@ -26,6 +27,8 @@ class Estimator(Params): """ Abstract class for estimators that fit models to data. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta @@ -42,6 +45,7 @@ def _fit(self, dataset): """ raise NotImplementedError() + @since("1.3.0") def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -73,6 +77,8 @@ class Transformer(Params): """ Abstract class for transformers that transform one dataset into another. + + .. versionadded:: 1.3.0 """ __metaclass__ = ABCMeta @@ -88,6 +94,7 @@ def _transform(self, dataset): """ raise NotImplementedError() + @since("1.3.0") def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -113,6 +120,8 @@ def transform(self, dataset, params=None): class Model(Transformer): """ Abstract class for models that are fitted by estimators. + + .. versionadded:: 1.4.0 """ __metaclass__ = ABCMeta @@ -136,6 +145,8 @@ class Pipeline(Estimator): consists of fitted models and transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as an identity transformer. + + .. versionadded:: 1.3.0 """ @keyword_only @@ -151,6 +162,7 @@ def __init__(self, stages=None): kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + @since("1.3.0") def setStages(self, value): """ Set pipeline stages. @@ -161,6 +173,7 @@ def setStages(self, value): self._paramMap[self.stages] = value return self + @since("1.3.0") def getStages(self): """ Get pipeline stages. @@ -169,6 +182,7 @@ def getStages(self): return self._paramMap[self.stages] @keyword_only + @since("1.3.0") def setParams(self, stages=None): """ setParams(self, stages=None) @@ -204,7 +218,14 @@ def _fit(self, dataset): transformers.append(stage) return PipelineModel(transformers) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance. + + :param extra: extra parameters + :returns: new instance + """ if extra is None: extra = dict() that = Params.copy(self, extra) @@ -216,6 +237,8 @@ def copy(self, extra=None): class PipelineModel(Model): """ Represents a compiled pipeline with transformers and fitted models. + + .. versionadded:: 1.3.0 """ def __init__(self, stages): @@ -227,7 +250,14 @@ def _transform(self, dataset): dataset = t.transform(dataset) return dataset + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance. + + :param extra: extra parameters + :returns: new instance + """ if extra is None: extra = dict() stages = [stage.copy(extra) for stage in self.stages] From ec03866a7ef2d0826520755d47c8c9480148a76c Mon Sep 17 00:00:00 2001 From: Dominik Dahlem Date: Mon, 2 Nov 2015 16:11:42 -0800 Subject: [PATCH 129/324] [SPARK-11343][ML] Allow float and double prediction/label columns in RegressionEvaluator mengxr, felixcheung This pull request just relaxes the type of the prediction/label columns to be float and double. Internally, these columns are casted to double. The other evaluators might need to be changed also. Author: Dominik Dahlem Closes #9296 from dahlem/ddahlem_regression_evaluator_double_predictions_27102015. --- .../spark/ml/evaluation/RegressionEvaluator.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 3fd34d8571017..ba012f444d3e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -23,7 +23,8 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, FloatType} /** * :: Experimental :: @@ -72,10 +73,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + val predictionType = schema($(predictionCol)).dataType + require(predictionType == FloatType || predictionType == DoubleType) + val labelType = schema($(labelCol)).dataType + require(labelType == FloatType || labelType == DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + val predictionAndLabels = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) .map { case Row(prediction: Double, label: Double) => (prediction, label) } From c020f7d9d43548d27ae4a9564ba38981fd530cb1 Mon Sep 17 00:00:00 2001 From: vectorijk Date: Mon, 2 Nov 2015 16:12:04 -0800 Subject: [PATCH 130/324] [SPARK-10592] [ML] [PySpark] Deprecate weights and use coefficients instead in ML models Deprecated in `LogisticRegression` and `LinearRegression` Author: vectorijk Closes #9311 from vectorijk/spark-10592. --- R/pkg/R/mllib.R | 6 +- .../classification/LogisticRegression.scala | 11 +- .../apache/spark/ml/r/SparkRWrappers.scala | 15 +- .../ml/regression/AFTSurvivalRegression.scala | 32 +-- .../ml/regression/IsotonicRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 15 +- .../ml/classification/JavaOneVsRestSuite.java | 6 +- .../LogisticRegressionSuite.scala | 152 ++++++++------- .../MultilayerPerceptronClassifierSuite.scala | 6 +- .../ml/classification/OneVsRestSuite.scala | 6 +- .../AFTSurvivalRegressionSuite.scala | 12 +- .../ml/regression/LinearRegressionSuite.scala | 184 +++++++++--------- python/pyspark/ml/classification.py | 13 ++ python/pyspark/ml/regression.py | 12 ++ 14 files changed, 263 insertions(+), 211 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index aadd5b8da5e3b..60bfadb8e7503 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -92,9 +92,9 @@ setMethod("summary", signature(x = "PipelineModel"), function(x, ...) { features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelFeatures", x@model) - weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelWeights", x@model) - coefficients <- as.matrix(unlist(weights)) + coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelCoefficients", x@model) + coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 6f839ff4d7cd8..a1335e7a1bde8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -392,11 +392,14 @@ class LogisticRegression(override val uid: String) @Experimental class LogisticRegressionModel private[ml] ( override val uid: String, - val weights: Vector, + val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { + @deprecated("Use coefficients instead.", "1.6.0") + def weights: Vector = coefficients + override def setThreshold(value: Double): this.type = super.setThreshold(value) override def getThreshold: Double = super.getThreshold @@ -407,7 +410,7 @@ class LogisticRegressionModel private[ml] ( /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { - BLAS.dot(features, weights) + intercept + BLAS.dot(features, coefficients) + intercept } /** Score (probability) for class label 1. For binary classification only. */ @@ -416,7 +419,7 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } - override val numFeatures: Int = weights.size + override val numFeatures: Int = coefficients.size override val numClasses: Int = 2 @@ -483,7 +486,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 21ebf6d916db7..9162ec0e4e153 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -51,13 +51,22 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + @deprecated("Use getModelCoefficients instead.", "1.6.0") def getModelWeights(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => Array(m.intercept) ++ m.weights.toArray - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No weights available for LogisticRegressionModel") // SPARK-9492 + case m: LogisticRegressionModel => + Array(m.intercept) ++ m.weights.toArray + } + } + + def getModelCoefficients(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.coefficients.toArray + case m: LogisticRegressionModel => + Array(m.intercept) ++ m.coefficients.toArray } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index ac2c3d825f13c..4dbbc7d39931b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -200,17 +200,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size /* - The weights vector has three parts: + The coefficients vector has three parts: the first element: Double, log(sigma), the log of scale parameter the second element: Double, intercept of the beta parameter the third to the end elements: Doubles, regression coefficients vector of the beta parameter */ - val initialWeights = Vectors.zeros(numFeatures + 2) + val initialCoefficients = Vectors.zeros(numFeatures + 2) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeights.toBreeze.toDenseVector) + initialCoefficients.toBreeze.toDenseVector) - val weights = { + val coefficients = { val arrayBuilder = mutable.ArrayBuilder.make[Double] var state: optimizer.State = null while (states.hasNext) { @@ -227,10 +227,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val coefficients = Vectors.dense(weights.slice(2, weights.length)) - val intercept = weights(1) - val scale = math.exp(weights(0)) - val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + val regressionCoefficients = Vectors.dense(coefficients.slice(2, coefficients.length)) + val intercept = coefficients(1) + val scale = math.exp(coefficients(0)) + val model = new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale) copyValues(model.setParent(this)) } @@ -251,7 +251,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("1.6.0") class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override val uid: String, - @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val regressionCoefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { @@ -275,7 +275,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime - val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + val lambda = math.exp(BLAS.dot(regressionCoefficients, features) + intercept) // shape parameter for the Weibull distribution of lifetime val k = 1 / scale val quantiles = $(quantileProbabilities).map { @@ -286,7 +286,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predict(features: Vector): Double = { - math.exp(BLAS.dot(coefficients, features) + intercept) + math.exp(BLAS.dot(regressionCoefficients, features) + intercept) } @Since("1.6.0") @@ -309,7 +309,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { - copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + copyValues(new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale), extra) .setParent(parent) } } @@ -369,17 +369,17 @@ class AFTSurvivalRegressionModel private[ml] ( * \frac{\partial (-\iota)}{\partial (\log\sigma)}= * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] * }}} - * @param weights The log of scale parameter, the intercept and + * @param coefficients including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. */ -private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) +private class AFTAggregator(coefficients: BDV[Double], fitIntercept: Boolean) extends Serializable { // beta is the intercept and regression coefficients to the covariates - private val beta = weights.slice(1, weights.length) + private val beta = coefficients.slice(1, coefficients.length) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(weights(0)) + private val sigma = math.exp(coefficients(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 2ff500f291abc..f4a17c8f9a582 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -87,8 +87,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures lit(1.0) } dataset.select(col($(labelCol)), f, w) - .map { case Row(label: Double, feature: Double, weights: Double) => - (label, feature, weights) + .map { case Row(label: Double, feature: Double, weight: Double) => + (label, feature, weight) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f663b9bd9ac73..6e9c7442b8110 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -203,7 +203,7 @@ class LinearRegression(override val uid: String) val yMean = ySummarizer.mean(0) val yStd = math.sqrt(ySummarizer.variance(0)) - // If the yStd is zero, then the intercept is yMean with zero weights; + // If the yStd is zero, then the intercept is yMean with zero coefficient; // as a result, training is not needed. if (yStd == 0.0) { logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + @@ -331,14 +331,17 @@ class LinearRegression(override val uid: String) @Experimental class LinearRegressionModel private[ml] ( override val uid: String, - val weights: Vector, + val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None - override val numFeatures: Int = weights.size + @deprecated("Use coefficients instead.", "1.6.0") + def weights: Vector = coefficients + + override val numFeatures: Int = coefficients.size /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is @@ -387,11 +390,11 @@ class LinearRegressionModel private[ml] ( override protected def predict(features: Vector): Double = { - dot(features, weights) + intercept + dot(features, coefficients) + intercept } override def copy(extra: ParamMap): LinearRegressionModel = { - val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra) + val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } @@ -400,7 +403,7 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the - * training weights except for the objective trace. + * training coefficients except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 253cabf0133d0..cbabafe1b541d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -47,16 +47,16 @@ public void setUp() { jsql = new SQLContext(jsc); int nPoints = 3; - // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2. // As a result, we are drawing samples from probability distribution of an actual model. - double[] weights = { + double[] coefficients = { -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; List points = JavaConverters.seqAsJavaListConverter( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) ).asJava(); datasetRDD = jsc.parallelize(points, 2); dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e0a795e5e0b00..325faf37e8eea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -48,21 +48,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.classification.LogisticRegressionSuite val nPoints = 10000 - val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 1) + coefficients, xMean, xVariance, true, nPoints, 42), 1) data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 - val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) sqlContext.createDataFrame(sc.parallelize(testData, 4)) } @@ -296,8 +297,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -308,14 +309,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7996864 */ val interceptR = 2.8366423 - val weightsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) + val coefficientsR = Vectors.dense(-0.5895848, 0.8931147, -0.3925051, -0.7996864) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) + assert(model1.coefficients ~= coefficientsR relTol 1E-3) // Without regularization, with or without standardization will converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) } test("binary logistic regression without intercept without regularization") { @@ -332,9 +333,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -345,14 +346,14 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.7407946 */ val interceptR = 0.0 - val weightsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) + val coefficientsR = Vectors.dense(-0.3534996, 1.2964482, -0.3571741, -0.7407946) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) // Without regularization, with or without standardization should converge to the same solution. assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) } test("binary logistic regression with intercept with L1 regularization") { @@ -371,8 +372,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -383,10 +384,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.02481551 */ val interceptR1 = -0.05627428 - val weightsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.04325749, -0.02481551) assert(model1.intercept ~== interceptR1 relTol 1E-2) - assert(model1.weights ~= weightsR1 absTol 2E-2) + assert(model1.coefficients ~= coefficientsR1 absTol 2E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -395,9 +396,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -408,10 +409,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.3722152 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.1665453, 0.0) assert(model2.intercept ~== interceptR2 relTol 1E-2) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression without intercept with L1 regularization") { @@ -430,9 +431,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -443,10 +444,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.03891782 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) + val coefficientsR1 = Vectors.dense(0.0, 0.0, -0.05189203, -0.03891782) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 absTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 absTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -455,9 +456,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -468,10 +469,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.08420782, 0.0) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression with intercept with L2 regularization") { @@ -490,8 +491,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -502,10 +503,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.10062872 */ val interceptR1 = 0.15021751 - val weightsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) + val coefficientsR1 = Vectors.dense(-0.07251837, 0.10724191, -0.04865309, -0.10062872) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -514,9 +515,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -527,10 +528,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.06266838 */ val interceptR2 = 0.48657516 - val weightsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) + val coefficientsR2 = Vectors.dense(-0.05155371, 0.02301057, -0.11482896, -0.06266838) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) } test("binary logistic regression without intercept with L2 regularization") { @@ -549,9 +550,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -562,10 +563,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.09799775 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) + val coefficientsR1 = Vectors.dense(-0.06099165, 0.12857058, -0.04708770, -0.09799775) assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -574,9 +575,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -587,10 +588,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.053314311 */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) + val coefficientsR2 = Vectors.dense(-0.005679651, 0.048967094, -0.093714016, -0.053314311) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) } test("binary logistic regression with intercept with ElasticNet regularization") { @@ -609,8 +610,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -621,10 +622,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.15458796 */ val interceptR1 = 0.57734851 - val weightsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) + val coefficientsR1 = Vectors.dense(-0.05310287, 0.0, -0.08849250, -0.15458796) assert(model1.intercept ~== interceptR1 relTol 6E-3) - assert(model1.weights ~== weightsR1 absTol 5E-3) + assert(model1.coefficients ~== coefficientsR1 absTol 5E-3) /* Using the following R code to load the data and train the model using glmnet package. @@ -633,9 +634,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -646,10 +647,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.05350074 */ val interceptR2 = 0.51555993 - val weightsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) + val coefficientsR2 = Vectors.dense(0.0, 0.0, -0.18807395, -0.05350074) assert(model2.intercept ~== interceptR2 relTol 6E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression without intercept with ElasticNet regularization") { @@ -668,9 +669,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, intercept=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -681,10 +682,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 -0.142534158 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) + val coefficientsR1 = Vectors.dense(-0.001005743, 0.072577857, -0.081203769, -0.142534158) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 absTol 1E-2) /* Using the following R code to load the data and train the model using glmnet package. @@ -693,9 +694,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + coefficients = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, intercept=FALSE, standardize=FALSE)) - weights + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -706,10 +707,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) + val coefficientsR2 = Vectors.dense(0.0, 0.03345223, -0.11304532, 0.0) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 absTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 absTol 1E-3) } test("binary logistic regression with intercept with strong L1 regularization") { @@ -732,8 +733,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { }).histogram /* - For binary logistic regression with strong L1 regularization, all the weights will be zeros. - As a result, + For binary logistic regression with strong L1 regularization, all the coefficients + will be zeros. As a result, {{{ P(0) = 1 / (1 + \exp(b)), and P(1) = \exp(b) / (1 + \exp(b)) @@ -743,13 +744,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { }}} */ val interceptTheory = math.log(histogram(1) / histogram(0)) - val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) + val coefficientsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) - assert(model1.weights ~= weightsTheory absTol 1E-6) + assert(model1.coefficients ~= coefficientsTheory absTol 1E-6) assert(model2.intercept ~== interceptTheory relTol 1E-5) - assert(model2.weights ~= weightsTheory absTol 1E-6) + assert(model2.coefficients ~= coefficientsTheory absTol 1E-6) /* Using the following R code to load the data and train the model using glmnet package. @@ -758,8 +759,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE) label = factor(data$V1) features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - weights + coefficients = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + coefficients 5 x 1 sparse Matrix of class "dgCMatrix" s0 @@ -770,10 +771,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data.V5 . */ val interceptR = -0.248065 - val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) + val coefficientsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights ~== weightsR absTol 1E-6) + assert(model1.coefficients ~== coefficientsR absTol 1E-6) } test("evaluate on test set") { @@ -814,10 +815,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("binary logistic regression with weighted samples") { val (dataset, weightedDataset) = { val nPoints = 1000 - val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) // Let's over-sample the positive samples twice. val data1 = testData.flatMap { case labeledPoint: LabeledPoint => @@ -863,9 +865,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model1a0 = trainer1a.fit(dataset) val model1a1 = trainer1a.fit(weightedDataset) val model1b = trainer1b.fit(weightedDataset) - assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.coefficients !~= model1a1.coefficients absTol 1E-3) assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) - assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 2d1df9b2b82e8..17db8c44777d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -53,16 +53,16 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp test("3 class classification with 2 hidden layers") { val nPoints = 1000 - // The following weights are taken from OneVsRestSuite.scala + // The following coefficients are taken from OneVsRestSuite.scala // they represent 3-class iris dataset - val weights = Array( + val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) val rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 42), 2) val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 977f0e0b70c1a..5ea71c5317b7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -43,16 +43,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val nPoints = 1000 - // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. + // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2 // As a result, we are drawing samples from probability distribution of an actual model. - val weights = Array( + val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 42), 2) dataset = sqlContext.createDataFrame(rdd) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172b..c0f791bce13d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -141,12 +141,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val coefficientsR = Vectors.dense(-0.039) + val regressionCoefficientsR = Vectors.dense(-0.039) val interceptR = 1.759 val scaleR = 1.41 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -212,12 +212,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val regressionCoefficientsR = Vectors.dense(-0.0844, 0.0677) val interceptR = 1.9206 val scaleR = 0.977 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -282,12 +282,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 6 n= 1000 */ - val coefficientsR = Vectors.dense(0.896, -0.709) + val regressionCoefficientsR = Vectors.dense(0.896, -0.709) val interceptR = 0.0 val scaleR = 1.52 assert(model.intercept === interceptR) - assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a2a5c0bbdcb90..235c796d785a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -122,8 +122,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) label <- as.numeric(data$V1) - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 6.298698 @@ -131,17 +131,18 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 7.199082 */ val interceptR = 6.298698 - val weightsR = Vectors.dense(4.700706, 7.199082) + val coefficientsR = Vectors.dense(4.700706, 7.199082) assert(model1.intercept ~== interceptR relTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) + assert(model1.coefficients ~= coefficientsR relTol 1E-3) assert(model2.intercept ~== interceptR relTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -159,37 +160,37 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val modelWithoutIntercept2 = trainer2.fit(datasetWithDenseFeatureWithoutIntercept) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, intercept = FALSE)) - > weights + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . as.numeric.data.V2. 6.995908 as.numeric.data.V3. 5.275131 */ - val weightsR = Vectors.dense(6.995908, 5.275131) + val coefficientsR = Vectors.dense(6.995908, 5.275131) assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.weights ~= weightsR relTol 1E-3) + assert(model1.coefficients ~= coefficientsR relTol 1E-3) assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.weights ~= weightsR relTol 1E-3) + assert(model2.coefficients ~= coefficientsR relTol 1E-3) /* Then again with the data with no intercept: - > weightsWithoutIntercept + > coefficientsWithourIntercept 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . as.numeric.data3.V2. 4.70011 as.numeric.data3.V3. 7.19943 */ - val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943) + val coefficientsWithourInterceptR = Vectors.dense(4.70011, 7.19943) assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3) + assert(modelWithoutIntercept1.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3) - assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3) + assert(modelWithoutIntercept2.coefficients ~= coefficientsWithourInterceptR relTol 1E-3) } } @@ -211,8 +212,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", + alpha = 1.0, lambda = 0.57 )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 6.24300 @@ -220,14 +222,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.679841 */ val interceptR1 = 6.24300 - val weightsR1 = Vectors.dense(4.024821, 6.679841) + val coefficientsR1 = Vectors.dense(4.024821, 6.679841) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - standardize=FALSE)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, standardize=FALSE )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 6.416948 @@ -235,16 +237,17 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.724286 */ val interceptR2 = 6.416948 - val weightsR2 = Vectors.dense(3.893869, 6.724286) + val coefficientsR2 = Vectors.dense(3.893869, 6.724286) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { case Row(features: DenseVector, prediction1: Double) => - val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) + - model1.intercept + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -269,9 +272,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - intercept=FALSE)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -279,15 +282,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.772913 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(6.299752, 4.772913) + val coefficientsR1 = Vectors.dense(6.299752, 4.772913) assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - intercept=FALSE, standardize=FALSE)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, + lambda = 0.57, intercept=FALSE, standardize=FALSE )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -295,16 +298,17 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.764229 */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(6.232193, 4.764229) + val coefficientsR2 = Vectors.dense(6.232193, 4.764229) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { case Row(features: DenseVector, prediction1: Double) => - val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) + - model1.intercept + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -321,8 +325,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 5.269376 @@ -330,15 +334,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.712356) */ val interceptR1 = 5.269376 - val weightsR1 = Vectors.dense(3.736216, 5.712356) + val coefficientsR1 = Vectors.dense(3.736216, 5.712356) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, standardize=FALSE)) - > weights + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 5.791109 @@ -346,15 +350,16 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.910406 */ val interceptR2 = 5.791109 - val weightsR2 = Vectors.dense(3.435466, 5.910406) + val coefficientsR2 = Vectors.dense(3.435466, 5.910406) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -370,9 +375,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, intercept = FALSE)) - > weights + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -380,15 +385,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.214502 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(5.522875, 4.214502) + val coefficientsR1 = Vectors.dense(5.522875, 4.214502) assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, intercept = FALSE, standardize=FALSE)) - > weights + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -396,15 +401,16 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.187419 */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(5.263704, 4.187419) + val coefficientsR2 = Vectors.dense(5.263704, 4.187419) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => val prediction2 = - features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -428,8 +434,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6 )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 6.324108 @@ -437,15 +444,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 5.200403 */ val interceptR1 = 5.696056 - val weightsR1 = Vectors.dense(3.670489, 6.001122) + val coefficientsR1 = Vectors.dense(3.670489, 6.001122) assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 standardize=FALSE)) - > weights + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) 6.114723 @@ -453,16 +460,17 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 6.146531 */ val interceptR2 = 6.114723 - val weightsR2 = Vectors.dense(3.409937, 6.146531) + val coefficientsR2 = Vectors.dense(3.409937, 6.146531) assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { case Row(features: DenseVector, prediction1: Double) => - val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) + - model1.intercept + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -487,9 +495,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = trainer2.fit(datasetWithDenseFeature) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - intercept=FALSE)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -497,15 +505,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.dataM.V3. 4.322251 */ val interceptR1 = 0.0 - val weightsR1 = Vectors.dense(5.673348, 4.322251) + val coefficientsR1 = Vectors.dense(5.673348, 4.322251) assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.weights ~= weightsR1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) /* - weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - intercept=FALSE, standardize=FALSE)) - > weights + coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, + lambda = 1.6, intercept=FALSE, standardize=FALSE )) + > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . @@ -513,16 +521,17 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { as.numeric.data.V3. 4.297622 */ val interceptR2 = 0.0 - val weightsR2 = Vectors.dense(5.477988, 4.297622) + val coefficientsR2 = Vectors.dense(5.477988, 4.297622) assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.weights ~= weightsR2 relTol 1E-3) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { case Row(features: DenseVector, prediction1: Double) => - val prediction2 = features(0) * model1.weights(0) + features(1) * model1.weights(1) + - model1.intercept + val prediction2 = + features(0) * model1.coefficients(0) + features(1) * model1.coefficients(1) + + model1.intercept assert(prediction1 ~== prediction2 relTol 1E-5) } } @@ -554,7 +563,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedResiduals = datasetWithDenseFeature.select("features", "label") .map { case Row(features: DenseVector, label: Double) => val prediction = - features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + features(0) * model.coefficients(0) + features(1) * model.coefficients(1) + + model.intercept label - prediction } .zip(model.summary.residuals.map(_.getDouble(0))) @@ -663,9 +673,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model1a1 = trainer1a.fit(weightedData) val model1b = trainer1b.fit(weightedData) - assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.coefficients !~= model1a1.coefficients absTol 1E-3) assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) - assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) val trainer2a = (new LinearRegression).setFitIntercept(true) @@ -675,9 +685,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model2a0 = trainer2a.fit(data) val model2a1 = trainer2a.fit(weightedData) val model2b = trainer2b.fit(weightedData) - assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.coefficients !~= model2a1.coefficients absTol 1E-3) assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) - assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.coefficients ~== model2b.coefficients absTol 1E-3) assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) val trainer3a = (new LinearRegression).setFitIntercept(false) @@ -687,8 +697,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model3a0 = trainer3a.fit(data) val model3a1 = trainer3a.fit(weightedData) val model3b = trainer3b.fit(weightedData) - assert(model3a0.weights !~= model3a1.weights absTol 1E-3) - assert(model3a0.weights ~== model3b.weights absTol 1E-3) + assert(model3a0.coefficients !~= model3a1.coefficients absTol 1E-3) + assert(model3a0.coefficients ~== model3b.coefficients absTol 1E-3) val trainer4a = (new LinearRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) @@ -697,8 +707,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model4a0 = trainer4a.fit(data) val model4a1 = trainer4a.fit(weightedData) val model4b = trainer4b.fit(weightedData) - assert(model4a0.weights !~= model4a1.weights absTol 1E-3) - assert(model4a0.weights ~== model4b.weights absTol 1E-3) + assert(model4a0.coefficients !~= model4a1.coefficients absTol 1E-3) + assert(model4a0.coefficients ~== model4b.coefficients absTol 1E-3) } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4cbe7fbd482da..2e468f67b8987 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -15,6 +15,9 @@ # limitations under the License. # +import warnings + +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -189,8 +192,18 @@ def weights(self): """ Model weights. """ + + warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") + @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + @property def intercept(self): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dc68815556d4e..ab26616f4a01d 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -117,8 +119,18 @@ def weights(self): """ Model weights. """ + + warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") + @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + @property @since("1.4.0") def intercept(self): From 476f4348e2ea57ea05f4b470abfe76d97eeb20ce Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Mon, 2 Nov 2015 17:02:31 -0800 Subject: [PATCH 131/324] [SPARK-11236] [TEST-MAVEN] [TEST-HADOOP1.0] [CORE] Update Tachyon dependency 0.7.1 -> 0.8.1 This is a reopening of #9204 which failed hadoop1 sbt tests. With the original PR, a classpath issue would occur due to the MIMA plugin pulling in hadoop-2.2 dependencies regardless of the hadoop version when building the `oldDeps` project. These affect the hadoop1 sbt build because they are placed in `lib_managed` and Tachyon 0.8.0's default hadoop version is 2.2. Author: Calvin Jia Closes #9395 from calvinjia/spark-11236. --- core/pom.xml | 6 +----- make-distribution.sh | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 1b6b13517bd56..570a25cf325a2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -262,7 +262,7 @@ org.tachyonproject tachyon-client - 0.7.1 + 0.8.1 org.apache.hadoop @@ -284,10 +284,6 @@ org.tachyonproject tachyon-underfs-glusterfs - - org.tachyonproject - tachyon-underfs-s3 - diff --git a/make-distribution.sh b/make-distribution.sh index 24418ace26270..e1c2afdbc6d87 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,9 +33,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.7.1" +TACHYON_VERSION="0.8.1" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" -TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" +TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" MAKE_TGZ=false NAME=none @@ -240,10 +240,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then fi tar xzf "${TACHYON_TGZ}" - cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/assembly/target/tachyon-assemblies-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/servers/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx From 21ad846238a9a79564e2e99a1def89fd31a0870d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 2 Nov 2015 19:07:31 -0800 Subject: [PATCH 132/324] [MINOR][ML] removed the old `getModelWeights` function Removed the old `getModelWeights` function which was private and renamed into `getModelCoefficients` Author: DB Tsai Closes #9426 from dbtsai/feature-minor. --- .../scala/org/apache/spark/ml/r/SparkRWrappers.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9162ec0e4e153..24f76de806d8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -51,16 +51,6 @@ private[r] object SparkRWrappers { pipeline.fit(df) } - @deprecated("Use getModelCoefficients instead.", "1.6.0") - def getModelWeights(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - Array(m.intercept) ++ m.weights.toArray - case m: LogisticRegressionModel => - Array(m.intercept) ++ m.weights.toArray - } - } - def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => From 2cef1bb0b560a03aa7308f694b0c66347b90c9ea Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 2 Nov 2015 19:18:45 -0800 Subject: [PATCH 133/324] =?UTF-8?q?[SPARK-5354][SQL]=20Cached=20tables=20s?= =?UTF-8?q?hould=20preserve=20partitioning=20and=20ord=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ering. For cached tables, we can just maintain the partitioning and ordering from the source relation. Author: Nong Li Closes #9404 from nongli/spark-5354. --- .../columnar/InMemoryColumnarTableScan.scala | 7 +++ .../apache/spark/sql/execution/Exchange.scala | 40 ++++++++++--- .../apache/spark/sql/CachedTableSuite.scala | 59 +++++++++++++++++++ 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index b4607b12fcefa..7eb1ad7cd8198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel @@ -209,6 +210,12 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + // The cached version does not change the outputPartitioning of the original SparkPlan. + override def outputPartitioning: Partitioning = relation.child.outputPartitioning + + // The cached version does not change the outputOrdering of the original SparkPlan. + override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + override def outputsUnsafeRows: Boolean = true private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7f60c8f5eaa95..e81108b7884d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -194,12 +194,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - private def numPartitions: Int = sqlContext.conf.numShufflePartitions + private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions /** * Given a required distribution, returns a partitioning that satisfies that distribution. */ - private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = { + private def createPartitioning(requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) @@ -220,7 +221,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (child.outputPartitioning.satisfies(distribution)) { child } else { - Exchange(canonicalPartitioning(distribution), child) + Exchange(createPartitioning(distribution, defaultPartitions), child) } } @@ -229,12 +230,33 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (children.length > 1 && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { - children = children.zip(requiredChildDistributions).map { case (child, distribution) => - val targetPartitioning = canonicalPartitioning(distribution) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - Exchange(targetPartitioning, child) + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + // TODO: this should be a cost based descision. For example, a big relation should probably + // maintain its existing number of partitions and smaller partitions should be shuffled. + // defaultPartitions is arbitrary. + val numPartitions = children.head.outputPartitioning.numPartitions + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + child.outputPartitioning.guarantees( + createPartitioning(distribution, numPartitions)) + } + } + + children = if (useExistingPartitioning) { + children + } else { + children.zip(requiredChildDistributions).map { + case (child, distribution) => { + val targetPartitioning = createPartitioning(distribution, defaultPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + Exchange(targetPartitioning, child) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fd566c8276bc1..605954b105d1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD import scala.concurrent.duration._ @@ -353,4 +354,62 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) } + + /** + * Verifies that the plan for `df` contains `expected` number of Exchange operators. + */ + private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + } + + test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { + val table3x = testData.unionAll(testData).unionAll(testData) + table3x.registerTempTable("testData3x") + + sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") + sqlContext.cacheTable("orderedTable") + assertCached(sqlContext.table("orderedTable")) + // Should not have an exchange as the query is already sorted on the group by key. + verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) + checkAnswer( + sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), + sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) + sqlContext.uncacheTable("orderedTable") + + // Set up two tables distributed in the same way. Try this with the data distributed into + // different number of partitions. + for (numPartitions <- 1 until 10 by 4) { + testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1") + testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + sqlContext.dropTempTable("t1") + sqlContext.dropTempTable("t2") + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle. + testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1") + testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + sqlContext.dropTempTable("t1") + sqlContext.dropTempTable("t2") + } } From 9cb5c731dadff9539126362827a258d6b65754bb Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 2 Nov 2015 20:32:08 -0800 Subject: [PATCH 134/324] [SPARK-11329][SQL] Support star expansion for structs. 1. Supporting expanding structs in Projections. i.e. "SELECT s.*" where s is a struct type. This is fixed by allowing the expand function to handle structs in addition to tables. 2. Supporting expanding * inside aggregate functions of structs. "SELECT max(struct(col1, structCol.*))" This requires recursively expanding the expressions. In this case, it it the aggregate expression "max(...)" and we need to recursively expand its children inputs. Author: Nong Li Closes #9343 from nongli/spark-11329. --- .../apache/spark/sql/catalyst/SqlParser.scala | 6 +- .../sql/catalyst/analysis/Analyzer.scala | 46 ++++-- .../sql/catalyst/analysis/unresolved.scala | 78 +++++++--- .../scala/org/apache/spark/sql/Column.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 133 ++++++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- 6 files changed, 230 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0fef04302714e..d7567e8613e3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -466,9 +466,9 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } - | primary - ) + | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target)) } + } | primary + ) protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beabacfc88e32..912c967b95f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -279,6 +279,24 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { + /** + * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree + * rooted at each expression. + */ + def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { + exprs.flatMap { + case s: Star => s.expand(child, resolver) + case e => + e.transformDown { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + } :: Nil + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -286,44 +304,42 @@ class Analyzer( case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) - case o => o :: Nil - } - UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + val newChildren = expandStarExpressions(args, child) + UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil + case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => + val newChildren = expandStarExpressions(args, child) + Alias(child = f.copy(children = newChildren), name)() :: Nil case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { - case s: Star => s.expand(child.output, resolver) + case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) + case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { - case s: Star => s.expand(t.child.output, resolver) + case s: Star => s.expand(t.child, resolver) case o => o :: Nil } ) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy( - aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output, resolver) - case o => o :: Nil - } - ) + val expanded = expandStarExpressions(a.aggregateExpressions, a.child) + .map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. case j @ Join(left, right, _, _) if !j.selfJoinResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c97365003935e..6975662e2b738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.{TableIdentifier, errors} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.{TableIdentifier, errors} +import org.apache.spark.sql.types.{DataType, StructType} /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] + def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] } @@ -166,26 +166,68 @@ abstract class Star extends LeafExpression with NamedExpression { * Represents all of the input attributes to a given relational operator, for example in * "SELECT * FROM ...". * - * @param table an optional table that should be the target of the expansion. If omitted all - * tables' columns are produced. + * This is also used to expand structs. For example: + * "SELECT record.* from (SELECT struct(a,b,c) as record ...) + * + * @param target an optional name that should be the target of the expansion. If omitted all + * targets' columns are produced. This can either be a table name or struct name. This + * is a list of identifiers that is the path of the expansion. */ -case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { +case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { + + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { - val expandedAttributes: Seq[Attribute] = table match { + // First try to expand assuming it is table.*. + val expandedAttributes: Seq[Attribute] = target match { // If there is no table specified, use all input attributes. - case None => input + case None => input.output // If there is a table, pick out attributes that are part of this table. - case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) + case Some(t) => if (t.size == 1) { + input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty) + } else { + List() + } } - expandedAttributes.zip(input).map { - case (n: NamedExpression, _) => n - case (e, originalAttribute) => - Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + if (!expandedAttributes.isEmpty) { + if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) { + return expandedAttributes + } else { + require(expandedAttributes.size == input.output.size) + expandedAttributes.zip(input.output).map { + case (e, originalAttribute) => + Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + } + } + return expandedAttributes + } + + require(target.isDefined) + + // Try to resolve it as a struct expansion. If there is a conflict and both are possible, + // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. + val attribute = input.resolve(target.get, resolver) + if (attribute.isDefined) { + // This target resolved to an attribute in child. It must be a struct. Expand it. + attribute.get.dataType match { + case s: StructType => { + s.fields.map( f => { + val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + Alias(extract, target.get + "." + f.name)() + }) + } + case _ => { + throw new AnalysisException("Can only star expand struct data types. Attribute: `" + + target.get + "`") + } + } + } else { + val from = input.inputSet.map(_.name).mkString(", ") + val targetString = target.get.mkString(".") + throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") } } - override def toString: String = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = target.map(_ + ".").getOrElse("") + "*" } /** @@ -225,7 +267,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { - override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e4f4cf1533ac4..3cde9d6cb4708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -60,7 +60,8 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName( + name.substring(0, name.length - 2)))) case _ => UnresolvedAttribute.quotedString(name) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5413ef1287da1..ee54bff24b196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1932,4 +1932,137 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(sampled.count() == sampledOdd.count() + sampledEven.count()) } } + + test("Struct Star Expansion") { + val structDf = testData2.select("a", "b").as("record") + + checkAnswer( + structDf.select($"record.a", $"record.b"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer( + structDf.select($"record.*", $"record.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 1, 2) :: Row(2, 1, 2, 1) :: Row(2, 2, 2, 2) :: + Row(3, 1, 3, 1) :: Row(3, 2, 3, 2) :: Nil) + + checkAnswer( + sql("select struct(a, b) as r1, struct(b, a) as r2 from testData2").select($"r1.*", $"r2.*"), + Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) :: + Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil) + + // Try with a registered table. + sql("select struct(a, b) as record from testData2").registerTempTable("structTable") + checkAnswer(sql("SELECT record.* FROM structTable"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin), + Row(Row(1, 1)) :: Nil) + + // Try with an alias on the select list + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select struct(a,b) as record from testData2) tmp + """.stripMargin).select($"r.*"), + Row(3, 2) :: Nil) + + // With GROUP BY + checkAnswer(sql( + """ + | SELECT min(struct(record.*)) FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin), + Row(Row(1, 1)) :: Row(Row(2, 1)) :: Row(Row(3, 1)) :: Nil) + + // With GROUP BY and alias + checkAnswer(sql( + """ + | SELECT max(struct(record.*)) as r FROM + | (select a as a, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil) + + // With GROUP BY and alias and additional fields in the struct + checkAnswer(sql( + """ + | SELECT max(struct(a, record.*, b)) as r FROM + | (select a as a, b as b, struct(a,b) as record from testData2) tmp + | GROUP BY a + """.stripMargin).select($"r.*"), + Row(1, 1, 2, 2) :: Row(2, 2, 2, 2) :: Row(3, 3, 2, 2) :: Nil) + + // Create a data set that contains nested structs. + val nestedStructData = sql( + """ + | SELECT struct(r1, r2) as record FROM + | (SELECT struct(a, b) as r1, struct(b, a) as r2 FROM testData2) tmp + """.stripMargin) + + checkAnswer(nestedStructData.select($"record.*"), + Row(Row(1, 1), Row(1, 1)) :: Row(Row(1, 2), Row(2, 1)) :: Row(Row(2, 1), Row(1, 2)) :: + Row(Row(2, 2), Row(2, 2)) :: Row(Row(3, 1), Row(1, 3)) :: Row(Row(3, 2), Row(2, 3)) :: Nil) + checkAnswer(nestedStructData.select($"record.r1"), + Row(Row(1, 1)) :: Row(Row(1, 2)) :: Row(Row(2, 1)) :: Row(Row(2, 2)) :: + Row(Row(3, 1)) :: Row(Row(3, 2)) :: Nil) + checkAnswer( + nestedStructData.select($"record.r1.*"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + + // Try with a registered table + nestedStructData.registerTempTable("nestedStructTable") + checkAnswer(sql("SELECT record.* FROM nestedStructTable"), + nestedStructData.select($"record.*")) + checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"), + nestedStructData.select($"record.r1")) + checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"), + nestedStructData.select($"record.r1.*")) + + // Create paths with unusual characters. + val specialCharacterPath = sql( + """ + | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM + | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp + """.stripMargin) + specialCharacterPath.registerTempTable("specialCharacterTable") + checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) + checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + + // Try star expanding a scalar. This should fail. + assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains( + "Can only star expand struct data types.")) + + // Try resolving something not there. + assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) + .getMessage.contains("cannot resolve")) + } + + + test("Struct Star Expansion - Name conflict") { + // Create a data set that contains a naming conflict + val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") + nameConflict.registerTempTable("nameConflict") + // Unqualified should resolve to table. + checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), + Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: + Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) + // Qualify the struct type with the table name. + checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3697761f20c28..ab88c1e68fd72 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1505,7 +1505,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(name)) + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) From efaa4721b511a1d29229facde6457a6dcda18966 Mon Sep 17 00:00:00 2001 From: Yves Raimond Date: Mon, 2 Nov 2015 20:35:59 -0800 Subject: [PATCH 135/324] [SPARK-11432][GRAPHX] Personalized PageRank shouldn't use uniform initialization Changes the personalized pagerank initialization to be non-uniform. Author: Yves Raimond Closes #9386 from moustaki/personalized-pagerank-init. --- .../apache/spark/graphx/lib/PageRank.scala | 29 ++++++++++++------- .../spark/graphx/lib/PageRankSuite.scala | 13 ++++++--- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 8c0a461e99fa4..52b237fc15093 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -104,18 +104,23 @@ object PageRank extends Logging { graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + val personalized = srcId isDefined + val src: VertexId = srcId.getOrElse(-1L) + // Initialize the PageRank graph with each edge attribute having - // weight 1/outDegree and each vertex with attribute 1.0. + // weight 1/outDegree and each vertex with attribute resetProb. + // When running personalized pagerank, only the source vertex + // has an attribute resetProb. All others are set to 0. var rankGraph: Graph[Double, Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values - .mapVertices( (id, attr) => resetProb ) + .mapVertices { (id, attr) => + if (!(id != src && personalized)) resetProb else 0.0 + } - val personalized = srcId isDefined - val src: VertexId = srcId.getOrElse(-1L) def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 @@ -192,6 +197,9 @@ object PageRank extends Logging { graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, srcId: Option[VertexId] = None): Graph[Double, Double] = { + val personalized = srcId.isDefined + val src: VertexId = srcId.getOrElse(-1L) + // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. val pagerankGraph: Graph[(Double, Double), Double] = graph @@ -202,13 +210,11 @@ object PageRank extends Logging { // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initalPR, delta = 0) - .mapVertices( (id, attr) => (0.0, 0.0) ) + .mapVertices { (id, attr) => + if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0) + } .cache() - val personalized = srcId.isDefined - val src: VertexId = srcId.getOrElse(-1L) - - // Define the three functions needed to implement PageRank in the GraphX // version of Pregel def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { @@ -225,7 +231,8 @@ object PageRank extends Logging { teleport = oldPR*delta val newPR = teleport + (1.0 - resetProb) * msgSum - (newPR, newPR - oldPR) + val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR + (newPR, newDelta) } def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { @@ -239,7 +246,7 @@ object PageRank extends Logging { def messageCombiner(a: Double, b: Double): Double = a + b // The initial message received by all vertices in PageRank - val initialMessage = resetProb / (1.0 - resetProb) + val initialMessage = if (personalized) 0.0 else resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. val vp = if (personalized) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 45f1e3011035e..bdff31446f8ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -109,17 +109,22 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { assert(notMatching === 0) val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == resetProb) || - (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * - (nVertices - 1)) )) < 1.0E-5) + val correct = (vid > 0 && pr == 0.0) || + (vid == 0 && pr == resetProb) if (!correct) 1 else 0 } assert(staticErrors.sum === 0) val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + + // We have one outbound edge from 1 to 0 + val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb) + .vertices.cache() + val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache() + assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol) } - } // end of test Star PageRank + } // end of test Star PersonalPageRank test("Grid PageRank") { withSpark { sc => From 9cf56c96b7d02a14175d40b336da14c2e1c88339 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 2 Nov 2015 21:18:38 -0800 Subject: [PATCH 136/324] [SPARK-11469][SQL] Allow users to define nondeterministic udfs. This is the first task (https://issues.apache.org/jira/browse/SPARK-11469) of https://issues.apache.org/jira/browse/SPARK-11438 Author: Yin Huai Closes #9393 from yhuai/udfNondeterministic. --- project/MimaExcludes.scala | 47 +++++ .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../apache/spark/sql/UDFRegistration.scala | 164 ++++++++++-------- .../spark/sql/UserDefinedFunction.scala | 13 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 105 +++++++++++ .../datasources/parquet/ParquetIOSuite.scala | 4 +- 6 files changed, 262 insertions(+), 78 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8282f7ea62400..ec0e44b7f2d66 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -112,6 +112,53 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 11c7950c0613b..a04af7f1dd877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,13 +30,18 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) + inputTypes: Seq[DataType] = Nil, + isDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true override def toString: String = s"UDF(${children.mkString(",")})" + override def foldable: Boolean = deterministic && children.forall(_.foldable) + + override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic) + // scalastyle:off /** This method has been generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c533a..f5b95e13e47bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -58,8 +58,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. - * @param udaf the UDAF needs to be registered. + * @param udaf the UDAF that needs to be registered. * @return the registered UDAF. + * + * @since 1.5.0 */ def register( name: String, @@ -69,6 +71,22 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } + /** + * Register a user-defined function (UDF). + * + * @param name the name of the UDF. + * @param udf the UDF that needs to be registered. + * @return the registered UDF. + * + * @since 1.6.0 + */ + def register( + name: String, + udf: UserDefinedFunction): UserDefinedFunction = { + functionRegistry.registerFunction(name, udf.builder) + udf + } + // scalastyle:off /* register 0-22 were generated by this script @@ -86,9 +104,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try($inputTypes).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf }""") } @@ -118,9 +136,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -131,9 +149,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -144,9 +162,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -157,9 +175,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -170,9 +188,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -183,9 +201,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -196,9 +214,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -209,9 +227,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -222,9 +240,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -235,9 +253,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -248,9 +266,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -261,9 +279,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -274,9 +292,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -287,9 +305,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -300,9 +318,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -313,9 +331,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -326,9 +344,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -339,9 +357,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -352,9 +370,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -365,9 +383,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -378,9 +396,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -391,9 +409,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -404,9 +422,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 0f8cd280b5acb..1319391db5375 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,11 +44,20 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil) { + inputTypes: Seq[DataType] = Nil, + deterministic: Boolean = true) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic)) } + + protected[sql] def builder: Seq[Expression] => ScalaUDF = { + (exprs: Seq[Expression]) => + ScalaUDF(f, dataType, exprs, inputTypes, deterministic) + } + + def nondeterministic: UserDefinedFunction = + UserDefinedFunction(f, dataType, inputTypes, deterministic = false) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0435a0dba6ad..6e510f0b8aff4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } + + private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { + val udfs = df.queryExecution.optimizedPlan.collect { + case p: logical.Project => p.projectList.flatMap { + case e => e.collect { + case udf: ScalaUDF => udf + } + } + }.flatten + assert(udfs.length === expectedNumUDFs) + } + + test("foldable udf") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 0) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using UDFRegistration") { + import org.apache.spark.sql.functions._ + + val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) + sqlContext.udf.register("plusOne2", myUDF.nondeterministic) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne1", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne2", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using udf function") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + // nondeterministicUDF will not be foldable. + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("override a registered udf") { + sqlContext.udf.register("intExpected", (x: Int) => x) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + + sqlContext.udf.register("intExpected", (x: Int) => x + 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 72744799897be..f14b2886a9ecb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) From c34c27fe9244939d8c905cd689536dfb81c74d7d Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Mon, 2 Nov 2015 23:52:36 -0800 Subject: [PATCH 137/324] [SPARK-9034][SQL] Reflect field names defined in GenericUDTF Hive GenericUDTF#initialize() defines field names in a returned schema though, the current HiveGenericUDTF drops these names. We might need to reflect these in a logical plan tree. Author: navis.ryu Closes #8456 from navis/SPARK-9034. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 11 +++++------ .../spark/sql/catalyst/expressions/generators.scala | 12 +++++++----- .../main/scala/org/apache/spark/sql/DataFrame.scala | 10 +++++----- .../sql/hive/execution/HiveCompatibilitySuite.scala | 1 + .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- ...GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e | 1 + ...GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 | 1 + ...l_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + ...l_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 | 0 ...l_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 | 2 ++ ...l_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a | 0 ...l_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 | 2 ++ ...al_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe | 0 ...l_view_noalias-6-16d227442dd775615c6ecfceedc6c612 | 0 ...l_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 | 2 ++ .../spark/sql/hive/execution/HiveQuerySuite.scala | 6 ++++++ 16 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e create mode 100644 sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-6-16d227442dd775615c6ecfceedc6c612 create mode 100644 sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 912c967b95f08..899ee67352df4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -147,7 +147,7 @@ class Analyzer( case u @ UnresolvedAlias(child) => child match { case ne: NamedExpression => ne case e if !e.resolved => u - case g: Generator if g.elementTypes.size > 1 => MultiAlias(g, Nil) + case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() case other => Alias(other, s"_c$i")() } @@ -722,7 +722,7 @@ class Analyzer( /** * Construct the output attributes for a [[Generator]], given a list of names. If the list of - * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + * names is empty names are assigned from field names in generator. */ private def makeGeneratorOutput( generator: Generator, @@ -731,13 +731,12 @@ class Analyzer( if (names.length == elementTypes.length) { names.zip(elementTypes).map { - case (name, (t, nullable)) => + case (name, (t, nullable, _)) => AttributeReference(name, t, nullable)() } } else if (names.isEmpty) { - elementTypes.zipWithIndex.map { - // keep the default column names as Hive does _c0, _c1, _cN - case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + elementTypes.map { + case (t, nullable, name) => AttributeReference(name, t, nullable)() } } else { failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1a2092c909c56..894a0730d1c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -53,7 +53,7 @@ trait Generator extends Expression { * The output element data types in structure of Seq[(DataType, Nullable)] * TODO we probably need to add more information like metadata etc. */ - def elementTypes: Seq[(DataType, Boolean)] + def elementTypes: Seq[(DataType, Boolean, String)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: InternalRow): TraversableOnce[InternalRow] @@ -69,7 +69,7 @@ trait Generator extends Expression { * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - elementTypes: Seq[(DataType, Boolean)], + elementTypes: Seq[(DataType, Boolean, String)], function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator with CodegenFallback { @@ -112,9 +112,11 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } } - override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { - case ArrayType(et, containsNull) => (et, containsNull) :: Nil - case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil + // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) + override def elementTypes: Seq[(DataType, Boolean, String)] = child.dataType match { + case ArrayType(et, containsNull) => (et, containsNull, "col") :: Nil + case MapType(kt, vt, valueContainsNull) => + (kt, false, "key") :: (vt, valueContainsNull, "value") :: Nil } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 53ad3c0266cdb..fc0ab632f9930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1175,7 +1175,8 @@ class DataFrame private[sql]( def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val elementTypes = schema.toAttributes.map { + attr => (attr.dataType, attr.nullable, attr.name) } val names = schema.toAttributes.map(_.name) val convert = CatalystTypeConverters.createToCatalystConverter(schema) @@ -1184,7 +1185,7 @@ class DataFrame private[sql]( val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + qualifier = None, generatorOutput = Nil, logicalPlan) } /** @@ -1203,8 +1204,7 @@ class DataFrame private[sql]( val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? - val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } - val names = attributes.map(_.name) + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } def rowFunction(row: Row): TraversableOnce[InternalRow] = { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) @@ -1213,7 +1213,7 @@ class DataFrame private[sql]( val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) Generate(generator, join = true, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) + qualifier = None, generatorOutput = Nil, logicalPlan) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 6ed40b03975d0..2d0d7b8af3581 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -661,6 +661,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_star", "lateral_view", "lateral_view_cp", + "lateral_view_noalias", "lateral_view_ppd", "leftsemijoin", "leftsemijoin_mr", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0b5e863506142..a9db70119d011 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -511,7 +511,7 @@ private[hive] case class HiveGenericUDTF( protected lazy val collector = new UDTFCollector override lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { - field => (inspectorToDataType(field.getFieldObjectInspector), true) + field => (inspectorToDataType(field.getFieldObjectInspector), true, field.getFieldName) } @transient diff --git a/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e new file mode 100644 index 0000000000000..1cf253f92c055 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #1-0-ff502d8c06f4b32f57aa45057b7fab0e @@ -0,0 +1 @@ +238 diff --git a/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-9034 Reflect field names defined in GenericUDTF #2-0-d6d0def30a7fad5f90fd835361820c30 @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 b/sql/hive/src/test/resources/golden/lateral_view_noalias-1-72509f06e1f7c5d5ccc292f775f8eea7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 b/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 new file mode 100644 index 0000000000000..0da0d93886e01 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-2-6d5806dd1d2511911a5de1e205523f42 @@ -0,0 +1,2 @@ +key1 100 +key2 200 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a b/sql/hive/src/test/resources/golden/lateral_view_noalias-3-155b3cc2f5054725a9c2acca3c38c00a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 b/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 new file mode 100644 index 0000000000000..0da0d93886e01 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-4-3b7045ace234af8e5e86d8ac23ccee56 @@ -0,0 +1,2 @@ +key1 100 +key2 200 diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe b/sql/hive/src/test/resources/golden/lateral_view_noalias-5-e1eca4e08216897d090259d4fd1e3fe new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-6-16d227442dd775615c6ecfceedc6c612 b/sql/hive/src/test/resources/golden/lateral_view_noalias-6-16d227442dd775615c6ecfceedc6c612 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 b/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 new file mode 100644 index 0000000000000..4ba46bbda5b04 --- /dev/null +++ b/sql/hive/src/test/resources/golden/lateral_view_noalias-7-66cb5ab20690dd85b2ed95bbfb9481d3 @@ -0,0 +1,2 @@ +key1 100 key1 100 +key2 200 key2 200 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e597d6865f67a..fc72e3c7dc6aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -563,6 +563,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("Specify the udtf output", "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #1", + "SELECT col FROM (SELECT explode(array(key,value)) FROM src LIMIT 1) t") + + createQueryTest("SPARK-9034 Reflect field names defined in GenericUDTF #2", + "SELECT key,value FROM (SELECT explode(map(key,value)) FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") From d728d5c98658c44ed2949b55d36edeaa46f8c980 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 3 Nov 2015 00:12:49 -0800 Subject: [PATCH 138/324] [SPARK-9858][SPARK-9859][SPARK-9861][SQL] Add an ExchangeCoordinator to estimate the number of post-shuffle partitions for aggregates and joins https://issues.apache.org/jira/browse/SPARK-9858 https://issues.apache.org/jira/browse/SPARK-9859 https://issues.apache.org/jira/browse/SPARK-9861 Author: Yin Huai Closes #9276 from yhuai/numReducer. --- .../plans/physical/partitioning.scala | 8 + .../scala/org/apache/spark/sql/SQLConf.scala | 27 + .../apache/spark/sql/execution/Exchange.scala | 217 +++++++- .../sql/execution/ExchangeCoordinator.scala | 260 ++++++++++ .../spark/sql/execution/ShuffledRowRDD.scala | 134 ++++- .../execution/ExchangeCoordinatorSuite.scala | 479 ++++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 8 +- .../execution/UnsafeRowSerializerSuite.scala | 7 +- .../sql/execution/joins/InnerJoinSuite.scala | 19 +- 9 files changed, 1115 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 86b9417477ba3..9312c8123e92e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -165,6 +165,11 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other + + def withNumPartitions(newNumPartitions: Int): Partitioning = { + throw new IllegalStateException( + s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}") + } } object Partitioning { @@ -249,6 +254,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def withNumPartitions(newNumPartitions: Int): HashPartitioning = { + HashPartitioning(expressions, newNumPartitions) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6f2892085a8f8..ed8b634ad5630 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -233,6 +233,25 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = + longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", + defaultValue = Some(64 * 1024 * 1024), + doc = "The target post-shuffle input size in bytes of a task.") + + val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled", + defaultValue = Some(false), + doc = "When true, enable adaptive query execution.") + + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = + intConf("spark.sql.adaptive.minNumPostShufflePartitions", + defaultValue = Some(-1), + doc = "The advisory minimal number of post-shuffle partitions provided to " + + "ExchangeCoordinator. This setting is used in our test to make sure we " + + "have enough parallelism to expose issues that will not be exposed with a " + + "single partition. When the value is a non-positive value, this setting will" + + "not be provided to ExchangeCoordinator.", + isPublic = false) + val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", defaultValue = Some(true), doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + @@ -487,6 +506,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + private[spark] def targetPostShuffleInputSize: Long = + getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) + + private[spark] def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + + private[spark] def minNumPostShufflePartitions: Int = + getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) + private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index e81108b7884d1..0f72ec6cc107a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -36,9 +36,23 @@ import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + var newPartitioning: Partitioning, + child: SparkPlan, + @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { - override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" + override def nodeName: String = { + val extraInfo = coordinator match { + case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => + "Shuffle" + case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => + "May shuffle" + case None => "Shuffle without coordinator" + } + + val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + s"$simpleNodeName($extraInfo)" + } /** * Returns true iff we can support the data type, and we are not doing range partitioning. @@ -129,7 +143,27 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } - protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + override protected def doPrepare(): Unit = { + // If an ExchangeCoordinator is needed, we register this Exchange operator + // to the coordinator when we do prepare. It is important to make sure + // we register this operator right before the execution instead of register it + // in the constructor because it is possible that we create new instances of + // Exchange operators when we transform the physical plan + // (then the ExchangeCoordinator will hold references of unneeded Exchanges). + // So, we should only call registerExchange just before we start to execute + // the plan. + coordinator match { + case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) + case None => + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { val rdd = child.execute() val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) @@ -181,7 +215,54 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } } - new ShuffledRowRDD(rddWithPartitionIds, serializer, part.numPartitions) + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + Some(serializer)) + + dependency + } + + /** + * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. + * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional + * partition start indices array. If this optional array is defined, the returned + * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. + */ + private[sql] def preparePostShuffleRDD( + shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { + // If an array of partition start indices is provided, we need to use this array + // to create the ShuffledRowRDD. Also, we need to update newPartitioning to + // update the number of post-shuffle partitions. + specifiedPartitionStartIndices.foreach { indices => + assert(newPartitioning.isInstanceOf[HashPartitioning]) + newPartitioning = newPartitioning.withNumPartitions(indices.length) + } + new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { + coordinator match { + case Some(exchangeCoordinator) => + val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) + assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) + shuffleRDD + case None => + val shuffleDependency = prepareShuffleDependency() + preparePostShuffleRDD(shuffleDependency) + } + } +} + +object Exchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { + Exchange(newPartitioning, child, None: Option[ExchangeCoordinator]) } } @@ -193,13 +274,22 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * input partition ordering requirements are met. */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - // TODO: Determine the number of partitions. - private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions + private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions + + private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + + private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + + private def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } /** * Given a required distribution, returns a partitioning that satisfies that distribution. */ - private def createPartitioning(requiredDistribution: Distribution, + private def createPartitioning( + requiredDistribution: Distribution, numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition @@ -209,6 +299,98 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } } + /** + * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled + * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]]. + */ + private def withExchangeCoordinator( + children: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + val supportsCoordinator = + if (children.exists(_.isInstanceOf[Exchange])) { + // Right now, ExchangeCoordinator only support HashPartitionings. + children.forall { + case e @ Exchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.exists(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withCoordinator = + if (adaptiveExecutionEnabled && supportsCoordinator) { + val coordinator = + new ExchangeCoordinator( + children.length, + targetPostShuffleInputSize, + minNumPostShufflePartitions) + children.zip(requiredChildDistributions).map { + case (e: Exchange, _) => + // This child is an Exchange, we need to add the coordinator. + e.copy(coordinator = Some(coordinator)) + case (child, distribution) => + // If this child is not an Exchange, we need to add an Exchange for now. + // Ideally, we can try to avoid this Exchange. However, when we reach here, + // there are at least two children operators (because if there is a single child + // and we can avoid Exchange, supportsCoordinator will be false and we + // will not reach here.). Although we can make two children have the same number of + // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. + // For example, let's say we have the following plan + // Join + // / \ + // Agg Exchange + // / \ + // Exchange t2 + // / + // t1 + // In this case, because a post-shuffle partition can include multiple pre-shuffle + // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes + // after shuffle. So, even we can use the child Exchange operator of the Join to + // have a number of post-shuffle partitions that matches the number of partitions of + // Agg, we cannot say these two children are partitioned in the same way. + // Here is another case + // Join + // / \ + // Agg1 Agg2 + // / \ + // Exchange1 Exchange2 + // / \ + // t1 t2 + // In this case, two Aggs shuffle data with the same column of the join condition. + // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same + // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 + // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle + // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its + // pre-shuffle partitions by using another partitionStartIndices [0, 4]. + // So, Agg1 and Agg2 are actually not co-partitioned. + // + // It will be great to introduce a new Partitioning to represent the post-shuffle + // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) + assert(targetPartitioning.isInstanceOf[HashPartitioning]) + Exchange(targetPartitioning, child, Some(coordinator)) + } + } else { + // If we do not need ExchangeCoordinator, the original children are returned. + children + } + + withCoordinator + } + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering @@ -221,7 +403,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (child.outputPartitioning.satisfies(distribution)) { child } else { - Exchange(createPartitioning(distribution, defaultPartitions), child) + Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } } @@ -234,7 +416,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // First check if the existing partitions of the children all match. This means they are // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. - // TODO: this should be a cost based descision. For example, a big relation should probably + // TODO: this should be a cost based decision. For example, a big relation should probably // maintain its existing number of partitions and smaller partitions should be shuffled. // defaultPartitions is arbitrary. val numPartitions = children.head.outputPartitioning.numPartitions @@ -250,7 +432,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } else { children.zip(requiredChildDistributions).map { case (child, distribution) => { - val targetPartitioning = createPartitioning(distribution, defaultPartitions) + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { @@ -261,12 +444,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } } + // Now, we need to add ExchangeCoordinator if necessary. + // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. + // However, with the way that we plan the query, we do not have a place where we have a + // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator + // at here for now. + // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, + // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. + children = withExchangeCoordinator(children, requiredChildDistributions) + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child) + sqlContext.planner.BasicOperators.getSortOperator( + requiredOrdering, + global = false, + child) } else { child } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala new file mode 100644 index 0000000000000..8dbd69e1f44b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.{Map => JMap, HashMap => JHashMap} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, MapOutputStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. + * Right now, the work of this coordinator is to determine the number of post-shuffle partitions + * for a stage that needs to fetch shuffle data from one or multiple stages. + * + * A coordinator is constructed with three parameters, `numExchanges`, + * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. + * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to + * this coordinator. So, when we start to do any actual work, we have a way to make sure that + * we have got expected number of [[Exchange]]s. + * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's + * input data size. With this parameter, we can estimate the number of post-shuffle partitions. + * This parameter is configured through + * `spark.sql.adaptive.shuffle.targetPostShuffleInputSize`. + * - `minNumPostShufflePartitions` is an optional parameter. If it is defined, this coordinator + * will try to make sure that there are at least `minNumPostShufflePartitions` post-shuffle + * partitions. + * + * The workflow of this coordinator is described as follows: + * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator, + * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. + * This happens in the `doPrepare` method. + * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will + * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]]. + * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will + * immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. + * - If this coordinator has not made the decision on how to shuffle data, it will ask those + * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size + * statistics of pre-shuffle partitions, this coordinator will determine the number of + * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices + * to a single post-shuffle partition whenever necessary. + * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered + * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can + * lookup the corresponding [[RDD]]. + * + * The strategy used to determine the number of post-shuffle partitions is described as follows. + * To determine the number of post-shuffle partitions, we have a target input size for a + * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages + * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and + * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until + * the size of a post-shuffle partition is equal or greater than the target size. + * For example, we have two stages with the following pre-shuffle partition size statistics: + * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] + * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] + * assuming the target input size is 128 MB, we will have three post-shuffle partitions, + * which are: + * - post-shuffle partition 0: pre-shuffle partition 0 and 1 + * - post-shuffle partition 1: pre-shuffle partition 2 + * - post-shuffle partition 2: pre-shuffle partition 3 and 4 + */ +private[sql] class ExchangeCoordinator( + numExchanges: Int, + advisoryTargetPostShuffleInputSize: Long, + minNumPostShufflePartitions: Option[Int] = None) + extends Logging { + + // The registered Exchange operators. + private[this] val exchanges = ArrayBuffer[Exchange]() + + // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. + private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] = + new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + + // A boolean that indicates if this coordinator has made decision on how to shuffle data. + // This variable will only be updated by doEstimationIfNecessary, which is protected by + // synchronized. + @volatile private[this] var estimated: Boolean = false + + /** + * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be + * called in the `doPrepare` method of an [[Exchange]] operator. + */ + def registerExchange(exchange: Exchange): Unit = synchronized { + exchanges += exchange + } + + def isEstimated: Boolean = estimated + + /** + * Estimates partition start indices for post-shuffle partitions based on + * mapOutputStatistics provided by all pre-shuffle stages. + */ + private[sql] def estimatePartitionStartIndices( + mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { + // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit + // a stage when the number of partitions of this dependency is 0. + assert(mapOutputStatistics.length <= numExchanges) + + // If minNumPostShufflePartitions is defined, it is possible that we need to use a + // value less than advisoryTargetPostShuffleInputSize as the target input size of + // a post shuffle task. + val targetPostShuffleInputSize = minNumPostShufflePartitions match { + case Some(numPartitions) => + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we + // only have a single post-shuffle partition. + val maxPostShuffleInputSize = + math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) + math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) + + case None => advisoryTargetPostShuffleInputSize + } + + logInfo( + s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + + s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + + // Make sure we do get the same number of pre-shuffle partitions for those stages. + val distinctNumPreShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + assert( + distinctNumPreShufflePartitions.length == 1, + "There should be only one distinct value of the number pre-shuffle partitions " + + "among registered Exchange operator.") + val numPreShufflePartitions = distinctNumPreShufflePartitions.head + + val partitionStartIndices = ArrayBuffer[Int]() + // The first element of partitionStartIndices is always 0. + partitionStartIndices += 0 + + var postShuffleInputSize = 0L + + var i = 0 + while (i < numPreShufflePartitions) { + // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. + // Then, we add the total size to postShuffleInputSize. + var j = 0 + while (j < mapOutputStatistics.length) { + postShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If the current postShuffleInputSize is equal or greater than the + // targetPostShuffleInputSize, We need to add a new element in partitionStartIndices. + if (postShuffleInputSize >= targetPostShuffleInputSize) { + if (i < numPreShufflePartitions - 1) { + // Next start index. + partitionStartIndices += i + 1 + } else { + // This is the last element. So, we do not need to append the next start index to + // partitionStartIndices. + } + // reset postShuffleInputSize. + postShuffleInputSize = 0L + } + + i += 1 + } + + partitionStartIndices.toArray + } + + private def doEstimationIfNecessary(): Unit = synchronized { + // It is unlikely that this method will be called from multiple threads + // (when multiple threads trigger the execution of THIS physical) + // because in common use cases, we will create new physical plan after + // users apply operations (e.g. projection) to an existing DataFrame. + // However, if it happens, we have synchronized to make sure only one + // thread will trigger the job submission. + if (!estimated) { + // Make sure we have the expected number of registered Exchange operators. + assert(exchanges.length == numExchanges) + + val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + + // Submit all map stages + val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() + val submittedStageFutures = ArrayBuffer[SimpleFutureAction[MapOutputStatistics]]() + var i = 0 + while (i < numExchanges) { + val exchange = exchanges(i) + val shuffleDependency = exchange.prepareShuffleDependency() + shuffleDependencies += shuffleDependency + if (shuffleDependency.rdd.partitions.length != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + submittedStageFutures += + exchange.sqlContext.sparkContext.submitMapStage(shuffleDependency) + } + i += 1 + } + + // Wait for the finishes of those submitted map stages. + val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) + i = 0 + while (i < submittedStageFutures.length) { + // This call is a blocking call. If the stage has not finished, we will wait at here. + mapOutputStatistics(i) = submittedStageFutures(i).get() + i += 1 + } + + // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the + // number of post-shuffle partitions. + val partitionStartIndices = + if (mapOutputStatistics.length == 0) { + None + } else { + Some(estimatePartitionStartIndices(mapOutputStatistics)) + } + + i = 0 + while (i < numExchanges) { + val exchange = exchanges(i) + val rdd = + exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices) + newPostShuffleRDDs.put(exchange, rdd) + + i += 1 + } + + // Finally, we set postShuffleRDDs and estimated. + assert(postShuffleRDDs.isEmpty) + assert(newPostShuffleRDDs.size() == numExchanges) + postShuffleRDDs.putAll(newPostShuffleRDDs) + estimated = true + } + } + + def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = { + doEstimationIfNecessary() + + if (!postShuffleRDDs.containsKey(exchange)) { + throw new IllegalStateException( + s"The given $exchange is not registered in this coordinator.") + } + + postShuffleRDDs.get(exchange) + } + + override def toString: String = { + s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index fb338b90bf79b..42891287a3006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -17,14 +17,23 @@ package org.apache.spark.sql.execution +import java.util.Arrays + import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow -private class ShuffledRowRDDPartition(val idx: Int) extends Partition { - override val index: Int = idx - override def hashCode(): Int = idx +/** + * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition + * (identified by `postShufflePartitionIndex`) contains a range of pre-shuffle partitions + * (`startPreShufflePartitionIndex` to `endPreShufflePartitionIndex - 1`, inclusive). + */ +private final class ShuffledRowRDDPartition( + val postShufflePartitionIndex: Int, + val startPreShufflePartitionIndex: Int, + val endPreShufflePartitionIndex: Int) extends Partition { + override val index: Int = postShufflePartitionIndex + override def hashCode(): Int = postShufflePartitionIndex } /** @@ -35,33 +44,107 @@ private class PartitionIdPassthrough(override val numPartitions: Int) extends Pa override def getPartition(key: Any): Int = key.asInstanceOf[Int] } +/** + * A Partitioner that might group together one or more partitions from the parent. + * + * @param parent a parent partitioner + * @param partitionStartIndices indices of partitions in parent that should create new partitions + * in child (this should be an array of increasing partition IDs). For example, if we have a + * parent with 5 partitions, and partitionStartIndices is [0, 2, 4], we get three output + * partitions, corresponding to partition ranges [0, 1], [2, 3] and [4] of the parent partitioner. + */ +class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: Array[Int]) + extends Partitioner { + + @transient private lazy val parentPartitionMapping: Array[Int] = { + val n = parent.numPartitions + val result = new Array[Int](n) + for (i <- 0 until partitionStartIndices.length) { + val start = partitionStartIndices(i) + val end = if (i < partitionStartIndices.length - 1) partitionStartIndices(i + 1) else n + for (j <- start until end) { + result(j) = i + } + } + result + } + + override def numPartitions: Int = partitionStartIndices.length + + override def getPartition(key: Any): Int = { + parentPartitionMapping(parent.getPartition(key)) + } + + override def equals(other: Any): Boolean = other match { + case c: CoalescedPartitioner => + c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices) + case _ => + false + } + + override def hashCode(): Int = 31 * parent.hashCode() + Arrays.hashCode(partitionStartIndices) +} + /** * This is a specialized version of [[org.apache.spark.rdd.ShuffledRDD]] that is optimized for * shuffling rows instead of Java key-value pairs. Note that something like this should eventually * be implemented in Spark core, but that is blocked by some more general refactorings to shuffle * interfaces / internals. * - * @param prev the RDD being shuffled. Elements of this RDD are (partitionId, Row) pairs. - * Partition ids should be in the range [0, numPartitions - 1]. - * @param serializer the serializer used during the shuffle. - * @param numPartitions the number of post-shuffle partitions. + * This RDD takes a [[ShuffleDependency]] (`dependency`), + * and a optional array of partition start indices as input arguments + * (`specifiedPartitionStartIndices`). + * + * The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle + * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * `dependency.partitioner` is the original partitioner used to partition + * map output, and `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions + * (i.e. the number of partitions of the map output). + * + * When `specifiedPartitionStartIndices` is defined, `specifiedPartitionStartIndices.length` + * will be the number of post-shuffle partitions. For this case, the `i`th post-shuffle + * partition includes `specifiedPartitionStartIndices[i]` to + * `specifiedPartitionStartIndices[i+1] - 1` (inclusive). + * + * When `specifiedPartitionStartIndices` is not defined, there will be + * `dependency.partitioner.numPartitions` post-shuffle partitions. For this case, + * a post-shuffle partition is created for every pre-shuffle partition. */ class ShuffledRowRDD( - @transient var prev: RDD[Product2[Int, InternalRow]], - serializer: Serializer, - numPartitions: Int) - extends RDD[InternalRow](prev.context, Nil) { + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None) + extends RDD[InternalRow](dependency.rdd.context, Nil) { - private val part: Partitioner = new PartitionIdPassthrough(numPartitions) + private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency[Int, InternalRow, InternalRow](prev, part, Some(serializer))) + private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match { + case Some(indices) => indices + case None => + // When specifiedPartitionStartIndices is not defined, every post-shuffle partition + // corresponds to a pre-shuffle partition. + (0 until numPreShufflePartitions).toArray } - override val partitioner = Some(part) + private[this] val part: Partitioner = + new CoalescedPartitioner(dependency.partitioner, partitionStartIndices) + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override val partitioner: Option[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { - Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRowRDDPartition(i)) + assert(partitionStartIndices.length == part.numPartitions) + Array.tabulate[Partition](partitionStartIndices.length) { i => + val startIndex = partitionStartIndices(i) + val endIndex = + if (i < partitionStartIndices.length - 1) { + partitionStartIndices(i + 1) + } else { + numPreShufflePartitions + } + new ShuffledRowRDDPartition(i, startIndex, endIndex) + } } override def getPreferredLocations(partition: Partition): Seq[String] = { @@ -71,15 +154,20 @@ class ShuffledRowRDD( } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val dep = dependencies.head.asInstanceOf[ShuffleDependency[Int, InternalRow, InternalRow]] - SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) - .read() - .asInstanceOf[Iterator[Product2[Int, InternalRow]]] - .map(_._2) + val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition] + // The range of pre-shuffle partitions that we are fetching at here is + // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. + val reader = + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + shuffledRowPartition.startPreShufflePartitionIndex, + shuffledRowPartition.endPreShufflePartitionIndex, + context) + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } override def clearDependencies() { super.clearDependencies() - prev = null + dependency = null } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala new file mode 100644 index 0000000000000..25f2f5caeed15 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql._ +import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, MapOutputStatistics} + +class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSQLContext: Option[SQLContext] = _ + private var originalInstantiatedSQLContext: Option[SQLContext] = _ + + override protected def beforeAll(): Unit = { + originalActiveSQLContext = SQLContext.getActiveContextOption() + originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + + SQLContext.clearActive() + originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + } + + private def checkEstimation( + coordinator: ExchangeCoordinator, + bytesByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int]): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { + case (bytesByPartitionId, index) => + new MapOutputStatistics(index, bytesByPartitionId) + } + val estimatedPartitionStartIndices = + coordinator.estimatePartitionStartIndices(mapOutputStatistics) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + } + + test("test estimatePartitionStartIndices - 1 Exchange") { + val coordinator = new ExchangeCoordinator(1, 100L) + + { + // All bytes per partition are 0. + val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // Some bytes per partition are 0 and total size is less than the target size. + // 1 post-shuffle partition is needed. + val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partitions are needed. + val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // There are a few large pre-shuffle partitions. + val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) + val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // All pre-shuffle partitions are larger than the targeted size. + val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + + { + // The last pre-shuffle partition is in a single post-shuffle partition. + val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) + val expectedPartitionStartIndices = Array[Int](0, 4) + checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + } + } + + test("test estimatePartitionStartIndices - 2 Exchanges") { + val coordinator = new ExchangeCoordinator(2, 100L) + + { + // If there are multiple values of the number of pre-shuffle partitions, + // we should see an assertion error. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) + val mapOutputStatistics = + Array( + new MapOutputStatistics(0, bytesByPartitionId1), + new MapOutputStatistics(1, bytesByPartitionId2)) + intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) + } + + { + // All bytes per partition are 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // Some bytes per partition are 0. + // 1 post-shuffle partition is needed. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // There are a few large pre-shuffle partitions. + val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) + val expectedPartitionStartIndices = Array[Int](0, 2, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // All pairs of pre-shuffle partitions are larger than the targeted size. + val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + } + + test("test estimatePartitionStartIndices and enforce minimal number of reducers") { + val coordinator = new ExchangeCoordinator(2, 100L, Some(2)) + + { + // The minimal number of post-shuffle partitions is not enforced because + // the size of data is 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // The minimal number of post-shuffle partitions is enforced. + val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) + val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + + { + // The number of post-shuffle partitions is determined by the coordinator. + val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) + val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices) + } + } + + /////////////////////////////////////////////////////////////////////////// + // Query tests + /////////////////////////////////////////////////////////////////////////// + + val numInputPartitions: Int = 10 + + def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(actual, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + def withSQLContext( + f: SQLContext => Unit, + targetNumPostShufflePartitions: Int, + minNumPostShufflePartitions: Option[Int]): Unit = { + val sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set( + SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, + targetNumPostShufflePartitions.toString) + minNumPostShufflePartitions match { + case Some(numPartitions) => + sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) + case None => + sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") + } + val sparkContext = new SparkContext(sparkConf) + val sqlContext = new TestSQLContext(sparkContext) + try f(sqlContext) finally sparkContext.stop() + } + + Seq(Some(3), None).foreach { minNumPostShufflePartitions => + val testNameNote = minNumPostShufflePartitions match { + case Some(numPartitions) => "(minNumPostShufflePartitions: 3)" + case None => "" + } + + test(s"determining the number of reducers: aggregate operator$testNameNote") { + val test = { sqlContext: SQLContext => + val df = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 20 as key", "id as value") + val agg = df.groupBy("key").count + + // Check the answer first. + checkAnswer( + agg, + sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = agg.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 1) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 2) + case o => + } + } + } + + withSQLContext(test, 1536, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: join operator$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .unionAll(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 2) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 2) + case o => + } + } + } + + withSQLContext(test, 16384, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 1$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count + .toDF("key1", "cnt1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + .groupBy("key2") + .count + .toDF("key2", "cnt2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("cnt2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 500) + .selectExpr("id", "2 as cnt") + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 4) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + assert(exchanges.forall(_.coordinator.isDefined)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(1, 2)) + } + } + + withSQLContext(test, 6144, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 2$testNameNote") { + val test = { sqlContext: SQLContext => + val df1 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count + .toDF("key1", "cnt1") + val df2 = + sqlContext + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + sqlContext + .range(0, 1000) + .selectExpr("id % 500 as key", "2 as cnt", "id as value") + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val exchanges = join.queryExecution.executedPlan.collect { + case e: Exchange => e + } + assert(exchanges.length === 3) + minNumPostShufflePartitions match { + case Some(numPartitions) => + exchanges.foreach { + case e: Exchange => + assert(e.coordinator.isDefined) + assert(e.outputPartitioning.numPartitions === 3) + case o => + } + + case None => + assert(exchanges.forall(_.coordinator.isDefined)) + assert(exchanges.map(_.outputPartitioning.numPartitions).toSeq.toSet === Set(2, 3)) + } + } + + withSQLContext(test, 6144, minNumPostShufflePartitions) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ebdab1c26d7bd..2076c573b56c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -268,7 +268,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + if (outputPlan.collect { case e: Exchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -306,7 +306,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + if (outputPlan.collect { case e: Exchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -326,7 +326,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -349,7 +349,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index d32572b54b8a8..09e258299de5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -152,7 +152,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) .asInstanceOf[RDD[Product2[Int, InternalRow]]] - val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2) + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rowsRDD, + new PartitionIdPassthrough(2), + Some(new UnsafeRowSerializer(2))) + val shuffled = new ShuffledRowRDD(dependency) shuffled.count() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index da58e96f3e6f7..066c16e535c76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -49,7 +49,16 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "e") )), new StructType().add("n", IntegerType).add("l", StringType)) - private lazy val myTestData = Seq( + private lazy val myTestData1 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + private lazy val myTestData2 = Seq( (1, 1), (1, 2), (2, 1), @@ -184,8 +193,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) { - lazy val left = myTestData.where("a = 1") - lazy val right = myTestData.where("a = 1") + lazy val left = myTestData1.where("a = 1") + lazy val right = myTestData2.where("a = 1") testInnerJoin( "inner join, multiple matches", left, @@ -201,8 +210,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } { - lazy val left = myTestData.where("a = 1") - lazy val right = myTestData.where("a = 2") + lazy val left = myTestData1.where("a = 1") + lazy val right = myTestData2.where("a = 2") testInnerJoin( "inner join, no matches", left, From 67e23b39ac3cdee06668fa9131951278b9731e29 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 3 Nov 2015 11:42:08 +0100 Subject: [PATCH 139/324] [SPARK-10429] [SQL] make mutableProjection atomic Right now, SQL's mutable projection updates every value of the mutable project after it evaluates the corresponding expression. This makes the behavior of MutableProjection confusing and complicate the implementation of common aggregate functions like stddev because developers need to be aware that when evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the mutable row has already been updated. This PR make the MutableProjection atomic, by generating all the results of expressions first, then copy them into mutableRow. Had run a mircro-benchmark, there is no notable performance difference between using class members and local variables. cc yhuai Author: Davies Liu Closes #9422 from davies/atomic_mutable and squashes the following commits: bbc1758 [Davies Liu] support wide table 8a0ae14 [Davies Liu] fix bug bec07da [Davies Liu] refactor 2891628 [Davies Liu] make mutableProjection atomic --- .../sql/catalyst/expressions/Projection.scala | 13 +- .../expressions/aggregate/functions.scala | 154 ++++++++---------- .../codegen/GenerateMutableProjection.scala | 28 +++- 3 files changed, 97 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index afe52e6a667eb..a6fe730f6dad4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + private[this] val buffer = new Array[Any](expressions.size) + expressions.foreach(_.foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).eval(input) + // Store the result into buffer first, to make the projection atomic (needed by aggregation) + buffer(i) = exprArray(i).eval(input) + i += 1 + } + i = 0 + while (i < exprArray.length) { + mutableRow(i) = buffer(i) i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 5d2eb7b017ab9..f2c3eca095115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate { case _ => DoubleType } - private val currentSum = AttributeReference("currentSum", sumDataType)() - private val currentCount = AttributeReference("currentCount", LongType)() + private val sum = AttributeReference("sum", sumDataType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentSum :: currentCount :: Nil + override val aggBufferAttributes = sum :: count :: Nil override val initialValues = Seq( - /* currentSum = */ Cast(Literal(0), sumDataType), - /* currentCount = */ Literal(0L) + /* sum = */ Cast(Literal(0), sumDataType), + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentSum = */ + /* sum = */ Add( - currentSum, + sum, Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentSum = */ currentSum.left + currentSum.right, - /* currentCount = */ currentCount.left + currentCount.right + /* sum = */ sum.left + sum.right, + /* count = */ count.left + count.right ) - // If all input are nulls, currentCount will be 0 and we will get null after the division. + // If all input are nulls, count will be 0 and we will get null after the division. override val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => - Cast(currentSum, resultType) / Cast(currentCount, resultType) + Cast(sum, resultType) / Cast(count, resultType) } } @@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val currentCount = AttributeReference("currentCount", LongType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentCount :: Nil + override val aggBufferAttributes = count :: Nil override val initialValues = Seq( - /* currentCount = */ Literal(0L) + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentCount = */ currentCount.left + currentCount.right + /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(currentCount, LongType) + override val evaluateExpression = Cast(count, LongType) } /** @@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { private val resultType = DoubleType - private val preCount = AttributeReference("preCount", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() - private val preAvg = AttributeReference("preAvg", resultType)() - private val currentAvg = AttributeReference("currentAvg", resultType)() - private val currentMk = AttributeReference("currentMk", resultType)() + private val count = AttributeReference("count", resultType)() + private val avg = AttributeReference("avg", resultType)() + private val mk = AttributeReference("mk", resultType)() - override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil + override val aggBufferAttributes = count :: avg :: mk :: Nil override val initialValues = Seq( - /* preCount = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType), - /* preAvg = */ Cast(Literal(0), resultType), - /* currentAvg = */ Cast(Literal(0), resultType), - /* currentMk = */ Cast(Literal(0), resultType) + /* count = */ Cast(Literal(0), resultType), + /* avg = */ Cast(Literal(0), resultType), + /* mk = */ Cast(Literal(0), resultType) ) override val updateExpressions = { + val value = Cast(child, resultType) + val newCount = count + Cast(Literal(1), resultType) // update average // avg = avg + (value - avg)/count - def avgAdd: Expression = { - currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) - } + val newAvg = avg + (value - avg) / newCount // update sum of square of difference from mean // Mk = Mk + (value - preAvg) * (value - updatedAvg) - def mkAdd: Expression = { - val delta1 = Cast(child, resultType) - preAvg - val delta2 = Cast(child, resultType) - currentAvg - currentMk + (delta1 * delta2) - } + val newMk = mk + (value - avg) * (value - newAvg) Seq( - /* preCount = */ If(IsNull(child), preCount, currentCount), - /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), - /* preAvg = */ If(IsNull(child), preAvg, currentAvg), - /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), - /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + /* count = */ If(IsNull(child), count, newCount), + /* avg = */ If(IsNull(child), avg, newAvg), + /* mk = */ If(IsNull(child), mk, newMk) ) } override val mergeExpressions = { // count merge - def countMerge: Expression = { - currentCount.left + currentCount.right - } + val newCount = count.left + count.right // average merge - def avgMerge: Expression = { - ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) - } + val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount // update sum of square differences - def mkMerge: Expression = { - val avgDelta = currentAvg.right - preAvg - val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / - (preCount + currentCount.right) - - currentMk.left + currentMk.right + mkDelta + val newMk = { + val avgDelta = avg.right - avg.left + val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount + mk.left + mk.right + mkDelta } Seq( - /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), - /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), - /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), - /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) + /* count = */ If(IsNull(count.left), count.right, + If(IsNull(count.right), count.left, newCount)), + /* avg = */ If(IsNull(avg.left), avg.right, + If(IsNull(avg.right), avg.left, newAvg)), + /* mk = */ If(IsNull(mk.left), mk.right, + If(IsNull(mk.right), mk.left, newMk)) ) } override val evaluateExpression = { - // when currentCount == 0, return null - // when currentCount == 1, return 0 - // when currentCount >1 - // stddev_samp = sqrt (currentMk/(currentCount -1)) - // stddev_pop = sqrt (currentMk/currentCount) - val varCol = { + // when count == 0, return null + // when count == 1, return 0 + // when count >1 + // stddev_samp = sqrt (mk/(count -1)) + // stddev_pop = sqrt (mk/count) + val varCol = if (isSample) { - currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) - } - else { - currentMk / currentCount + mk / Cast((count - Cast(Literal(1), resultType)), resultType) + } else { + mk / count } - } - If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), Cast(Sqrt(varCol), resultType))) } } @@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private val sumDataType = resultType - private val currentSum = AttributeReference("currentSum", sumDataType)() + private val sum = AttributeReference("sum", sumDataType)() private val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = currentSum :: Nil + override val aggBufferAttributes = sum :: Nil override val initialValues = Seq( - /* currentSum = */ Literal.create(null, sumDataType) + /* sum = */ Literal.create(null, sumDataType) ) override val updateExpressions = Seq( - /* currentSum = */ - Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) override val mergeExpressions = { - val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( - /* currentSum = */ - Coalesce(Seq(add, currentSum.left)) + /* sum = */ + Coalesce(Seq(add, sum.left)) ) } - override val evaluateExpression = Cast(currentSum, resultType) + override val evaluateExpression = Cast(sum, resultType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e8ee64756d5d0..4b66069b5f55a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } + val updates = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => if (e.dataType.isInstanceOf[DecimalType]) { // Can't call setNullAt on DecimalType, because we need to keep the offset s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } else { s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { mutableRow.setNullAt($i); } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } } + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" public Object generate($exprType[] expr) { @@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public Object apply(Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections + // copy all the results into MutableRow + $allUpdates return mutableRow; } } From 425ff03f5ac4f3ddda1ba06656e620d5426f4209 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Nov 2015 12:47:39 +0100 Subject: [PATCH 140/324] [SPARK-11436] [SQL] rebind right encoder when join 2 datasets When we join 2 datasets, we will combine 2 encoders into a tupled one, and use it as the encoder for the jioned dataset. Assume both of the 2 encoders are flat, their `constructExpression`s both reference to the first element of input row. However, when we combine 2 encoders, the schema of input row changed, now the right encoder should reference to second element of input row. So we should rebind right encoder to let it know the new schema of input row before combine it. Author: Wenchen Fan Closes #9391 from cloud-fan/join and squashes the following commits: 846d3ab [Wenchen Fan] rebind right encoder when join 2 datasets --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +++- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e0ab5f593e933..ed98a2541598f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -390,7 +390,9 @@ class Dataset[T] private( val rightEncoder = if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(leftEncoder, rightEncoder) + ExpressionEncoder.tuple( + leftEncoder, + rightEncoder.rebind(right.output, left.output ++ right.output)) withPlan[(T, U)](other) { (left, right) => Project( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 993e6d269ee03..95b8d05cf4414 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -214,4 +214,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + } } From b86f2cab67989f09ba1ba8604e52cd4b1e44e436 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 3 Nov 2015 13:02:17 +0100 Subject: [PATCH 141/324] [SPARK-11404] [SQL] Support for groupBy using column expressions This PR adds a new method `groupBy(cols: Column*)` to `Dataset` that allows users to group using column expressions instead of a lambda function. Since the return type of these expressions is not known at compile time, we just set the key type as a generic `Row`. If the user would like to work the key in a type-safe way, they can call `grouped.asKey[Type]`, which is also added in this PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] val agged = grouped.mapGroups { case (g, iter) => Iterator((g, iter.map(_._2).sum)) } agged.collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust Closes #9359 from marmbrus/columnGroupBy and squashes the following commits: bbcb03b [Michael Armbrust] Update DatasetSuite.scala 8fd2908 [Michael Armbrust] Update DatasetSuite.scala 0b0e2f8 [Michael Armbrust] [SPARK-11404] [SQL] Support for groupBy using column expressions --- .../scala/org/apache/spark/sql/Dataset.scala | 36 ++++++++++++-- .../org/apache/spark/sql/GroupedDataset.scala | 28 +++++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 48 +++++++++++++++++++ 3 files changed, 106 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ed98a2541598f..7b75aeec4cf3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner @@ -78,9 +79,17 @@ class Dataset[T] private( * ************* */ /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. - * TODO: should bind here... - * TODO: document binding rules + * Returns a new `Dataset` where each record has been mapped on to the specified type. The + * method used to map columns depend on the type of `U`: + * - When `U` is a class, fields for the class will be mapped to columns of the same name + * (case sensitivity is determined by `spark.sql.caseSensitive`) + * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * be assigned to `_1`). + * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the + * [[DataFrame]] will be used. + * + * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select` + * along with `alias` or `as` to rearrange or rename as required. * @since 1.6.0 */ def as[U : Encoder]: Dataset[U] = { @@ -225,6 +234,27 @@ class Dataset[T] private( withGroupingKey.newColumns) } + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * @since 1.6.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataset[Row, T] = { + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKey = Project(withKeyColumns, logicalPlan) + val executed = sqlContext.executePlan(withKey) + + val dataAttributes = executed.analyzed.output.dropRight(cols.size) + val keyAttributes = executed.analyzed.output.takeRight(cols.size) + + new GroupedDataset( + RowEncoder(keyAttributes.toStructType), + encoderFor[T], + executed, + dataAttributes, + keyAttributes) + } + /* ****************** * * Typed Relational * * ****************** */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 612f2b60cd405..96d6e9dd548e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -34,11 +34,33 @@ class GroupedDataset[K, T] private[sql]( private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - private implicit def kEnc = kEncoder - private implicit def tEnc = tEncoder + private implicit val kEnc = kEncoder match { + case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) + case other => + throw new UnsupportedOperationException("Only expression encoders are currently supported") + } + + private implicit val tEnc = tEncoder match { + case e: ExpressionEncoder[T] => e.resolve(dataAttributes) + case other => + throw new UnsupportedOperationException("Only expression encoders are currently supported") + } + private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext + /** + * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified + * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + */ + def asKey[L : Encoder]: GroupedDataset[L, T] = + new GroupedDataset( + encoderFor[L], + tEncoder, + queryExecution, + dataAttributes, + groupingAttributes) + /** * Returns a [[Dataset]] that contains each unique key. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 95b8d05cf4414..5973fa7f2a76b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -203,6 +203,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1") + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g.getString(0), iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns asKey, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").asKey[String] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns asKey tuple, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) + } + + test("groupBy columns asKey class, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() From 233e534ac43ea25ac1b0e6a985f6928d46c5d03a Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Tue, 3 Nov 2015 12:46:11 +0000 Subject: [PATCH 142/324] [SPARK-11344] Made ApplicationDescription and DriverDescription case classes DriverDescription refactored to case class because it included no mutable fields. ApplicationDescription had one mutable field, which was appUiUrl. This field was set by the driver to point to the driver web UI. Master was modifying this field when the application was removed to redirect requests to history server. This was wrong because objects which are sent over the wire should be immutable. Now appUiUrl is immutable in ApplicationDescription and always points to the driver UI even if it is already shutdown. The UI url which master exposes to the user and modifies dynamically is now included into ApplicationInfo - a data object which describes the application state internally in master. That URL in ApplicationInfo is initialised with the value from ApplicationDescription. ApplicationDescription also included value user, which is now a part of case class fields. Author: Jacek Lewandowski Closes #9299 from jacek-lewandowski/SPARK-11344. --- .../spark/deploy/ApplicationDescription.scala | 33 ++++++------------- .../spark/deploy/DriverDescription.scala | 21 ++++-------- .../spark/deploy/master/ApplicationInfo.scala | 7 ++++ .../apache/spark/deploy/master/Master.scala | 12 ++++--- .../deploy/master/ui/ApplicationPage.scala | 2 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../apache/spark/deploy/DeployTestUtils.scala | 3 +- 7 files changed, 34 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index ae99432f5ce86..78bbd5c03f4a6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -19,30 +19,17 @@ package org.apache.spark.deploy import java.net.URI -private[spark] class ApplicationDescription( - val name: String, - val maxCores: Option[Int], - val memoryPerExecutorMB: Int, - val command: Command, - var appUiUrl: String, - val eventLogDir: Option[URI] = None, +private[spark] case class ApplicationDescription( + name: String, + maxCores: Option[Int], + memoryPerExecutorMB: Int, + command: Command, + appUiUrl: String, + eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) - val eventLogCodec: Option[String] = None, - val coresPerExecutor: Option[Int] = None) - extends Serializable { - - val user = System.getProperty("user.name", "") - - def copy( - name: String = name, - maxCores: Option[Int] = maxCores, - memoryPerExecutorMB: Int = memoryPerExecutorMB, - command: Command = command, - appUiUrl: String = appUiUrl, - eventLogDir: Option[URI] = eventLogDir, - eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = - new ApplicationDescription( - name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) + eventLogCodec: Option[String] = None, + coresPerExecutor: Option[Int] = None, + user: String = System.getProperty("user.name", "")) { override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala index 659fb434a80f5..1f5626ab5a896 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -17,21 +17,12 @@ package org.apache.spark.deploy -private[deploy] class DriverDescription( - val jarUrl: String, - val mem: Int, - val cores: Int, - val supervise: Boolean, - val command: Command) - extends Serializable { - - def copy( - jarUrl: String = jarUrl, - mem: Int = mem, - cores: Int = cores, - supervise: Boolean = supervise, - command: Command = command): DriverDescription = - new DriverDescription(jarUrl, mem, cores, supervise, command) +private[deploy] case class DriverDescription( + jarUrl: String, + mem: Int, + cores: Int, + supervise: Boolean, + command: Command) { override def toString: String = s"DriverDescription (${command.mainClass})" } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index b40d20f9f7868..ac553b71115df 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -41,6 +41,7 @@ private[spark] class ApplicationInfo( @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ + @transient @volatile var appUIUrlAtHistoryServer: Option[String] = None // A cap on the number of executors this application can have at any given time. // By default, this is infinite. Only after the first allocation request is issued by the @@ -135,4 +136,10 @@ private[spark] class ApplicationInfo( } } + /** + * Returns the original application UI url unless there is its address at history server + * is defined + */ + def curAppUIUrl: String = appUIUrlAtHistoryServer.getOrElse(desc.appUiUrl) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 6715d6c70f497..b25a487806c7f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -768,7 +768,8 @@ private[deploy] class Master( ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) + val appId = newApplicationId(date) + new ApplicationInfo(now, appId, desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { @@ -920,7 +921,7 @@ private[deploy] class Master( val eventLogDir = app.desc.eventLogDir .getOrElse { // Event logging is not enabled for this application - app.desc.appUiUrl = notFoundBasePath + app.appUIUrlAtHistoryServer = Some(notFoundBasePath) return None } @@ -954,7 +955,7 @@ private[deploy] class Master( appIdToUI(app.id) = ui webUi.attachSparkUI(ui) // Application UI is successfully rebuilt, so link the Master UI to it - app.desc.appUiUrl = ui.basePath + app.appUIUrlAtHistoryServer = Some(ui.basePath) Some(ui) } catch { case fnf: FileNotFoundException => @@ -964,7 +965,7 @@ private[deploy] class Master( logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title" + app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") None case e: Exception => // Relay exception message to application UI page @@ -973,7 +974,8 @@ private[deploy] class Master( var msg = s"Exception in replaying log for application $appName!" logError(msg, e) msg = URLEncoder.encode(msg, "UTF-8") - app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title" + app.appUIUrlAtHistoryServer = + Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") None } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index e28e7e379ac91..f405aa2bdc8b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -76,7 +76,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
  12. Submit Date: {app.submitDate}
  13. State: {app.state}
  14. -
  15. Application Detail UI
  16. +
  17. Application Detail UI
  18. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c3e20ebf8d6eb..ee539dd1f5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -206,7 +206,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {killLink}
    - + @@ -430,7 +430,11 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverActive = receiverInfo.map { info => if (info.active) "ACTIVE" else "INACTIVE" }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLocation = receiverInfo.map { info => + val executorId = if (info.executorId.isEmpty) emptyCell else info.executorId + val location = if (info.location.isEmpty) emptyCell else info.location + s"$executorId / $location" + }.getOrElse(emptyCell) val receiverLastError = receiverInfo.map { info => val msg = s"${info.lastErrorMessage} - ${info.lastError}" if (msg.size > 100) msg.take(97) + "..." else msg diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java index 8cc285aa7fb34..67b2a0703e02b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -29,6 +29,7 @@ public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStart receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); @@ -41,6 +42,7 @@ public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); @@ -53,6 +55,7 @@ public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopp receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala index 6d6d61e70cafc..0295e059f7bc2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -33,7 +33,8 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { streamId = 2, name = "test", active = true, - location = "localhost" + location = "localhost", + executorId = "1" )) listenerWrapper.onReceiverStarted(receiverStarted) assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) @@ -42,7 +43,8 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { streamId = 2, name = "test", active = false, - location = "localhost" + location = "localhost", + executorId = "1" )) listenerWrapper.onReceiverStopped(receiverStopped) assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) @@ -52,6 +54,7 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { name = "test", active = false, location = "localhost", + executorId = "1", lastErrorMessage = "failed", lastError = "failed", lastErrorTime = System.currentTimeMillis() @@ -197,6 +200,7 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { assert(javaReceiverInfo.name === receiverInfo.name) assert(javaReceiverInfo.active === receiverInfo.active) assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.executorId === receiverInfo.executorId) assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) assert(javaReceiverInfo.lastError === receiverInfo.lastError) assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index af4718b4eb705..34cd7435569e1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -130,20 +130,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost", "0") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost", "1") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost", "2") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) From 1431319e5bc46c7225a8edeeec482816d14a83b8 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 18:53:57 -0800 Subject: [PATCH 267/324] Add mockito as an explicit test dependency to spark-streaming While sbt successfully compiles as it properly pulls the mockito dependency, maven builds have broken. We need this in ASAP. tdas Author: Burak Yavuz Closes #9584 from brkyvz/fix-master. --- streaming/pom.xml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/streaming/pom.xml b/streaming/pom.xml index 145c8a7321c05..435e16db13ab4 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -93,6 +93,11 @@ selenium-java test + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes From c4e19b3819df4cd7a1c495a00bd2844cf55f4dbd Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 9 Nov 2015 21:06:01 -0800 Subject: [PATCH 268/324] [SPARK-11587][SPARKR] Fix the summary generic to match base R The signature is summary(object, ...) as defined in https://stat.ethz.ch/R-manual/R-devel/library/base/html/summary.html Author: Shivaram Venkataraman Closes #9582 from shivaram/summary-fix. --- R/pkg/R/DataFrame.R | 6 +++--- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib.R | 12 ++++++------ R/pkg/inst/tests/test_mllib.R | 6 ++++++ 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 44ce9414da5cf..e9013aa34a84f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1944,9 +1944,9 @@ setMethod("describe", #' @rdname summary #' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 083d37fee28a4..efef7d66b522c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) #' @rdname summary #' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) # @rdname tojson # @export diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7ff859741b4a0..7126b7cde4bd7 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,17 +89,17 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { +setMethod("summary", signature(object = "PipelineModel"), + function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", x@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", x@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelDevianceResiduals", x@model) + "getModelDevianceResiduals", object@model) devianceResiduals <- matrix(devianceResiduals, nrow = 1) colnames(devianceResiduals) <- c("Min", "Max") rownames(devianceResiduals) <- rep("", times = 1) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 2606407bdcb44..42287ea19adc5 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -113,3 +113,9 @@ test_that("summary coefficients match with native glm of family 'binomial'", { rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) + +test_that("summary works on base GLM models", { + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) From d6cd3a18e720e8f6f1f307e0dffad3512952d997 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Nov 2015 23:27:36 -0800 Subject: [PATCH 269/324] [SPARK-11599] [SQL] fix NPE when resolve Hive UDF in SQLParser The DataFrame APIs that takes a SQL expression always use SQLParser, then the HiveFunctionRegistry will called outside of Hive state, cause NPE if there is not a active Session State for current thread (in PySpark). cc rxin yhuai Author: Davies Liu Closes #9576 from davies/hive_udf. --- .../apache/spark/sql/hive/HiveContext.scala | 10 +++++- .../sql/hive/execution/HiveQuerySuite.scala | 33 ++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72b959af134..c5f69657f5293 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,7 +454,15 @@ class HiveContext private[hive]( // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { + super.lookupFunction(name, children) + } + } + } // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer // can't access the SessionState of metadataHive. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 78378c8b69c7a..f0a7a6cc7a1e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin - import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.{SparkException, SparkFiles} case class TestData(a: Int, b: String) @@ -1237,6 +1234,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") From 521b3cae118d1e22c170e2aad43f9baa162db55e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Nov 2015 23:28:32 -0800 Subject: [PATCH 270/324] [SPARK-11598] [SQL] enable tests for ShuffledHashOuterJoin Author: Davies Liu Closes #9573 from davies/join_condition. --- .../org/apache/spark/sql/JoinSuite.scala | 435 ++++++++++-------- 1 file changed, 231 insertions(+), 204 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a9ca46cab067d..3f3b837f7581c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -237,214 +237,241 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 2, 2) :: Nil) } - test("left outer join") { - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), - Row(1, "A", 1, "a") :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), - Row(1, "A", null, null) :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), - Row(1, "A", null, null) :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), - Row(1, "A", 1, "a") :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - // Make sure we are choosing left.outputPartitioning as the - // outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.N, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) - - checkAnswer( - sql( - """ - |SELECT r.a, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), - Row(null, 6) :: Nil) - } + def test_outer_join(useSMJ: Boolean): Unit = { + + val algo = if (useSMJ) "SortMergeOuterJoin" else "ShuffledHashOuterJoin" + + test("left outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + // Make sure we are choosing left.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """. + stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + Row(null, 6) :: Nil) + } + } - test("right outer join") { - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), - Row(1, "a", 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), - Row(null, null, 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), - Row(null, null, 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), - Row(1, "a", 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - // Make sure we are choosing right.outputPartitioning as the - // outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.a, count(*) - |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), - Row(null, 6)) + test("right outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + // Make sure we are choosing right.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + Row(null, + 6)) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + Row(1 + , 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + } + } - checkAnswer( - sql( - """ - |SELECT r.N, count(*) - |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY r.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) + test("full outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") + + val left = UnresolvedRelation(TableIdentifier("left"), None) + val right = UnresolvedRelation(TableIdentifier("right"), None) + + checkAnswer( + left.join(right, $"left.N" === $"right.N", "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join + // operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """. + stripMargin), + Row( + null, 10)) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + Row + (1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + Row(1 + , 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """. + stripMargin), + Row(null, 10)) + } + } } - test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") - - val left = UnresolvedRelation(TableIdentifier("left"), None) - val right = UnresolvedRelation(TableIdentifier("right"), None) - - checkAnswer( - left.join(right, $"left.N" === $"right.N", "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", null, null) :: - Row(null, null, 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", null, null) :: - Row(null, null, 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.a, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), - Row(null, 10)) - - checkAnswer( - sql( - """ - |SELECT r.N, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY r.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: - Row(null, 4) :: Nil) - - checkAnswer( - sql( - """ - |SELECT l.N, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: - Row(null, 4) :: Nil) - - checkAnswer( - sql( - """ - |SELECT r.a, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), - Row(null, 10)) - } + // test SortMergeOuterJoin + test_outer_join(true) + // test ShuffledHashOuterJoin + test_outer_join(false) test("broadcasted left semi join operator selection") { sqlContext.cacheManager.clearCache() From 5507a9d0935aa42d65c3a4fa65da680b5af14faf Mon Sep 17 00:00:00 2001 From: Paul Chandler Date: Tue, 10 Nov 2015 12:59:53 +0100 Subject: [PATCH 271/324] Fix typo in driver page "Comamnd property" => "Command property" Author: Paul Chandler Closes #9578 from pestilence669/fix_spelling. --- .../scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index e8ef60bd5428a..bc67fd460d9a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -46,7 +46,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") val schedulerHeaders = Seq("Scheduler property", "Value") val commandEnvHeaders = Seq("Command environment variable", "Value") val launchedHeaders = Seq("Launched property", "Value") - val commandHeaders = Seq("Comamnd property", "Value") + val commandHeaders = Seq("Command property", "Value") val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") val driverDescription = Iterable.apply(driverState.description) val submissionState = Iterable.apply(driverState.submissionState) From a81f47ff7498e7063c855ccf75bba81ab101b43e Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 10 Nov 2015 10:05:53 -0800 Subject: [PATCH 272/324] [SPARK-11382] Replace example code in mllib-decision-tree.md using include_example https://issues.apache.org/jira/browse/SPARK-11382 B.T.W. I fix an error in naive_bayes_example.py. Author: Xusen Yin Closes #9596 from yinxusen/SPARK-11382. --- docs/mllib-decision-tree.md | 253 +----------------- ...JavaDecisionTreeClassificationExample.java | 91 +++++++ .../JavaDecisionTreeRegressionExample.java | 96 +++++++ .../decision_tree_classification_example.py | 55 ++++ .../mllib/decision_tree_regression_example.py | 56 ++++ .../main/python/mllib/naive_bayes_example.py | 1 + .../DecisionTreeClassificationExample.scala | 67 +++++ .../mllib/DecisionTreeRegressionExample.scala | 66 +++++ 8 files changed, 438 insertions(+), 247 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java create mode 100644 examples/src/main/python/mllib/decision_tree_classification_example.py create mode 100644 examples/src/main/python/mllib/decision_tree_regression_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index b5b454bc69245..77ce34e91af3c 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -194,137 +194,19 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "gini" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model for classification. -final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - impurity='gini', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_classification_example.py %}
    @@ -343,142 +225,19 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "variance" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model. -final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, - impurity='variance', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_regression_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..5839b0cf8a8f8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeClassificationExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..ccde578249f7c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeRegressionExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "variance"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model. + final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py new file mode 100644 index 0000000000000..1b529768b6c62 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeClassificationExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py new file mode 100644 index 0000000000000..cf518eac67e81 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeRegressionExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index a2e7dacf25491..f5e120c678fcf 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -20,6 +20,7 @@ """ from __future__ import print_function +from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..d427bbadaa0c1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeClassificationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "gini" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..fb05e7d9c5065 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "variance" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + // $example off$ + } +} +// scalastyle:on println From 689386b1c60997e4505749915f7005a52c207de2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Nov 2015 10:14:19 -0800 Subject: [PATCH 273/324] [SPARK-7841][BUILD] Stop using retrieveManaged to retrieve dependencies in SBT This patch modifies Spark's SBT build so that it no longer uses `retrieveManaged` / `lib_managed` to store its dependencies. The motivations for this change are nicely described on the JIRA ticket ([SPARK-7841](https://issues.apache.org/jira/browse/SPARK-7841)); my personal interest in doing this stems from the fact that `lib_managed` has caused me some pain while debugging dependency issues in another PR of mine. Removing our use of `lib_managed` would be trivial except for one snag: the Datanucleus JARs, required by Spark SQL's Hive integration, cannot be included in assembly JARs due to problems with merging OSGI `plugin.xml` files. As a result, several places in the packaging and deployment pipeline assume that these Datanucleus JARs are copied to `lib_managed/jars`. In the interest of maintaining compatibility, I have chosen to retain the `lib_managed/jars` directory _only_ for these Datanucleus JARs and have added custom code to `SparkBuild.scala` to automatically copy those JARs to that folder as part of the `assembly` task. `dev/mima` also depended on `lib_managed` in a hacky way in order to set classpaths when generating MiMa excludes; I've updated this to obtain the classpaths directly from SBT instead. /cc dragos marmbrus pwendell srowen Author: Josh Rosen Closes #9575 from JoshRosen/SPARK-7841. --- dev/mima | 2 +- project/SparkBuild.scala | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/dev/mima b/dev/mima index 2952fa65d42ff..d5baffc6ef8a3 100755 --- a/dev/mima +++ b/dev/mima @@ -38,7 +38,7 @@ generate_mima_ignore() { # it did not process the new classes (which are in assembly jar). generate_mima_ignore -export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" +export SPARK_CLASSPATH="$(build/sbt "export oldDeps/fullClasspath" | tail -n1)" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" generate_mima_ignore diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b75ed13a78c68..a9fb741d75933 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -16,6 +16,7 @@ */ import java.io._ +import java.nio.file.Files import scala.util.Properties import scala.collection.JavaConverters._ @@ -135,8 +136,6 @@ object SparkBuild extends PomBuild { .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), incOptions := incOptions.value.withNameHashing(true), - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", @@ -326,8 +325,6 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.5", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", @@ -404,6 +401,8 @@ object Assembly { val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") + val deployDatanucleusJars = taskKey[Unit]("Deploy datanucleus jars to the spark/lib_managed/jars directory") + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, hadoopVersion := { @@ -429,7 +428,20 @@ object Assembly { case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first - } + }, + deployDatanucleusJars := { + val jars: Seq[File] = (fullClasspath in assembly).value.map(_.data) + .filter(_.getPath.contains("org.datanucleus")) + var libManagedJars = new File(BuildCommons.sparkHome, "lib_managed/jars") + libManagedJars.mkdirs() + jars.foreach { jar => + val dest = new File(libManagedJars, jar.getName) + if (!dest.exists()) { + Files.copy(jar.toPath, dest.toPath) + } + } + }, + assembly <<= assembly.dependsOn(deployDatanucleusJars) ) } From 6e5fc37883ed81c3ee2338145a48de3036d19399 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 Nov 2015 10:40:08 -0800 Subject: [PATCH 274/324] [SPARK-11252][NETWORK] ShuffleClient should release connection after fetching blocks had been completed for external shuffle with yarn's external shuffle, ExternalShuffleClient of executors reserve its connections for yarn's NodeManager until application has been completed. so it will make NodeManager and executors have many socket connections. in order to reduce network pressure of NodeManager's shuffleService, after registerWithShuffleServer or fetchBlocks have been completed in ExternalShuffleClient, connection for NM's shuffleService needs to be closed.andrewor14 rxin vanzin Author: Lianhui Wang Closes #9227 from lianhuiwang/spark-11252. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../spark/network/TransportContext.java | 11 +++++- .../client/TransportClientFactory.java | 10 ++++++ .../server/TransportChannelHandler.java | 26 +++++++++----- .../network/TransportClientFactorySuite.java | 34 +++++++++++++++++++ .../shuffle/ExternalShuffleClient.java | 12 ++++--- 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 6840a3ae831f0..a039d543c35e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -47,7 +47,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + private val transportContext: TransportContext = + new TransportContext(transportConf, blockHandler, true) private var server: TransportServer = _ diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 43900e6f2c972..1b64b863a9fe5 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -59,15 +59,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -144,7 +153,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8b..42a4f664e697c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -158,6 +158,16 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port + * But this connection is not pooled. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e3..f8fcd1c3d7d76 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). boolean isActuallyOverdue = System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + ctx.close(); + } } } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb98..f447137419306 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -177,4 +179,36 @@ public void closeBlockClientsWithFactory() throws IOException { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf(new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be3..ef3a9dcc8711f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -78,7 +78,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +137,13 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override From e0701c75601c43f69ed27fc7c252321703db51f2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 10 Nov 2015 11:06:29 -0800 Subject: [PATCH 275/324] [SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to evaluate AggregateExpression1s https://issues.apache.org/jira/browse/SPARK-9830 This PR contains the following main changes. * Removing `AggregateExpression1`. * Removing `Aggregate` operator, which is used to evaluate `AggregateExpression1`. * Removing planner rule used to plan `Aggregate`. * Linking `MultipleDistinctRewriter` to analyzer. * Renaming `AggregateExpression2` to `AggregateExpression` and `AggregateFunction2` to `AggregateFunction`. * Updating places where we create aggregate expression. The way to create aggregate expressions is `AggregateExpression(aggregateFunction, mode, isDistinct)`. * Changing `val`s in `DeclarativeAggregate`s that touch children of this function to `lazy val`s (when we create aggregate expression in DataFrame API, children of an aggregate function can be unresolved). Author: Yin Huai Closes #9556 from yhuai/removeAgg1. --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/tests.py | 2 +- .../spark/sql/catalyst/CatalystConf.scala | 10 +- .../apache/spark/sql/catalyst/SqlParser.scala | 14 +- .../sql/catalyst/analysis/Analyzer.scala | 26 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 46 +- .../DistinctAggregationRewriter.scala} | 235 +--- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../catalyst/analysis/HiveTypeCoercion.scala | 20 +- .../sql/catalyst/analysis/unresolved.scala | 4 + .../spark/sql/catalyst/dsl/package.scala | 22 +- .../expressions/aggregate/Average.scala | 31 +- .../aggregate/CentralMomentAgg.scala | 13 +- .../catalyst/expressions/aggregate/Corr.scala | 15 + .../expressions/aggregate/Count.scala | 28 +- .../expressions/aggregate/First.scala | 14 +- .../aggregate/HyperLogLogPlusPlus.scala | 17 + .../expressions/aggregate/Kurtosis.scala | 2 + .../catalyst/expressions/aggregate/Last.scala | 12 +- .../catalyst/expressions/aggregate/Max.scala | 17 +- .../catalyst/expressions/aggregate/Min.scala | 17 +- .../expressions/aggregate/Skewness.scala | 2 + .../expressions/aggregate/Stddev.scala | 31 +- .../catalyst/expressions/aggregate/Sum.scala | 29 +- .../expressions/aggregate/Variance.scala | 7 +- .../expressions/aggregate/interfaces.scala | 57 +- .../sql/catalyst/expressions/aggregates.scala | 1073 ----------------- .../sql/catalyst/optimizer/Optimizer.scala | 23 +- .../sql/catalyst/planning/patterns.scala | 74 -- .../spark/sql/catalyst/plans/QueryPlan.scala | 12 +- .../plans/logical/basicOperators.scala | 4 +- .../analysis/AnalysisErrorSuite.scala | 23 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 1 + .../ExpressionTypeCheckingSuite.scala | 6 +- .../optimizer/ConstantFoldingSuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 14 +- .../org/apache/spark/sql/DataFrame.scala | 13 +- .../org/apache/spark/sql/GroupedData.scala | 45 +- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +- .../spark/sql/execution/Aggregate.scala | 205 ---- .../apache/spark/sql/execution/Expand.scala | 3 + .../spark/sql/execution/SparkPlanner.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 238 ++-- .../aggregate/AggregationIterator.scala | 28 +- .../aggregate/SortBasedAggregate.scala | 4 +- .../SortBasedAggregationIterator.scala | 8 +- .../aggregate/TungstenAggregate.scala | 6 +- .../TungstenAggregationIterator.scala | 36 +- .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../spark/sql/execution/aggregate/utils.scala | 20 +- .../spark/sql/expressions/Aggregator.scala | 5 +- .../spark/sql/expressions/WindowSpec.scala | 82 +- .../apache/spark/sql/expressions/udaf.scala | 6 +- .../org/apache/spark/sql/functions.scala | 53 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 69 +- .../spark/sql/UserDefinedTypeSuite.scala | 15 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 30 - .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 8 +- .../execution/AggregationQuerySuite.scala | 188 ++- 64 files changed, 743 insertions(+), 2260 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/{expressions/aggregate/Utils.scala => analysis/DistinctAggregationRewriter.scala} (58%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d7fd279279137..0b280870295a2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"), #' @export setMethod("approxCountDistinct", signature(x = "Column"), - function(x, rsd = 0.95) { + function(x, rsd = 0.05) { jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) column(jc) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b97c94dad834a..0dd75ba7ca820 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -866,7 +866,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 962f676d406d8..6e1cbde4239f3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -382,7 +382,7 @@ def expr(str): """Parses the expression string into the column that it represents >>> df.select(expr("length(name)")).collect() - [Row('length(name)=5), Row('length(name)=3)] + [Row(length(name)=5), Row(length(name)=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.expr(str)) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e224574bcb301..9f5f7cfdf7a69 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1017,7 +1017,7 @@ def test_expr(self): row = Row(a="length string", b=75) df = self.sqlCtx.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["'length(a)"]) + self.assertEqual(13, result["length(a)"]) def test_replace(self): schema = StructType([ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b07b37df..7c2b8a9407884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } + + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index cd717c09f8e5e..2a132d8b82bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DataTypeParser @@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid expression $udfName(*)") } @@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) + case "count" => + aggregate.Count(exprs).toAggregateExpression(isDistinct = true) case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) + AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate $udfName") } @@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) + AggregateExpression( + HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), + mode = Complete, + isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate($s) $udfName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df4..b1e14390b7dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -79,6 +79,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -525,21 +526,14 @@ class Analyzer( case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98d6637c0601b..8322e9930cd5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -108,7 +109,19 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + // TODO: Is it possible that the child of a agg function is another + // agg function? + aggExpr.aggregateFunction.children.foreach { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + case child if !child.deterministic => + failAnalysis( + s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in the arguments of an aggregate function.") + case child => // OK + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + @@ -120,14 +133,26 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + expr.dataType match { + case BinaryType => + failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case a: ArrayType => + failAnalysis(s"array type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case m: MapType => + failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case _ => // OK + } + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -179,7 +204,8 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + // The rule above is used to check Aggregate operator. failAnalysis( s"""nondeterministic expressions are only allowed in Project or Filter, found: | ${o.expressions.map(_.prettyString).mkString(",")} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala similarity index 58% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9b22ce2619731..397eff05686b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -15,215 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions.aggregate +package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.IntegerType /** - * Utility functions used by the query planner to convert our plan to new aggregation code path. - */ -object Utils { - - // Check if the DataType given cannot be part of a group by clause. - private def isUnGroupable(dt: DataType): Boolean = dt match { - case _: ArrayType | _: MapType => true - case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType)) - case _ => false - } - - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = - !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType)) - - private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - - val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.CountDistinct(children) => - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Kurtosis(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Kurtosis(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Skewness(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Skewness(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevPop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevPop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevSamp(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.Corr(left, right) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Corr(left, right), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.ApproxCountDistinct(child, rsd) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VariancePop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VariancePop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VarianceSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VarianceSamp(child), - mode = aggregate.Complete, - isDistinct = false) - }) - - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - // TODO remove this. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to see if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate => - val converted = doConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case other => None - } -} - -/** - * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * This rule rewrites an aggregate query with distinct aggregations into an expanded double * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * @@ -298,9 +100,11 @@ object Utils { * we could improve this in the current rule by applying more advanced expression cannocalization * techniques. */ -object MultipleDistinctRewriter extends Rule[LogicalPlan] { +case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p + // We need to wait until this Aggregate operator is resolved. case a: Aggregate => rewrite(a) case p => p } @@ -310,7 +114,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Collect all aggregate expressions. val aggExpressions = a.aggregateExpressions.flatMap { e => e.collect { - case ae: AggregateExpression2 => ae + case ae: AggregateExpression => ae } } @@ -319,8 +123,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Only continue to rewrite if there is more than one distinct group. - if (distinctAggGroups.size > 1) { + val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { + // When the flag is set to specialize single distinct agg planning, + // we will rely on our Aggregation strategy to handle queries with a single + // distinct column and this aggregate operator does have grouping expressions. + distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + } else { + distinctAggGroups.size >= 1 + } + if (shouldRewrite) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { @@ -332,11 +143,11 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Functions used to modify aggregate functions and their inputs. def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( - af: AggregateFunction2)( - attrs: Expression => Expression): AggregateFunction2 = { + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { af.withNewChildren(af.children.map { case afc => attrs(afc) - }).asInstanceOf[AggregateFunction2] + }).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -381,7 +192,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression2( + val result = AggregateExpression( aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4334d16289a5..dfa749d1afa5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -177,6 +178,7 @@ object FunctionRegistry { expression[ToRadians]("radians"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84e2b1366f626..bf2bff0243fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -295,14 +296,17 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) - case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) - case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) + case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) } } @@ -562,12 +566,6 @@ object HiveTypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eae17c86ddc7a..6485bdfb30234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -141,6 +141,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false + override def prettyString: String = { + s"${name}(${children.map(_.prettyString).mkString(",")})" + } + override def toString: String = s"'$name(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d8df66430a695..af594c25c54cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ @@ -144,17 +145,18 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c8c20ada5fbc7..7f9e5034702e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Average(child: Expression) extends DeclarativeAggregate { @@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } - private val sumDataType = child.dataType match { + private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributes = sum :: count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* sum = */ Add( sum, @@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate { /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { + override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index ef08b025ff556..984ce7f24dacc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 832338378fb38..00d7436b710d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -35,6 +37,9 @@ case class Corr( inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + def this(left: Expression, right: Expression) = + this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def children: Seq[Expression] = Seq(left, right) override def nullable: Boolean = false @@ -43,6 +48,16 @@ case class Corr( override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"corr requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) override def inputAggBufferAttributes: Seq[AttributeReference] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index ec0c8b483a909..09a1da9200df0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = count :: Nil + override lazy val aggBufferAttributes = count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = Cast(count, LongType) override def defaultResult: Option[Literal] = Option(Literal(0L)) } + +object Count { + def apply(children: Seq[Expression]): Count = { + // This is used to deal with COUNT DISTINCT. When we have multiple + // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). + // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any + // null in the arguments, we will not count that row. So, we use DropAnyNull at here + // to return a null when any field of the created STRUCT is null. + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } + Count(child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9028143015853..35f57426feaf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val first = AttributeReference("first", child.dataType)() + private lazy val first = AttributeReference("first", child.dataType)() - private val valueSet = AttributeReference("valueSet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* first = */ Literal.create(null, child.dataType), /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), @@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set // to true, we use first.right. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). @@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara ) } - override val evaluateExpression: AttributeReference = first + override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8d341ee630bdb..8a95c541f1e86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -22,6 +22,7 @@ import java.util import com.clearspring.analytics.hash.MurmurHash +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus( extends ImperativeAggregate { import HyperLogLogPlusPlus._ + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = relativeSD match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + }, + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 6da39e7143447..bae78d98493b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -24,6 +24,8 @@ case class Kurtosis(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8636bfe8d07aa..be7e12d7a2336 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val last = AttributeReference("last", child.dataType)() + private lazy val last = AttributeReference("last", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(child), last, child) @@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(last.right), last.left, last.right) @@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val evaluateExpression: AttributeReference = last + override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b9d75ad452838..61cae44cd0f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Max(child: Expression) extends DeclarativeAggregate { @@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val max = AttributeReference("max", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") - override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + private lazy val max = AttributeReference("max", child.dataType)() - override val initialValues: Seq[Literal] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( /* max = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val greatest = Greatest(Seq(max.left, max.right)) Seq( /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) ) } - override val evaluateExpression: AttributeReference = max + override lazy val evaluateExpression: AttributeReference = max } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 5ed9cd348daba..242456d9e2e18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val min = AttributeReference("min", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") - override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + private lazy val min = AttributeReference("min", child.dataType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* min = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val least = Least(Seq(min.left, min.right)) Seq( /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) ) } - override val evaluateExpression: AttributeReference = min + override lazy val evaluateExpression: AttributeReference = min } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index 0def7ddfd9d3d..c593074fa2479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -24,6 +24,8 @@ case class Skewness(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 3f47ffe13cbc8..5b9eb7ae02f25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = DoubleType + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() + private lazy val resultType = DoubleType - override val aggBufferAttributes = count :: avg :: mk :: Nil + private lazy val count = AttributeReference("count", resultType)() + private lazy val avg = AttributeReference("avg", resultType)() + private lazy val mk = AttributeReference("mk", resultType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes = count :: avg :: mk :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* count = */ Cast(Literal(0), resultType), /* avg = */ Cast(Literal(0), resultType), /* mk = */ Cast(Literal(0), resultType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { val value = Cast(child, resultType) val newCount = count + Cast(Literal(1), resultType) @@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // count merge val newCount = count.left + count.right @@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = { + override lazy val evaluateExpression: Expression = { // when count == 0, return null // when count == 1, return 0 // when count >1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7f8adbc56ad1d..c005ec9657211 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Sum(child: Expression) extends DeclarativeAggregate { @@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) // TODO: Remove this line once we remove the NullType from inputTypes. @@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate { case _ => child.dataType } - private val sumDataType = resultType + private lazy val sumDataType = resultType - private val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", sumDataType)() - private val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* sum = */ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ @@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = Cast(sum, resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ec63534e5290a..ede2da2805966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression, } } -case class VariancePop(child: Expression, +case class VariancePop( + child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 5c5b3d1ccd3cd..3b441de34a49f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction2]]. */ +/** The mode of an [[AggregateFunction]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly + * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable { } /** - * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression2( - aggregateFunction: AggregateFunction2, +private[sql] case class AggregateExpression( + aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) extends AggregateExpression { + isDistinct: Boolean) + extends Expression + with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } + override def prettyString: String = aggregateFunction.prettyString + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } @@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2( * combined aggregation buffer which concatenates the aggregation buffers of the individual * aggregate functions. * - * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { +sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false @@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct + * field of the [[AggregateExpression]] to the given value because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { + AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + } } /** @@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction2 { +abstract class ImperativeAggregate extends AggregateFunction { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation @@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and * `evaluateExpressions`. + * + * Please note that children of an aggregate function can be unresolved (it will happen when + * we create this function in DataFrame API). So, if there is any fields in + * the implemented class that need to access fields of its children, please make + * those fields `lazy val`s. */ abstract class DeclarativeAggregate - extends AggregateFunction2 + extends AggregateFunction with Serializable with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala deleted file mode 100644 index 3dcf7915d77b3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import com.clearspring.analytics.stream.cardinality.HyperLogLog - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils} -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - - -trait AggregateExpression extends Expression with Unevaluable - -trait AggregateExpression1 extends AggregateExpression { - - /** - * Aggregate expressions should not be foldable. - */ - override def foldable: Boolean = false - - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction1 -} - -/** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. - */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) - -/** - * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. - */ -trait PartialAggregate1 extends AggregateExpression1 { - - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} - -/** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. - */ -abstract class AggregateFunction1 extends LeafExpression with Serializable { - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression1 - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType - - def update(input: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - throw new UnsupportedOperationException( - "AggregateFunction1 should not be used for generated aggregates") - } -} - -case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") -} - -case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: InternalRow): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMin.value -} - -case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") -} - -case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: InternalRow): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMax.value -} - -case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) - } - - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var count: Long = _ - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } - } - - override def eval(input: InternalRow): Any = count -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = - new CollectHashSetFunction(expressions, this) -} - -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) - } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ -private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { - - override def sqlType: DataType = BinaryType - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def serialize(obj: Any): Array[Byte] = - obj.asInstanceOf[HyperLogLog].getBytes - - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def deserialize(datum: Any): HyperLogLog = - HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) - - override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: DataType = HyperLogLogUDT - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } - - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) - } - - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) -} - -case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def prettyName: String = "avg" - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -case class AverageFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) - - private def addFunction(value: Any) = Add(sum, - Cast(Literal.create(value, expr.dataType), calcType)) - - override def eval(input: InternalRow): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - val dt = DecimalType.bounded(precision + 14, scale + 4) - Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) - } - } - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) - } - } -} - -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") -} - -case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) - - override def update(input: InternalRow): Unit = { - sum.update(addFunction, input) - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - override def toString: String = s"sum(distinct $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") -} - -case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val seen = new scala.collection.mutable.HashSet[Any]() - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } - } - - override def eval(input: InternalRow): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) - } - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next()) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.get(0, null)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute, ignoreNulls), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) -} - -object First { - def apply(child: Expression): First = First(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): First = - First(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class FirstFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - private[this] var result: Any = null - - private[this] var valueSet: Boolean = false - - override def update(input: InternalRow): Unit = { - if (!valueSet) { - val value = expr.eval(input) - // When we have not set the result, we will set the result if we respect nulls - // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - valueSet = true - } - } - } - - override def eval(input: InternalRow): Any = result -} - -case class Last( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute, ignoreNulls), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) -} - -object Last { - def apply(child: Expression): Last = Last(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): Last = - Last(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class LastFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - val value = expr.eval(input) - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - } - } - - override def eval(input: InternalRow): Any = { - result - } -} - -/** - * Calculate Pearson Correlation Coefficient for the given columns. - * Only support AggregateExpression2. - * - */ -case class Corr(left: Expression, right: Expression) - extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { - override def nullable: Boolean = false - override def dataType: DoubleType.type = DoubleType - override def toString: String = s"corr($left, $right)" - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException( - "Corr only supports the new AggregateExpression2 and can only be used " + - "when spark.sql.useAggregate2 = true") - } -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def nullable: Boolean = true - override def dataType: DataType = DoubleType - - def isSample: Boolean - - override def asPartial: SplitEvaluation = { - val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() - SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) - } - - override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_pop($child)" - override def isSample: Boolean = false -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_samp($child)" - override def isSample: Boolean = true -} - -case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) -} - -case class ComputePartialStdFunction ( - expr: Expression, - base: AggregateExpression1 - ) extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private var partialCount: Long = 0L - - // the mean of data processed so far - private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update average based on this formula: - // avg = avg + (value - avg)/count - private def avgAddFunction (value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), partialAvg) - Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) - } - - // the sum of squares of difference from mean - private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update sum of square of difference from mean based on following formula: - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), prePartialAvg) - val delta2 = Subtract(Cast(value, computeType), partialAvg) - Add(partialMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - val prePartialAvg = partialAvg.copy() - partialCount += 1 - partialAvg.update(avgAddFunction(exprValue), input) - partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), - partialAvg.eval(null), - partialMk.eval(null))) - } -} - -case class MergePartialStd( - child: Expression, - isSample: Boolean -) extends UnaryExpression with AggregateExpression1 { - def this() = this(null, false) // required for serialization - - override def children: Seq[Expression] = child:: Nil - override def nullable: Boolean = false - override def dataType: DataType = DoubleType - override def toString: String = s"MergePartialStd($child)" - override def newInstance(): MergePartialStdFunction = { - new MergePartialStdFunction(child, this, isSample) - } -} - -case class MergePartialStdFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - def this() = this (null, null, false) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private val combineCount = MutableLiteral(zero.eval(null), computeType) - private val combineAvg = MutableLiteral(zero.eval(null), computeType) - private val combineMk = MutableLiteral(zero.eval(null), computeType) - - private def avgUpdateFunction(preCount: Expression, - partialCount: Expression, - partialAvg: Expression): Expression = { - Divide(Add(Multiply(combineAvg, preCount), - Multiply(partialAvg, partialCount)), - Add(preCount, partialCount)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] - - if (evaluatedExpr != null) { - val exprValue = evaluatedExpr.toArray(computeType) - val (partialCount, partialAvg, partialMk) = - (Literal.create(exprValue(0), computeType), - Literal.create(exprValue(1), computeType), - Literal.create(exprValue(2), computeType)) - - if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { - val preCount = combineCount.copy() - combineCount.update(Add(combineCount, partialCount), input) - - val preAvg = combineAvg.copy() - val avgDelta = Subtract(partialAvg, preAvg) - val mkDelta = Multiply(Multiply(avgDelta, avgDelta), - Divide(Multiply(preCount, partialCount), - combineCount)) - - // update average based on following formula - // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) - combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) - - // update sum of square differences from mean based on following formula - // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) - combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) - } - } - } - - override def eval(input: InternalRow): Any = { - val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] - - if (count == 0) null - else if (count < 2) zero.eval(null) - else { - // when total count > 2 - // stddev_samp = sqrt (combineMk/(combineCount -1)) - // stddev_pop = sqrt (combineMk/combineCount) - val varCol = { - if (isSample) { - Divide(combineMk, Cast(Literal(count - 1), computeType)) - } - else { - Divide(combineMk, Cast(Literal(count), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -case class StddevFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - - def this() = this(null, null, false) // Required for serialization - - private val computeType = DoubleType - private var curCount: Long = 0L - private val zero = Cast(Literal(0), computeType) - private val curAvg = MutableLiteral(zero.eval(null), computeType) - private val curMk = MutableLiteral(zero.eval(null), computeType) - - private def curAvgAddFunction(value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), curAvg) - Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) - } - private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), preAvg) - val delta2 = Subtract(Cast(value, computeType), curAvg) - Add(curMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val preAvg: MutableLiteral = curAvg.copy() - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - curCount += 1L - curAvg.update(curAvgAddFunction(exprValue), input) - curMk.update(curMkAddFunction(exprValue, preAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - if (curCount == 0) null - else if (curCount < 2) zero.eval(null) - else { - // when total count > 2, - // stddev_samp = sqrt(curMk/(curCount - 1)) - // stddev_pop = sqrt(curMk/curCount) - val varCol = { - if (isSample) { - Divide(curMk, Cast(Literal(curCount - 1), computeType)) - } - else { - Divide(curMk, Cast(Literal(curCount), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -// placeholder -case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "kurtosis" -} - -// placeholder -case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "skewness" -} - -// placeholder -case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_pop" -} - -// placeholder -case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_samp" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d222dfa33ad8a..f4dba67f13b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter @@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] { object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ Count(expr) if !expr.nullable => Count(Literal(1)) + case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + // This rule should be only triggered when isDistinct field is false. + AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) Cast( - Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b975b904a332..6f4f11406d7c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -84,80 +84,6 @@ object PhysicalOperation extends PredicateHelper { } } -/** - * Matches a logical aggregation that can be performed on distributed data in two steps. The first - * operates on the data in each partition performing partial aggregation for each group. The second - * occurs after the shuffle and completes the aggregation. - * - * This pattern will only match if all aggregate expressions can be computed partially and will - * return the rewritten aggregation expressions for both phases. - * - * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. - */ -object PartialAggregation { - type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - - case e: Expression => - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute - }.getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = namedGroupingExpressions.map(_._2) ++ - partialEvaluations.values.flatMap(_.partialEvaluations) - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child)) - } else { - None - } - case _ => None - } -} - - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0ec9f08571082..b9db7838db08a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Returns all of the expressions present in this query plan operator. */ def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Traversable[_] => seqToExpressions(s) + case other => Nil + } + productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seq.flatMap { - case e: Expression => e :: Nil - case other => Nil - } + case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d771088d69dea..764f8aaebddf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Utils +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -219,8 +219,6 @@ case class Aggregate( !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fbdd3a7776f50..5a2368e329976 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -171,16 +171,18 @@ class AnalysisErrorSuite extends AnalysisTest { test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + // Since we manually construct the logical plan at here and Sum only accetp + // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + AttributeReference("a", LongType)(exprId = ExprId(2)))) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) + assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) } test("error test for self-join") { @@ -196,7 +198,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan = Aggregate( AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, LocalRelation( AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) @@ -207,13 +209,24 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Aggregate( AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, LocalRelation( AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) assertAnalysisError(plan2, "map type expression a cannot be used in grouping expression" :: Nil) + + val plan3 = + Aggregate( + AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + assertAnalysisError(plan3, + "array type expression a cannot be used in grouping expression" :: Nil) } test("Join can't work on binary and map types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 71d2939ecffe6..65f09b46afae1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest { val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) + assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 40c4ae7920918..fed591fd90a9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c9bcc68f02030..b902982add8ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{TypeCollection, StringType} @@ -140,15 +141,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for aggregates") { + // We use AggregateFunction directly at here because the error will be thrown from it + // instead of from AggregateExpression, which is the wrapper of an AggregateFunction. + // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) - assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) assertError(Min('complexField), "min does not support ordering on type") assertError(Max('complexField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e67606288f514..8aaefa84937c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ed810a12808f0..0290fafe879f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -68,7 +68,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group") { val originalQuery = testRelation - .groupBy('a)('a, Count('b)) + .groupBy('a)('a, count('b)) .select('a) val optimized = Optimize.execute(originalQuery.analyze) @@ -84,7 +84,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group with alias") { val originalQuery = testRelation - .groupBy('a)('a as 'c, Count('b)) + .groupBy('a)('a as 'c, count('b)) .select('c) val optimized = Optimize.execute(originalQuery.analyze) @@ -656,7 +656,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .select('a, 'c) .where('a === 2) @@ -664,7 +664,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .where('a === 2) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .analyze comparePlans(optimized, correctAnswer) } @@ -672,7 +672,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push down filter when filter not on group by expression") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) val optimized = Optimize.execute(originalQuery.analyze) @@ -683,7 +683,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters partially which are subset of group by expressions") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L && 'a === 3) val optimized = Optimize.execute(originalQuery.analyze) @@ -691,7 +691,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a, 'b) .where('a === 3) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d25807cf8d09c..3b69247dc54ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -1338,7 +1339,7 @@ class DataFrame private[sql]( if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(First(attr), attr.name)() + Alias(new First(attr).toAggregateExpression(), attr.name)() } } Aggregate(groupCols, aggCols, logicalPlan) @@ -1381,11 +1382,11 @@ class DataFrame private[sql]( // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( - "count" -> Count, - "mean" -> Average, - "stddev" -> StddevSamp, - "min" -> Min, - "max" -> Max) + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f9eab5c2e965b..5babf2cc0ca25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,7 +71,7 @@ class GroupedData protected[sql]( } } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { val columnExprs = if (colNames.isEmpty) { @@ -88,30 +89,28 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map(f)) + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) } private[this] def strToExpr(expr: String): (Expression => Expression) = { - expr.toLowerCase match { - case "avg" | "average" | "mean" => Average - case "max" => Max - case "min" => Min - case "stddev" | "std" => StddevSamp - case "stddev_pop" => StddevPop - case "stddev_samp" => StddevSamp - case "variance" => VarianceSamp - case "var_pop" => VariancePop - case "var_samp" => VarianceSamp - case "sum" => Sum - case "skewness" => Skewness - case "kurtosis" => Kurtosis - case "count" | "size" => - // Turn count(*) into count(1) - (inputExpr: Expression) => inputExpr match { - case s: Star => Count(Literal(1)) - case _ => Count(inputExpr) - } + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } } + (inputExpr: Expression) => exprToFunc(inputExpr) } /** @@ -213,7 +212,7 @@ class GroupedData protected[sql]( * * @since 1.3.0 */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ed8b634ad5630..b7314189b5403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -448,15 +448,24 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) - val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", - defaultValue = Some(true), doc = "") - val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = + booleanConf("spark.sql.specializeSingleDistinctAggPlanning", + defaultValue = Some(true), + isPublic = false, + doc = "When true, if a query only has a single distinct column and it has " + + "grouping expressions, we will use our planner rule to handle this distinct " + + "column (other cases are handled by DistinctAggregationRewriter). " + + "When false, we will always use DistinctAggregationRewriter to plan " + + "aggregation queries with DISTINCT keyword. This is an internal flag that is " + + "used to benchmark the performance impact of using DistinctAggregationRewriter to " + + "plan aggregation queries with a single distinct column.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -532,8 +541,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) - private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = @@ -575,6 +582,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = + getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala deleted file mode 100644 index 6f3f1bd97ad52..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression1, - aggregate: AggregateExpression1, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression1 => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { - val buffer = new Array[AggregateFunction1](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") - val numOutputRows = longMetric("numOutputRows") - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - numOutputRows += 1 - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): InternalRow = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - numOutputRows += 1 - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 55e95769d3faa..91530bd63798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -45,6 +45,9 @@ case class Expand( override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + private[this] val projection = { if (outputsUnsafeRows) { (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 0f98fe88b2101..a10d1edcc91aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,7 +38,6 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { DataSourceStrategy :: DDLStrategy :: TakeOrderedAndProject :: - HashAggregation :: Aggregation :: LeftSemiJoin :: EquiJoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd3bb33c57287..d65cb1bae7fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -146,148 +146,104 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object HashAggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Aggregations that can be performed in two phases, before and after the shuffle. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil - - case _ => Nil - } - - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { - case a: logical.Aggregate => - if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { - a.newAggregation.isDefined - } else { - Utils.checkInvalidAggregateFunction2(a) - false - } - case _ => false - } - - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = - exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) - } - /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && - sqlContext.conf.codegenEnabled => - val converted = p.newAggregation - converted match { - case None => Nil // Cannot convert to new aggregation code path. - case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets (aggregate.NewAggregation will not match). - sys.error( - "Multiple distinct column sets are not supported by the new aggregation" + - "code path.") - } + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression2(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + val aggregateOperator = + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) } + } else if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + namedGroupingExpressions.map(_._2), + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } - val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } else { - aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), - functionsWithDistinct, - functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - - aggregateOperator - } + aggregateOperator case _ => Nil } @@ -422,18 +378,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled - if (useNewAggregation && a.newAggregation.isDefined) { - // If this logical.Aggregate can be planned to use new aggregation code path - // (i.e. it can be planned by the Strategy Aggregation), we will not use the old - // aggregation code path. - Nil - } else { - Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - } - } case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => execution.Window( projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 99fb7a40b72e1..008478a6a0e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -35,9 +35,9 @@ import scala.collection.mutable.ArrayBuffer abstract class AggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -76,14 +76,14 @@ abstract class AggregationIterator( // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction2] = { + protected val allAggregateFunctions: Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { + val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of @@ -135,7 +135,7 @@ abstract class AggregationIterator( } // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) // All imperative aggregate functions with mode Partial, PartialMerge, or Final. @@ -172,7 +172,7 @@ abstract class AggregationIterator( case (Some(Partial), None) => val updateExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -204,7 +204,7 @@ abstract class AggregationIterator( // allAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } // This projection is used to merge buffer values for all expression-based aggregates. val expressionAggMergeProjection = @@ -225,7 +225,7 @@ abstract class AggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -248,7 +248,7 @@ abstract class AggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalExpressionAggMergeProjection = newMutableProjection(mergeExpressions, mergeInputSchema)() @@ -256,7 +256,7 @@ abstract class AggregationIterator( val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -282,7 +282,7 @@ abstract class AggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -291,7 +291,7 @@ abstract class AggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -353,7 +353,7 @@ abstract class AggregationIterator( allAggregateFunctions.flatMap(_.aggBufferAttributes) val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 4d37106e007f5..fb7f30c2aec99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 64c673064f576..fe5c3195f867b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.metric.LongSQLMetric /** - * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been + * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( @@ -31,9 +31,9 @@ class SortBasedAggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 15616915f7364..1edde1e5a16d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} @@ -30,9 +30,9 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index ce8d592c368ee..04391443920ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -64,12 +64,12 @@ import org.apache.spark.sql.types.StructType * @param groupingExpressions * expressions for grouping keys * @param nonCompleteAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], - * [[PartialMerge]], or [[Final]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs * when they are stored in the final aggregation buffer. * @param resultExpressions @@ -83,9 +83,9 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -106,7 +106,7 @@ class TungstenAggregationIterator( // A Seq containing all AggregateExpressions. // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + private[this] val allAggregateExpressions: Seq[AggregateExpression] = nonCompleteAggregateExpressions ++ completeAggregateExpressions // Check to make sure we do not have more than three modes in our AggregateExpressions. @@ -150,10 +150,10 @@ class TungstenAggregationIterator( // Initialize all AggregateFunctions by binding references, if necessary, // and setting inputBufferOffset and mutableBufferOffset. private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction2] = { + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction @@ -195,7 +195,7 @@ class TungstenAggregationIterator( functions } - private[this] var allAggregateFunctions: Array[AggregateFunction2] = + private[this] var allAggregateFunctions: Array[AggregateFunction] = initializeAllAggregateFunctions(initialInputBufferOffset) // Positions of those imperative aggregate functions in allAggregateFunctions. @@ -263,7 +263,7 @@ class TungstenAggregationIterator( case (Some(Partial), None) => val updateExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -286,7 +286,7 @@ class TungstenAggregationIterator( case (Some(PartialMerge), None) | (Some(Final), None) => val mergeExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -307,11 +307,11 @@ class TungstenAggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } @@ -321,7 +321,7 @@ class TungstenAggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -331,7 +331,7 @@ class TungstenAggregationIterator( Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -358,7 +358,7 @@ class TungstenAggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -366,7 +366,7 @@ class TungstenAggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -414,7 +414,7 @@ class TungstenAggregationIterator( val joinedRow = new JoinedRow() val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() // These are the attributes of the row produced by `expressionAggEvalProjection` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d2f56e0fc14a4..20359c1e540e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index eaafd83158a15..79abf2d5929be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -28,8 +28,8 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -54,8 +54,8 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -137,9 +137,9 @@ object Utils { def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], - functionsWithDistinct: Seq[AggregateExpression2], - functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -253,16 +253,16 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction2] + }.asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) + AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) (rewrittenAggregateExpression, aggregateFunctionAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 0b3192a6da9d8..8cc25c2440633 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} @@ -70,7 +70,7 @@ abstract class Aggregator[-A, B, C] { implicit bEncoder: Encoder[B], cEncoder: Encoder[C]): TypedColumn[A, C] = { val expr = - new AggregateExpression2( + new AggregateExpression( TypedAggregateExpression(this), Complete, false) @@ -78,4 +78,3 @@ abstract class Aggregator[-A, B, C] { new TypedColumn[A, C](expr, encoderFor[C]) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8b9247adea200..fc873c04f88f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ /** @@ -141,40 +141,56 @@ class WindowSpec private[sql]( */ private[sql] def withAggregate(aggregate: Column): Column = { val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) + // First, we check if we get an aggregate function without the DISTINCT keyword. + // Right now, we do not support using a DISTINCT aggregate function as a + // window function. + case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => + aggregateFunction match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction( + "first_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction( + "last_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") + } + + case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => + throw new UnsupportedOperationException( + s"Distinct aggregate function ${aggregateFunction} is not supported " + + s"in window operation.") + + case wf: WindowFunction => + WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") } + new Column(windowExpr) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 258afadc76951..11dbf391cff98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ @@ -109,7 +109,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) @@ -123,7 +123,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6d56542ee0875..22104e4d48617 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -76,6 +77,12 @@ object functions extends LegacyFunctions { private def withExpr(expr: Expression): Column = Column(expr) + private def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + Column(func.toAggregateExpression(isDistinct)) + } + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) @@ -154,7 +161,9 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) } + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -170,8 +179,8 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = withExpr { - ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd, 0, 0) } /** @@ -190,7 +199,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = withExpr { Average(e.expr) } + def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -226,7 +235,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = withExpr { + def corr(column1: Column, column2: Column): Column = withAggregateFunction { Corr(column1.expr, column2.expr) } @@ -246,7 +255,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = withExpr { + def count(e: Column): Column = withAggregateFunction { e.expr match { // Turn count(*) into count(1) case s: Star => Count(Literal(1)) @@ -269,8 +278,8 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = withExpr { - CountDistinct((expr +: exprs).map(_.expr)) + def countDistinct(expr: Column, exprs: Column*): Column = { + withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true) } /** @@ -289,7 +298,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = withExpr { First(e.expr) } + def first(e: Column): Column = withAggregateFunction { new First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -305,7 +314,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) } + def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } /** * Aggregate function: returns the last value in a group. @@ -313,7 +322,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = withExpr { Last(e.expr) } + def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -329,7 +338,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = withExpr { Max(e.expr) } + def max(e: Column): Column = withAggregateFunction { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -363,7 +372,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = withExpr { Min(e.expr) } + def min(e: Column): Column = withAggregateFunction { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -379,7 +388,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = withExpr { Skewness(e.expr) } + def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -387,7 +396,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the unbiased sample standard deviation of @@ -396,7 +405,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the population standard deviation of @@ -405,7 +414,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) } + def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } /** * Aggregate function: returns the sum of all values in the expression. @@ -413,7 +422,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = withExpr { Sum(e.expr) } + def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -429,7 +438,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) } + def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -445,7 +454,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -453,7 +462,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the population variance of the values in a group. @@ -461,7 +470,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) } + def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3de277a79a52c..441a0c6d0e36e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -237,34 +237,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-8828 sum should return null if all input values are null") { - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -507,29 +483,22 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("literal in agg grouping expressions") { - def literalInAggTest(): Unit = { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - literalInAggTest() - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - literalInAggTest() - } + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) } test("aggregates with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a229e5814df89..e31c528f3a633 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,16 +21,13 @@ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} -import com.clearspring.analytics.stream.cardinality.HyperLogLog - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -134,16 +131,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("HyperLogLogUDT") { - val hyperLogLogUDT = HyperLogLogUDT - val hyperLogLog = new HyperLogLog(0.4) - (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) - - val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) - assert(actual.cardinality() === hyperLogLog.cardinality()) - assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) - } - test("OpenHashSetUDT") { val openHashSetUDT = new OpenHashSetUDT(IntegerType) val set = new OpenHashSet[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2076c573b56c1..44634dacbde68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = sqlContext.planner import planner._ - val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val plannedOption = Aggregation(query).headOption val planned = plannedOption.getOrElse( fail(s"Could query play aggregation query $query. Is it an aggregation query?")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cdd885ba14203..4b4f5c6c45c7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -152,36 +152,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("Aggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } - } - test("SortBasedAggregate metrics") { // Because SortBasedAggregate may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c5f69657f5293..ba6204633b9ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -584,7 +584,6 @@ class HiveContext private[hive]( HiveTableScans, DataSinks, Scripts, - HashAggregation, Aggregation, LeftSemiJoin, EquiJoinSelection, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ab88c1e68fd72..6f8ed413a06cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -38,6 +38,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -1508,9 +1509,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea36c132bb190..6bf2c53440baf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -69,11 +69,7 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ - var originalUseAggregate2: Boolean = _ - override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -120,7 +116,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -447,73 +442,80 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) + Seq(true, false).foreach { specializeSingleDistinctAgg => + val conf = + (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, + specializeSingleDistinctAgg.toString) + withSQLConf(conf) { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + } } test("single distinct multiple columns set") { @@ -699,48 +701,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) - - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - val errorMessage = intercept[SparkException] { - val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") - val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) - }.getMessage - assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + - "Corr only supports the new AggregateExpression2")) - } - } - - test("test Last implemented based on AggregateExpression1") { - // TODO: Remove this test once we remove AggregateExpression1. - import org.apache.spark.sql.functions._ - val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - - checkAnswer( - df.groupBy("i").agg(last("j")), - df - ) - } - } - - test("error handling") { - withSQLConf("spark.sql.useAggregate2" -> "false") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - } } test("no aggregation function (SPARK-11486)") { From 47735cdc2a878cfdbe76316d3ff8314a45dabf54 Mon Sep 17 00:00:00 2001 From: "Oscar D. Lara Yejas" Date: Tue, 10 Nov 2015 11:07:57 -0800 Subject: [PATCH 276/324] [SPARK-10863][SPARKR] Method coltypes() (New version) This is a follow up on PR #8984, as the corresponding branch for such PR was damaged. Author: Oscar D. Lara Yejas Closes #9579 from olarayej/SPARK-10863_NEW14. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 6 ++-- R/pkg/R/DataFrame.R | 49 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/R/schema.R | 15 +--------- R/pkg/R/types.R | 43 ++++++++++++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++- 7 files changed, 124 insertions(+), 18 deletions(-) create mode 100644 R/pkg/R/types.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 3d6edb70ec98e..369714f7b99c2 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -34,4 +34,5 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'types.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 56b8ed0bf271b..52fd6c9f76c54 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,9 +23,11 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", "attach", "cache", "collect", + "coltypes", "columns", "count", "cov", @@ -262,6 +264,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") - -export("as.data.frame") + "print.structType") \ No newline at end of file diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e9013aa34a84f..cc868069d1e5a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,3 +2152,52 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Returns the column types of a DataFrame. +#' +#' @name coltypes +#' @title Get column types of a DataFrame +#' @family dataframe_funcs +#' @param x (DataFrame) +#' @return value (character) A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @examples \dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#' } +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index efef7d66b522c..89731affeb898 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1047,3 +1047,7 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 6f0e9a94e9bfa..c6ddb562270b7 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -115,20 +115,7 @@ structField.jobj <- function(x) { } checkType <- function(type) { - primtiveTypes <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - if (type %in% primtiveTypes) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { return() } else { # Check complex types diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 0000000000000..1828c23ab0f6d --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "byte"="integer", + "tinyint"="integer", + "smallint"="integer", + "integer"="integer", + "bigint"="numeric", + "float"="numeric", + "double"="numeric", + "decimal"="numeric", + "string"="character", + "binary"="raw", + "boolean"="logical", + "timestamp"="POSIXct", + "date"="Date")) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map"=NA, + "array"=NA, + "struct"=NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index fbdb9a8f1ef6b..06f52d021cff8 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1467,8 +1467,9 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) +irisDF <- createDataFrame(sqlContext, iris) + test_that("Method as.data.frame as a synonym for collect()", { - irisDF <- createDataFrame(sqlContext, iris) expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -1503,6 +1504,27 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) +test_that("Method coltypes() to get R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1=c(1,2,3), + c2=c(T,F,T), + c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a"="b", "c"="d", "e"="f"))))) + expect_equal(coltypes(x), "map") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From dfcfcbcc0448ebc6f02eba6bf0495832a321c87e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Nov 2015 11:14:25 -0800 Subject: [PATCH 277/324] [SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation Currently the user facing api for typed aggregation has some limitations: * the customized typed aggregation must be the first of aggregation list * the customized typed aggregation can only use long as buffer type * the customized typed aggregation can only use flat type as result type This PR tries to remove these limitations. Author: Wenchen Fan Closes #9599 from cloud-fan/agg. --- .../catalyst/encoders/ExpressionEncoder.scala | 6 +++ .../aggregate/TypedAggregateExpression.scala | 50 +++++++++++++----- .../spark/sql/expressions/Aggregator.scala | 5 ++ .../spark/sql/DatasetAggregatorSuite.scala | 52 +++++++++++++++++++ 4 files changed, 99 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c287aebeeee05..005c0627f56b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -185,6 +185,12 @@ case class ExpressionEncoder[T]( }) } + def shift(delta: Int): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) + }) + } + /** * Returns a copy of this encoder where the expressions used to create an object given an * input row have been modified to pull the object out from a nested struct, instead of the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 24d8122b6222b..0e5bc1f9abf28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types._ object TypedAggregateExpression { def apply[A, B : Encoder, C : Encoder]( @@ -67,8 +67,11 @@ case class TypedAggregateExpression( override def nullable: Boolean = true - // TODO: this assumes flat results... - override def dataType: DataType = cEncoder.schema.head.dataType + override def dataType: DataType = if (cEncoder.flat) { + cEncoder.schema.head.dataType + } else { + cEncoder.schema + } override def deterministic: Boolean = true @@ -93,32 +96,51 @@ case class TypedAggregateExpression( case a: AttributeReference => inputMapping(a) }) - // TODO: this probably only works when we are in the first column. val bAttributes = bEncoder.schema.toAttributes lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { + // todo: need a more neat way to assign the value. + var i = 0 + while (i < aggBufferAttributes.length) { + aggBufferSchema(i).dataType match { + case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) + case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + } + i += 1 + } + } + override def initialize(buffer: MutableRow): Unit = { - // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for - // this in execution. - buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + val zero = bEncoder.toRow(aggregator.zero) + updateBuffer(buffer, zero) } override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.fromRow(buffer) + val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) val returned = boundB.toRow(merged) - buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + + updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - buffer1.setLong( - mutableAggBufferOffset, - buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val merged = aggregator.merge(b1, b2) + val returned = boundB.toRow(merged) + + updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - buffer.getInt(mutableAggBufferOffset) + val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val result = cEncoder.toRow(aggregator.present(b)) + dataType match { + case _: StructType => result + case _ => result.get(0, dataType) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 8cc25c2440633..3c1c457e06d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -57,6 +57,11 @@ abstract class Aggregator[-A, B, C] { */ def reduce(b: B, a: A): B + /** + * Merge two intermediate values + */ + def merge(b1: B, b2: B): B + /** * Transform the output of the reduction. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 340470c096b87..206095a519762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + override def present(reduction: N): N = reduction } +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] + with Serializable { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(reduction: (Long, Long)): (Long, Long) = reduction +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { count("*")), ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } } From 53600854c270d4c953fe95fbae528740b5cf6603 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Nov 2015 11:21:31 -0800 Subject: [PATCH 278/324] [SPARK-11590][SQL] use native json_tuple in lateral view Author: Wenchen Fan Closes #9562 from cloud-fan/json-tuple. --- .../expressions/jsonExpressions.scala | 23 +++++--------- .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++-------- .../org/apache/spark/sql/DataFrame.scala | 8 +++-- .../org/apache/spark/sql/functions.scala | 12 +++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 23 ++++++++------ .../org/apache/spark/sql/hive/HiveQl.scala | 4 +++ .../apache/spark/sql/hive/HiveQlSuite.scala | 13 ++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 8 files changed, 104 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8c9853e628d2c..8cd73236a7876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -314,7 +314,7 @@ case class GetJsonObject(json: Expression, path: Expression) } case class JsonTuple(children: Seq[Expression]) - extends Expression with CodegenFallback { + extends Generator with CodegenFallback { import SharedFactory._ @@ -324,8 +324,8 @@ case class JsonTuple(children: Seq[Expression]) } // if processing fails this shared value will be returned - @transient private lazy val nullRow: InternalRow = - new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil // the json body is the first child @transient private lazy val jsonExpr: Expression = children.head @@ -344,15 +344,8 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override lazy val dataType: StructType = { - val fields = fieldExpressions.zipWithIndex.map { - case (_, idx) => StructField( - name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize - dataType = StringType, - nullable = true) - } - - StructType(fields) + override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { + case (_, idx) => (StringType, true, s"c$idx") } override def prettyName: String = "json_tuple" @@ -367,7 +360,7 @@ case class JsonTuple(children: Seq[Expression]) } } - override def eval(input: InternalRow): InternalRow = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val json = jsonExpr.eval(input).asInstanceOf[UTF8String] if (json == null) { return nullRow @@ -383,7 +376,7 @@ case class JsonTuple(children: Seq[Expression]) } } - private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = { + private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { // only objects are supported if (parser.nextToken() != JsonToken.START_OBJECT) { return nullRow @@ -433,7 +426,7 @@ case class JsonTuple(children: Seq[Expression]) parser.skipChildren() } - new GenericInternalRow(row) + new GenericInternalRow(row) :: Nil } private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f33125f463e14..7b754091f4714 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -209,8 +209,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("f5") :: Nil + private def checkJsonTuple(jt: JsonTuple, expected: InternalRow): Unit = { + assert(jt.eval(null).toSeq.head === expected) + } + test("json_tuple - hive key 1") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: jsonTupleQuery), @@ -218,7 +222,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: jsonTupleQuery), @@ -226,7 +230,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2 (mix of foldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: Literal("f1") :: NonFoldableLiteral("f2") :: @@ -238,7 +242,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: jsonTupleQuery), @@ -247,7 +251,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable json)") { - checkEvaluation( + checkJsonTuple( JsonTuple( NonFoldableLiteral( """{"f1": "value13", "f4": "value44", @@ -258,7 +262,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal( """{"f1": "value13", "f4": "value44", | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: @@ -273,43 +277,43 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 4 - null json") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal(null) :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - hive key 5 - null and empty fields") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) } test("json_tuple - hive key 6 - invalid json (array)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (object start only)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (no object end)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (invalid json)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("\\") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - preserve newlines") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3b69247dc54ef..9368435a63c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -750,10 +750,14 @@ class DataFrame private[sql]( // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to // make it a NamedExpression. case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) + case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analyzer will generate the - // correct defaults after the nested expression's type has been resolved. + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. case Column(explode: Explode) => MultiAlias(explode, Nil) + case Column(jt: JsonTuple) => MultiAlias(jt, Nil) + case Column(expr: Expression) => Alias(expr, expr.prettyString)() } Project(namedExpressions.toSeq, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 22104e4d48617..a59d738010f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2307,6 +2307,18 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Creates a new row for a json column according to the given field names. + * + * @group collection_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def json_tuple(json: Column, fields: String*): Column = withExpr { + require(fields.length > 0, "at least 1 field name should be given.") + JsonTuple(json.expr +: fields.map(Literal.apply)) + } + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index e3531d0d6d799..14fd56fc8c222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -41,23 +41,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") - val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: - Row("2", Row("value12", "2", "value3", "4.01", null)) :: - Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: - Row("4", Row(null, null, null, null, null)) :: - Row("5", Row("", null, null, null, null)) :: - Row("6", Row(null, null, null, null, null)) :: + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: Nil - checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + checkAnswer( + df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), + expected) } test("json_tuple filter and group") { val df: DataFrame = tuples.toDF("key", "jstring") val expr = df - .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") - .where($"jt.c0".isNotNull) - .groupBy($"jt.c1") + .select(functions.json_tuple($"jstring", "f1", "f2")) + .where($"c0".isNotNull) + .groupBy($"c1") .count() val expected = Row(null, 1) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6f8ed413a06cd..091caab921fe9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1821,6 +1821,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head @@ -1833,6 +1834,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => (Explode(nodeToExpr(child)), attributes) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + (JsonTuple(children.map(nodeToExpr)), attributes) + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 528a7398b10df..a330362b4e1d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite @@ -183,4 +185,15 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assertError("select interval '.1111111111' second", "nanosecond 1111111111 outside range") } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = HiveQl.parseSql( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9a425d7f6b265..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1448,4 +1448,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) } } + + test("SPARK-11590: use native json_tuple in lateral view") { + checkAnswer(sql( + """ + |SELECT a, b + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin), Row("value1", "12")) + + // we should use `c0`, `c1`... as the name of fields if no alias is provided, to follow hive. + checkAnswer(sql( + """ + |SELECT c0, c1 + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt + """.stripMargin), Row("value1", "12")) + + // we can also use `json_tuple` in project list. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2') + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + """.stripMargin), Row("value1", "12")) + + // we can also mix `json_tuple` with other project expressions. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str + |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test + """.stripMargin), Row("value1", "12", 3.14, "hello")) + } } From 87aedc48c01dffbd880e6ca84076ed47c68f88d0 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 10 Nov 2015 11:28:53 -0800 Subject: [PATCH 279/324] [SPARK-10371][SQL] Implement subexpr elimination for UnsafeProjections This patch adds the building blocks for codegening subexpr elimination and implements it end to end for UnsafeProjection. The building blocks can be used to do the same thing for other operators. It introduces some utilities to compute common sub expressions. Expressions can be added to this data structure. The expr and its children will be recursively matched against existing expressions (ones previously added) and grouped into common groups. This is built using the existing `semanticEquals`. It does not understand things like commutative or associative expressions. This can be done as future work. After building this data structure, the codegen process takes advantage of it by: 1. Generating a helper function in the generated class that computes the common subexpression. This is done for all common subexpressions that have at least two occurrences and the expression tree is sufficiently complex. 2. When generating the apply() function, if the helper function exists, call that instead of regenerating the expression tree. Repeated calls to the helper function shortcircuit the evaluation logic. Author: Nong Li Author: Nong Li This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #9480 from nongli/spark-10371. --- .../expressions/EquivalentExpressions.scala | 106 ++++++++++++ .../sql/catalyst/expressions/Expression.scala | 50 +++++- .../sql/catalyst/expressions/Projection.scala | 16 ++ .../expressions/codegen/CodeGenerator.scala | 110 ++++++++++++- .../codegen/GenerateUnsafeProjection.scala | 36 ++++- .../expressions/namedExpressions.scala | 4 + .../SubexpressionEliminationSuite.scala | 153 ++++++++++++++++++ .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../spark/sql/execution/SparkPlan.scala | 5 + .../spark/sql/execution/basicOperators.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++ 11 files changed, 523 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 0000000000000..e7380d21f98af --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + val hash = e.semanticHash() + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override def hashCode: Int = hash + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] = + new mutable.HashMap[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get.+= (expr) + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this datastructure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + if (!skip && root.deterministic && !addExpr(root)) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all fo the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the datastructure as a string. If all is false, skips sets of equivalent + * expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => { + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + }} + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 96fcc799e537a..7d5741eefcc7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + val subExprState = ctx.subExprEliminationExprs.get(this) + if (subExprState.isDefined) { + // This expression is repeated meaning the code to evaluated has already been added + // as a function, `subExprState.fnName`. Just call that. + val code = + s""" + |/* $this */ + |${subExprState.get.fnName}(${ctx.INPUT_ROW}); + |""".stripMargin.trim + GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value) + } else { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code.trim) + } } /** @@ -145,11 +157,37 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + // Non-determinstic expressions cannot be equal + if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq checkSemantic(elements1, elements2) } + /** + * Returns the hash for this expression. Expressions that compute the same result, even if + * they differ cosmetically should return the same hash. + */ + def semanticHash() : Int = { + def computeHash(e: Seq[Any]): Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var hash: Int = 17 + e.foreach(i => { + val h: Int = i match { + case (e: Expression) => e.semanticHash() + case (Some(e: Expression)) => e.semanticHash() + case (t: Traversable[_]) => computeHash(t.toSeq) + case null => 0 + case (o) => o.hashCode() + } + hash = hash * 37 + h + }) + hash + } + + computeHash(this.productIterator.toSeq) + } + /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 79dabe8e925ad..9f0b7821ae74a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -144,6 +144,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f0f7a6cf0cc4d..60a3d6018496c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -92,6 +92,33 @@ class CodeGenContext { addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two exprsesion generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState( + val isLoaded: String, code: GeneratedExpressionCode, val fnName: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] = + mutable.HashMap[Expression, SubExprEliminationState]() + + // The collection of isLoaded variables that need to be reset on each row. + val subExprIsLoadedVariables: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -317,6 +344,87 @@ class CodeGenContext { functions.map(name => s"$name($row);").mkString("\n") } } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpresses, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the exprs that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach(e => { + val expr = e.head + val isLoaded = freshName("isLoaded") + val isNull = freshName("isNull") + val primitive = freshName("primitive") + val fnName = freshName("evalExpr") + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow ${INPUT_ROW}) { + | if (!$isLoaded) { + | ${code.code.trim} + | $isLoaded = true; + | $isNull = ${code.isNull}; + | $primitive = ${code.value}; + | } + |} + """.stripMargin + code.code = fn + code.isNull = isNull + code.value = primitive + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + // Maintain the loaded value and isNull as member variables. This is necessary if the codegen + // function is split across multiple functions. + // TODO: maintaining this as a local variable probably allows the compiler to do better + // optimizations. + addMutableState("boolean", isLoaded, s"$isLoaded = false;") + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), primitive, + s"$primitive = ${defaultValue(expr.dataType)};") + subExprIsLoadedVariables += isLoaded + + val state = SubExprEliminationState(isLoaded, code, fnName) + e.foreach(subExprEliminationExprs.put(_, state)) + }) + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2136f82ba4752..9ef226141421b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" ${input.code} if (${input.isNull}) { - $setNull + ${setNull.trim} } else { - $writeField + ${writeField.trim} } """ } @@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $rowWriter.initialize($bufferHolder, ${inputs.length}); ${ctx.splitExpressions(row, writeFields)} - """ + """.trim } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { - val exprEvals = expressions.map(e => e.gen(ctx)) + def createCode( + ctx: CodeGenContext, + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") @@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + // Reset the isLoaded flag for each row. + val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + val code = s""" $bufferHolder.reset(); + $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) @@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } + protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + create(expressions, false) + } - val eval = createCode(ctx, expressions) + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" public Object generate($exprType[] exprs) { @@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { @@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code} + ${eval.code.trim} return ${eval.value}; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9ab5c299d0f55..f80bcfcb0b0bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -203,6 +203,10 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 0000000000000..9de066e99d637 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode == b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b7314189b5403..89e196c066007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -268,6 +268,11 @@ private[spark] object SQLConf { doc = "When true, use the new optimized Tungsten physical execution backend.", isPublic = false) + val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", + defaultValue = Some(true), // use CODEGEN_ENABLED as default + doc = "When true, common subexpressions will be eliminated.", + isPublic = false) + val DIALECT = stringConf( "spark.sql.dialect", defaultValue = Some("sql"), @@ -541,6 +546,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8bb293ae87e64..8650ac500b652 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } else { false } + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 145de0db9edaa..303d636164adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(projectList, child.output) + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 441a0c6d0e36e..19e850a46fdfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1970,4 +1970,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) } } + + test("Common subexpression elimination") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } } From f14e95115c0939a77ebcb00209696a87fd651ff9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 10 Nov 2015 11:34:36 -0800 Subject: [PATCH 280/324] [ML][R] SparkR::glm summary result to compare with native R Follow up #9561. Due to [SPARK-11587](https://issues.apache.org/jira/browse/SPARK-11587) has been fixed, we should compare SparkR::glm summary result with native R output rather than hard-code one. mengxr Author: Yanbo Liang Closes #9590 from yanboliang/glm-r-test. --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 31 ++++++++++--------------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7126b7cde4bd7..f23e1c7f1fce4 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -106,7 +106,7 @@ setMethod("summary", signature(object = "PipelineModel"), coefficients <- matrix(coefficients, ncol = 4) colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) - return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients)) + return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) } else { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 42287ea19adc5..d497ad8c9daa3 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -72,22 +72,17 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) - coefs <- unlist(stats$Coefficients) - devianceResiduals <- unlist(stats$DevianceResiduals) + coefs <- unlist(stats$coefficients) + devianceResiduals <- unlist(stats$devianceResiduals) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331) - rTValue <- c(7.123, 7.557, -13.644, -10.798) - rPValue <- c(0.0, 0.0, 0.0, 0.0) + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rCoefs <- unlist(rStats$coefficients) rDevianceResiduals <- c(-0.95096, 0.72918) - expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6)) - expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5)) - expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3)) - expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6)) + expect_true(all(abs(rCoefs - coefs) < 1e-5)) expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - rownames(stats$Coefficients) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -96,21 +91,15 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$Coefficients) + coefs <- as.vector(stats$coefficients[,1]) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) - rStdError <- c(3.0974, 0.5169, 0.8628) - rTValue <- c(-4.212, 3.680, 0.469) - rPValue <- c(0.000, 0.000, 0.639) - - expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4)) - expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4)) - expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3)) - expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3)) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) expect_true(all( - rownames(stats$Coefficients) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) From 18350a57004eb87cafa9504ff73affab4b818e06 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 11:36:43 -0800 Subject: [PATCH 281/324] [SPARK-11618][ML] Minor refactoring of basic ML import/export Refactoring * separated overwrite and param save logic in DefaultParamsWriter * added sparkVersion to DefaultParamsWriter CC: mengxr Author: Joseph K. Bradley Closes #9587 from jkbradley/logreg-io. --- .../org/apache/spark/ml/util/ReadWrite.scala | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ea790e0dddc7f..cbdf913ba8dfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite { protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { SQLContext.getOrCreate(SparkContext.getOrCreate()) } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext } /** @@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite { +abstract class Writer extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit /** * Overwrites if the output path already exists. @@ -147,28 +172,9 @@ trait Readable[T] { * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { - - /** - * Saves the ML component to the input path. - */ - override def save(path: String): Unit = { - val sc = sqlContext.sparkContext - - val hadoopConf = sc.hadoopConfiguration - val fs = FileSystem.get(hadoopConf) - val p = new Path(path) - if (fs.exists(p)) { - if (shouldOverwrite) { - logInfo(s"Path $path already exists. It will be overwritten.") - // TODO: Revert back to the original content if save is not successful. - fs.delete(p, true) - } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") - } - } +private[ml] class DefaultParamsWriter(instance: Params) extends Writer { + override protected def saveImpl(path: String): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg }.toList val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) val metadataPath = new Path(path, "metadata").toString @@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg */ private[ml] class DefaultParamsReader[T] extends Reader[T] { - /** - * Loads the ML component from the input path. - */ override def load(path: String): T = { implicit val format = DefaultFormats - val sc = sqlContext.sparkContext val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) From dba1a62cf1baa9ae1ee665d592e01dfad78331a2 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 10 Nov 2015 14:25:06 -0800 Subject: [PATCH 282/324] [SPARK-7316][MLLIB] RDD sliding window with step Implementation of step capability for sliding window function in MLlib's RDD. Though one can use current sliding window with step 1 and then filter every Nth window, it will take more time and space (N*data.count times more than needed). For example, below are the results for various windows and steps on 10M data points: Window | Step | Time | Windows produced ------------ | ------------- | ---------- | ---------- 128 | 1 | 6.38 | 9999873 128 | 10 | 0.9 | 999988 128 | 100 | 0.41 | 99999 1024 | 1 | 44.67 | 9998977 1024 | 10 | 4.74 | 999898 1024 | 100 | 0.78 | 99990 ``` import org.apache.spark.mllib.rdd.RDDFunctions._ val rdd = sc.parallelize(1 to 10000000, 10) rdd.count val window = 1024 val step = 1 val t = System.nanoTime(); val windows = rdd.sliding(window, step); println(windows.count); println((System.nanoTime() - t) / 1e9) ``` Author: unknown Author: Alexander Ulanov Author: Xiangrui Meng Closes #5855 from avulanov/SPARK-7316-sliding. --- .../apache/spark/mllib/rdd/RDDFunctions.scala | 11 ++- .../apache/spark/mllib/rdd/SlidingRDD.scala | 71 ++++++++++--------- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 11 +-- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 78172843be56e..19a047ded257c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -37,15 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Array[T]] = { + def sliding(windowSize: Int, step: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") - if (windowSize == 1) { + if (windowSize == 1 && step == 1) { self.map(Array(_)) } else { - new SlidingRDD[T](self, windowSize) + new SlidingRDD[T](self, windowSize, step) } } + /** + * [[sliding(Int, Int)*]] with step = 1. + */ + def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1) + /** * Reduces the elements of this RDD in a multi-level tree pattern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 1facf83d806d0..ead8db6344998 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.{TaskContext, Partition} import org.apache.spark.rdd.RDD private[mllib] -class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int) extends Partition with Serializable { override val index: Int = idx } /** - * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to @@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] * * @param parent the parent RDD * @param windowSize the window size, must be greater than 1 + * @param step step size for windows * - * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]] + * @see [[scala.collection.IterableLike.sliding(Int, Int)*]] */ private[mllib] -class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int) extends RDD[Array[T]](parent) { - require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1), + "Window size and step must be greater than 0, " + + s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.") override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) - .sliding(windowSize) + .drop(part.offset) + .sliding(windowSize, step) .withPartial(false) .map(_.toArray) } @@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int override def getPartitions: Array[Partition] = { val parentPartitions = parent.partitions - val n = parentPartitions.size + val n = parentPartitions.length if (n == 0) { Array.empty } else if (n == 1) { - Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0)) } else { - val n1 = n - 1 val w1 = windowSize - 1 - // Get the first w1 items of each partition, starting from the second partition. - val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) - val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + // Get partition sizes and first w1 elements. + val (sizes, heads) = parent.mapPartitions { iter => + val w1Array = iter.take(w1).toArray + Iterator.single((w1Array.length + iter.length, w1Array)) + }.collect().unzip + val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]] var i = 0 + var cumSize = 0 var partitionIndex = 0 - while (i < n1) { - var j = i - val tail = mutable.ListBuffer[T]() - // Keep appending to the current tail until appended a head of size w1. - while (j < n1 && nextHeads(j).size < w1) { - tail ++= nextHeads(j) - j += 1 + while (i < n) { + val mod = cumSize % step + val offset = if (mod == 0) 0 else step - mod + val size = sizes(i) + if (offset < size) { + val tail = mutable.ListBuffer.empty[T] + // Keep appending to the current tail until it has w1 elements. + var j = i + 1 + while (j < n && tail.length < w1) { + tail ++= heads(j).take(w1 - tail.length) + j += 1 + } + if (sizes(i) + tail.length >= offset + windowSize) { + partitions += + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + partitionIndex += 1 + } } - if (j < n1) { - tail ++= nextHeads(j) - j += 1 - } - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) - partitionIndex += 1 - // Skip appended heads. - i = j - } - // If the head of last partition has size w1, we also need to add this partition. - if (nextHeads.last.size == w1) { - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + cumSize += size + i += 1 } partitions.toArray } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index bc64172614830..ac93733bab5f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { - val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList - val expected = data.sliding(windowSize).map(_.toList).toList - assert(sliding === expected) + for (step <- 1 to 3) { + val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList + val expected = data.sliding(windowSize, step) + .map(_.toList).toList.filter(l => l.size == windowSize) + assert(sliding === expected) + } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") @@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) - assert(rdd.partitions.size === data.length) + assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) From 724cf7a38c551bf2a79b87a8158bbe1725f9f888 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 10 Nov 2015 14:30:19 -0800 Subject: [PATCH 283/324] [SPARK-11616][SQL] Improve toString for Dataset Author: Michael Armbrust Closes #9586 from marmbrus/dataset-toString. --- .../org/apache/spark/sql/DataFrame.scala | 14 ++----- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/Queryable.scala | 37 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 5 +++ 4 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 9368435a63c35..691b476fff8d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -116,7 +116,8 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) + extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. @@ -234,15 +235,6 @@ class DataFrame private[sql]( sb.toString() } - override def toString: String = { - try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } - } - /** * Returns the object itself. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6d2968e2881f8..a7e5ab19bf846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** @@ -62,7 +62,7 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Serializable { + unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala new file mode 100644 index 0000000000000..9ca383896a09b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.types.StructType + +import scala.util.control.NonFatal + +/** A trait that holds shared code between DataFrames and Datasets. */ +private[sql] trait Queryable { + def schema: StructType + def queryExecution: QueryExecution + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index aea5a700d0204..621148528714f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -313,4 +313,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") checkAnswer(joined, ("2", 2)) } + + test("toString") { + val ds = Seq((1, 2)).toDS() + assert(ds.toString == "[_1: int, _2: int]") + } } From 638c51d9380081b3b8182be2c2460bd53b8b0a4f Mon Sep 17 00:00:00 2001 From: Pravin Gadakh Date: Tue, 10 Nov 2015 14:47:04 -0800 Subject: [PATCH 284/324] [SPARK-11550][DOCS] Replace example code in mllib-optimization.md using include_example Author: Pravin Gadakh Closes #9516 from pravingadakh/SPARK-11550. --- docs/mllib-optimization.md | 145 +----------------- .../examples/mllib/JavaLBFGSExample.java | 108 +++++++++++++ .../spark/examples/mllib/LBFGSExample.scala | 90 +++++++++++ 3 files changed, 200 insertions(+), 143 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index a3bd130ba077c..ad7bcd9bfd407 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -220,154 +220,13 @@ L-BFGS optimizer.
    Refer to the [`LBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) and [`SquaredL2Updater` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.SquaredL2Updater) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val numFeatures = data.take(1)(0).features.size - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - -// Append 1 into the training data as intercept. -val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() - -val test = splits(1) - -// Run training algorithm to build the model -val numCorrections = 10 -val convergenceTol = 1e-4 -val maxNumIterations = 20 -val regParam = 0.1 -val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) - -val (weightsWithIntercept, loss) = LBFGS.runLBFGS( - training, - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept) - -val model = new LogisticRegressionModel( - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), - weightsWithIntercept(weightsWithIntercept.size - 1)) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Loss of each step in training process") -loss.foreach(println) -println("Area under ROC = " + auROC) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LBFGSExample.scala %}
    Refer to the [`LBFGS` Java docs](api/java/org/apache/spark/mllib/optimization/LBFGS.html) and [`SquaredL2Updater` Java docs](api/java/org/apache/spark/mllib/optimization/SquaredL2Updater.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.Random; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.optimization.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class LBFGSExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - int numFeatures = data.take(1).get(0).features().size(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD trainingInit = data.sample(false, 0.6, 11L); - JavaRDD test = data.subtract(trainingInit); - - // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); - training.cache(); - - // Run training algorithm to build the model. - int numCorrections = 10; - double convergenceTol = 1e-4; - int maxNumIterations = 20; - double regParam = 0.1; - Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); - - Tuple2 result = LBFGS.runLBFGS( - training.rdd(), - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept); - Vector weightsWithIntercept = result._1(); - double[] loss = result._2(); - - final LogisticRegressionModel model = new LogisticRegressionModel( - Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), - (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(scoreAndLabels.rdd()); - double auROC = metrics.areaUnderROC(); - - System.out.println("Loss of each step in training process"); - for (double l : loss) - System.out.println(l); - System.out.println("Area under ROC = " + auROC); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLBFGSExample.java %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java new file mode 100644 index 0000000000000..355883f61bd64 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example off$ + +public class JavaLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + // $example off$ + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala new file mode 100644 index 0000000000000..61d2e7715f53d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object LBFGSExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("LBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + val numFeatures = data.take(1)(0).features.size + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + + // Append 1 into the training data as intercept. + val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() + + val test = splits(1) + + // Run training algorithm to build the model + val numCorrections = 10 + val convergenceTol = 1e-4 + val maxNumIterations = 20 + val regParam = 0.1 + val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) + + val (weightsWithIntercept, loss) = LBFGS.runLBFGS( + training, + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept) + + val model = new LogisticRegressionModel( + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), + weightsWithIntercept(weightsWithIntercept.size - 1)) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Loss of each step in training process") + loss.foreach(println) + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println From 32790fe7249b0efe2cbc5c4ee2df0fb687dcd624 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 10 Nov 2015 15:47:10 -0800 Subject: [PATCH 285/324] [SPARK-11567] [PYTHON] Add Python API for corr Aggregate function like `df.agg(corr("col1", "col2")` davies Author: felixcheung Closes #9536 from felixcheung/pyfunc. --- python/pyspark/sql/functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6e1cbde4239f3..c3da513c13897 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -255,6 +255,22 @@ def coalesce(*cols): return Column(jc) +@since(1.6) +def corr(col1, col2): + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. + + >>> a = [x * x - 2 * x + 3.5 for x in range(20)] + >>> b = range(20) + >>> corrDf = sqlContext.createDataFrame(zip(a, b)) + >>> corrDf = corrDf.agg(corr(corrDf._1, corrDf._2).alias('c')) + >>> corrDf.selectExpr('abs(c - 0.9572339139475857) < 1e-16 as t').collect() + [Row(t=True)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) + + @since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. From 1dde39d796bbf42336051a86bedf871c7fddd513 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Nov 2015 15:58:30 -0800 Subject: [PATCH 286/324] [SPARK-9818] Re-enable Docker tests for JDBC data source This patch re-enables tests for the Docker JDBC data source. These tests were reverted in #4872 due to transitive dependency conflicts introduced by the `docker-client` library. This patch should avoid those problems by using a version of `docker-client` which shades its transitive dependencies and by performing some build-magic to work around problems with that shaded JAR. In addition, I significantly refactored the tests to simplify the setup and teardown code and to fix several Docker networking issues which caused problems when running in `boot2docker`. Closes #8101. Author: Josh Rosen Author: Yijie Shen Closes #9503 from JoshRosen/docker-jdbc-tests. --- docker-integration-tests/pom.xml | 149 ++++++++++++++++ .../sql/jdbc/DockerJDBCIntegrationSuite.scala | 160 ++++++++++++++++++ .../sql/jdbc/MySQLIntegrationSuite.scala | 153 +++++++++++++++++ .../sql/jdbc/PostgresIntegrationSuite.scala | 82 +++++++++ .../org/apache/spark/util/DockerUtils.scala | 68 ++++++++ pom.xml | 14 ++ project/SparkBuild.scala | 14 +- .../org/apache/spark/tags/DockerTest.java | 26 +++ 8 files changed, 664 insertions(+), 2 deletions(-) create mode 100644 docker-integration-tests/pom.xml create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala create mode 100644 tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml new file mode 100644 index 0000000000000..dee0c4aa37ae8 --- /dev/null +++ b/docker-integration-tests/pom.xml @@ -0,0 +1,149 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + spark-docker-integration-tests_2.10 + jar + Spark Project Docker Integration Tests + http://spark.apache.org/ + + docker-integration-tests + + + + + com.spotify + docker-client + shaded + test + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + + + com.fasterxml.jackson.core + jackson-databind + + + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.connectors + jersey-apache-connector + + + org.glassfish.jersey.media + jersey-media-json-jackson + + + + + + com.google.guava + guava + 18.0 + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + + + + com.sun.jersey + jersey-server + 1.19 + test + + + com.sun.jersey + jersey-core + 1.19 + test + + + com.sun.jersey + jersey-servlet + 1.19 + test + + + com.sun.jersey + jersey-json + 1.19 + test + + + stax + stax-api + + + + + + diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala new file mode 100644 index 0000000000000..c503c4a13b482 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.net.ServerSocket +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.spotify.docker.client._ +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.DockerUtils +import org.apache.spark.sql.test.SharedSQLContext + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled. + */ + val imageName: String + + /** + * Environment variables to set inside of the Docker container while launching it. + */ + val env: Map[String, String] + + /** + * The container-internal JDBC port that the database listens on. + */ + val jdbcPort: Int + + /** + * Return a JDBC URL that connects to the database running at the given IP address and port. + */ + def getJdbcUrl(ip: String, port: Int): String +} + +abstract class DockerJDBCIntegrationSuite + extends SparkFunSuite + with BeforeAndAfterAll + with Eventually + with SharedSQLContext { + + val db: DatabaseOnDocker + + private var docker: DockerClient = _ + private var containerId: String = _ + protected var jdbcUrl: String = _ + + override def beforeAll() { + super.beforeAll() + try { + docker = DefaultDockerClient.fromEnv.build() + // Check that Docker is actually up + try { + docker.ping() + } catch { + case NonFatal(e) => + log.error("Exception while connecting to Docker. Check whether Docker is running.") + throw e + } + // Ensure that the Docker image is installed: + try { + docker.inspectImage(db.imageName) + } catch { + case e: ImageNotFoundException => + log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") + docker.pull(db.imageName) + } + // Configure networking (necessary for boot2docker / Docker Machine) + val externalPort: Int = { + val sock = new ServerSocket(0) + val port = sock.getLocalPort + sock.close() + port + } + val dockerIp = DockerUtils.getDockerIp() + val hostConfig: HostConfig = HostConfig.builder() + .networkMode("bridge") + .portBindings( + Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) + .build() + // Create the database container: + val config = ContainerConfig.builder() + .image(db.imageName) + .networkDisabled(false) + .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) + .hostConfig(hostConfig) + .exposedPorts(s"${db.jdbcPort}/tcp") + .build() + containerId = docker.createContainer(config).id + // Start the container and wait until the database can accept JDBC connections: + docker.startContainer(containerId) + jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) + eventually(timeout(60.seconds), interval(1.seconds)) { + val conn = java.sql.DriverManager.getConnection(jdbcUrl) + conn.close() + } + // Run any setup queries: + val conn: Connection = java.sql.DriverManager.getConnection(jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } catch { + case NonFatal(e) => + try { + afterAll() + } finally { + throw e + } + } + } + + override def afterAll() { + try { + if (docker != null) { + try { + if (containerId != null) { + docker.killContainer(containerId) + docker.removeContainer(containerId) + } + } catch { + case NonFatal(e) => + logWarning(s"Could not stop container $containerId", e) + } finally { + docker.close() + } + } + } finally { + super.afterAll() + } + } + + /** + * Prepare databases and tables for testing. + */ + def dataPreparation(connection: Connection): Unit +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 0000000000000..c68e4dc4933b1 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "mysql:5.7.9" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 0000000000000..164a7f396280c --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "postgres:9.4.5" + override val env = Map( + "POSTGRES_PASSWORD" -> "rootpass" + ) + override val jdbcPort = 5432 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Double")) + assert(types(3).equals("class java.lang.Long")) + assert(types(4).equals("class java.lang.Boolean")) + assert(types(5).equals("class [B")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class java.lang.Boolean")) + assert(types(8).equals("class java.lang.String")) + assert(types(9).equals("class java.lang.String")) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test only that it doesn't crash. + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala new file mode 100644 index 0000000000000..87271776d8564 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{Inet4Address, NetworkInterface, InetAddress} + +import scala.collection.JavaConverters._ +import scala.sys.process._ +import scala.util.Try + +private[spark] object DockerUtils { + + def getDockerIp(): String = { + /** If docker-machine is setup on this box, attempts to find the ip from it. */ + def findFromDockerMachine(): Option[String] = { + sys.env.get("DOCKER_MACHINE_NAME").flatMap { name => + Try(Seq("/bin/bash", "-c", s"docker-machine ip $name 2>/dev/null").!!.trim).toOption + } + } + sys.env.get("DOCKER_IP") + .orElse(findFromDockerMachine()) + .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) + .getOrElse { + // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // certain interfaces. + val address = InetAddress.getLocalHost + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val blackListedIFs = Seq( + "vboxnet0", // Mac + "docker0" // Linux + ) + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => + !blackListedIFs.contains(i.getName) + } + val reOrderedNetworkIFs = activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) + return strippedAddress.getHostAddress + } + } + address.getHostAddress + } + } +} diff --git a/pom.xml b/pom.xml index fd8c773513881..c499a80aa0f43 100644 --- a/pom.xml +++ b/pom.xml @@ -98,6 +98,7 @@ sql/catalyst sql/core sql/hive + docker-integration-tests unsafe assembly external/twitter @@ -778,6 +779,19 @@ 0.11 test + + com.spotify + docker-client + shaded + 3.2.1 + test + + + guava + com.google.guava + + + org.apache.curator curator-recipes diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a9fb741d75933..b7c619224329f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,8 +43,9 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl, dockerIntegrationTests) = + Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", + "docker-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") @@ -240,6 +241,8 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + /** * Adds the ability to run the spark shell directly from SBT without building an assembly @@ -291,6 +294,13 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +object DockerIntegrationTests { + // This serves to override the override specified in DependencyOverrides: + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "18.0" + ) +} + /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/tags/src/main/java/org/apache/spark/tags/DockerTest.java new file mode 100644 index 0000000000000..0fecf3b8f979a --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/DockerTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface DockerTest { } From e281b87398f1298cc3df8e0409c7040acdddce03 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 16:20:10 -0800 Subject: [PATCH 287/324] [SPARK-5565][ML] LDA wrapper for Pipelines API This adds LDA to spark.ml, the Pipelines API. It follows the design doc in the JIRA: [https://issues.apache.org/jira/browse/SPARK-5565], with one major change: * I eliminated doc IDs. These are not necessary with DataFrames since the user can add an ID column as needed. Note: This will conflict with [https://github.com/apache/spark/pull/9484], but I'll try to merge [https://github.com/apache/spark/pull/9484] first and then rebase this PR. CC: hhbyyh feynmanliang If you have a chance to make a pass, that'd be really helpful--thanks! Now that I'm done traveling & this PR is almost ready, I'll see about reviewing other PRs critical for 1.6. CC: mengxr Author: Joseph K. Bradley Closes #9513 from jkbradley/lda-pipelines. --- .../org/apache/spark/ml/clustering/LDA.scala | 701 ++++++++++++++++++ .../spark/mllib/clustering/LDAModel.scala | 29 +- .../apache/spark/ml/clustering/LDASuite.scala | 221 ++++++ 3 files changed, 946 insertions(+), 5 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala new file mode 100644 index 0000000000000..f66233ed3d0f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -0,0 +1,701 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} +import org.apache.spark.sql.types.StructType + + +private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter + with HasSeed with HasCheckpointInterval { + + /** + * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * @group param + */ + @Since("1.6.0") + final val k = new IntParam(this, "k", "number of topics (clusters) to infer", + ParamValidators.gt(1)) + + /** @group getParam */ + @Since("1.6.0") + def getK: Int = $(k) + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). + * + * If not set by the user, then docConcentration is set automatically. If set to + * singleton vector [alpha], then alpha is replicated to a vector of length k in fitting. + * Otherwise, the [[docConcentration]] vector must be length k. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val docConcentration = new DoubleArrayParam(this, "docConcentration", + "Concentration parameter (commonly named \"alpha\") for the prior placed on documents'" + + " distributions over topics (\"theta\").", (alpha: Array[Double]) => alpha.forall(_ >= 0.0)) + + /** @group getParam */ + @Since("1.6.0") + def getDocConcentration: Array[Double] = $(docConcentration) + + /** Get docConcentration used by spark.mllib LDA */ + protected def getOldDocConcentration: Vector = { + if (isSet(docConcentration)) { + Vectors.dense(getDocConcentration) + } else { + Vectors.dense(-1.0) + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * If not set by the user, then topicConcentration is set automatically. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val topicConcentration = new DoubleParam(this, "topicConcentration", + "Concentration parameter (commonly named \"beta\" or \"eta\") for the prior placed on topic'" + + " distributions over terms.", ParamValidators.gtEq(0)) + + /** @group getParam */ + @Since("1.6.0") + def getTopicConcentration: Double = $(topicConcentration) + + /** Get topicConcentration used by spark.mllib LDA */ + protected def getOldTopicConcentration: Double = { + if (isSet(topicConcentration)) { + getTopicConcentration + } else { + -1.0 + } + } + + /** Supported values for Param [[optimizer]]. */ + @Since("1.6.0") + final val supportedOptimizers: Array[String] = Array("online", "em") + + /** + * Optimizer or inference algorithm used to estimate the LDA model. + * Currently supported (case-insensitive): + * - "online": Online Variational Bayes (default) + * - "em": Expectation-Maximization + * + * For details, see the following papers: + * - Online LDA: + * Hoffman, Blei and Bach. "Online Learning for Latent Dirichlet Allocation." + * Neural Information Processing Systems, 2010. + * [[http://www.cs.columbia.edu/~blei/papers/HoffmanBleiBach2010b.pdf]] + * - EM: + * Asuncion et al. "On Smoothing and Inference for Topic Models." + * Uncertainty in Artificial Intelligence, 2009. + * [[http://arxiv.org/pdf/1205.2662.pdf]] + * + * @group param + */ + @Since("1.6.0") + final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), + (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + + /** @group getParam */ + @Since("1.6.0") + def getOptimizer: String = $(optimizer) + + /** + * Output column with estimates of the topic mixture distribution for each document (often called + * "theta" in the literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @group param + */ + @Since("1.6.0") + final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" + + " with estimates of the topic mixture distribution for each document (often called \"theta\"" + + " in the literature). Returns a vector of zeros for an empty document.") + + setDefault(topicDistributionCol -> "topicDistribution") + + /** @group getParam */ + @Since("1.6.0") + def getTopicDistributionCol: String = $(topicDistributionCol) + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) + * Default: 1024, following Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" + + " parameter that downweights early iterations. Larger values make early iterations count less.", + ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningOffset: Double = $(learningOffset) + + /** + * Learning rate, set as an exponential decay rate. + * This should be between (0.5, 1.0] to guarantee asymptotic convergence. + * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). + * Default: 0.51, based on Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" + + " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" + + " convergence.", ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningDecay: Double = $(learningDecay) + + /** + * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, + * in range (0, 1]. + * + * Note that this should be adjusted in synch with [[LDA.maxIter]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Note: This is the same as the `miniBatchFraction` parameter in + * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. + * + * Default: 0.05, i.e., 5% of total documents. + * @group param + */ + @Since("1.6.0") + final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" + + " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].", + ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("1.6.0") + def getSubsamplingRate: Double = $(subsamplingRate) + + /** + * Indicates whether the docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. + * Setting this to true will make the model more expressive and fit the training data better. + * Default: false + * @group expertParam + */ + @Since("1.6.0") + final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration", + "Indicates whether the docConcentration (Dirichlet parameter for document-topic" + + " distribution) will be optimized during training.") + + /** @group expertGetParam */ + @Since("1.6.0") + def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) + } + + @Since("1.6.0") + override def validateParams(): Unit = { + if (isSet(docConcentration)) { + if (getDocConcentration.length != 1) { + require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + + s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + + s" length either 1 (scalar) or k (num topics).") + } + getOptimizer match { + case "online" => + require(getDocConcentration.forall(_ >= 0), + "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + + getDocConcentration.mkString(",")) + case "em" => + require(getDocConcentration.forall(_ >= 0), + "For EM optimizer, docConcentration values must be >= 1. Found values: " + + getDocConcentration.mkString(",")) + } + } + if (isSet(topicConcentration)) { + getOptimizer match { + case "online" => + require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + + s" must be >= 0. Found value: $getTopicConcentration") + case "em" => + require(getTopicConcentration >= 0, s"For EM optimizer, topicConcentration" + + s" must be >= 1. Found value: $getTopicConcentration") + } + } + } + + private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + } +} + + +/** + * :: Experimental :: + * Model fitted by [[LDA]]. + * + * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) + * @param oldLocalModel Underlying spark.mllib model. + * If this model was produced by Online LDA, then this is the + * only model representation. + * If this model was produced by EM, then this local + * representation may be built lazily. + * @param sqlContext Used to construct local DataFrames for returning query results + */ +@Since("1.6.0") +@Experimental +class LDAModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val vocabSize: Int, + @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel], + @Since("1.6.0") @transient protected val sqlContext: SQLContext) + extends Model[LDAModel] with LDAParams with Logging { + + /** Returns underlying spark.mllib model */ + @Since("1.6.0") + protected def getModel: OldLDAModel = oldLocalModel match { + case Some(m) => m + case None => + // Should never happen. + throw new RuntimeException("LDAModel required local model format," + + " but the underlying model is missing.") + } + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("1.6.0") + override def copy(extra: ParamMap): LDAModel = { + val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) + copyValues(copied, extra).setParent(parent) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + if ($(topicDistributionCol).nonEmpty) { + val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext)) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + } else { + logWarning("LDAModel.transform was called without any output columns. Set an output column" + + " such as topicDistributionCol to produce results.") + dataset + } + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Value for [[docConcentration]] estimated from data. + * If Online LDA was used and [[optimizeDocConcentration]] was set to false, + * then this returns the fixed (given) value for the [[docConcentration]] parameter. + */ + @Since("1.6.0") + def estimatedDocConcentration: Vector = getModel.docConcentration + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM, + * then this method could involve collecting a large amount of data to the driver + * (on the order of vocabSize x k). + */ + @Since("1.6.0") + def topicsMatrix: Matrix = getModel.topicsMatrix + + /** Indicates whether this instance is of type [[DistributedLDAModel]] */ + @Since("1.6.0") + def isDistributed: Boolean = false + + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting + * a large [[topicsMatrix]] to the driver. This implementation may be changed in the + * future. + * + * @param dataset test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + @Since("1.6.0") + def logLikelihood(dataset: DataFrame): Double = oldLocalModel match { + case Some(m) => + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + m.logLikelihood(oldDataset) + case None => + // Should never happen. + throw new RuntimeException("LocalLDAModel.logLikelihood was called," + + " but the underlying model is missing.") + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * @param dataset test corpus to use for calculating perplexity + * @return Variational upper bound on log perplexity per token. + */ + @Since("1.6.0") + def logPerplexity(dataset: DataFrame): Double = oldLocalModel match { + case Some(m) => + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + m.logPerplexity(oldDataset) + case None => + // Should never happen. + throw new RuntimeException("LocalLDAModel.logPerplexity was called," + + " but the underlying model is missing.") + } + + /** + * Return the topics described by their top-weighted terms. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * Default value of 10. + * @return Local DataFrame with one topic per Row, with columns: + * - "topic": IntegerType: topic index + * - "termIndices": ArrayType(IntegerType): term indices, sorted in order of decreasing + * term importance + * - "termWeights": ArrayType(DoubleType): corresponding sorted term weights + */ + @Since("1.6.0") + def describeTopics(maxTermsPerTopic: Int): DataFrame = { + val topics = getModel.describeTopics(maxTermsPerTopic).zipWithIndex.map { + case ((termIndices, termWeights), topic) => + (topic, termIndices.toSeq, termWeights.toSeq) + } + sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + } + + @Since("1.6.0") + def describeTopics(): DataFrame = describeTopics(10) +} + + +/** + * :: Experimental :: + * + * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM). + * + * This model stores the inferred topics, the full training dataset, and the topic distribution + * for each training document. + */ +@Since("1.6.0") +@Experimental +class DistributedLDAModel private[ml] ( + uid: String, + vocabSize: Int, + private val oldDistributedModel: OldDistributedLDAModel, + sqlContext: SQLContext) + extends LDAModel(uid, vocabSize, None, sqlContext) { + + /** + * Convert this distributed model to a local representation. This discards info about the + * training dataset. + */ + @Since("1.6.0") + def toLocal: LDAModel = { + if (oldLocalModel.isEmpty) { + oldLocalModel = Some(oldDistributedModel.toLocal) + } + new LDAModel(uid, vocabSize, oldLocalModel, sqlContext) + } + + @Since("1.6.0") + override protected def getModel: OldLDAModel = oldDistributedModel + + @Since("1.6.0") + override def copy(extra: ParamMap): DistributedLDAModel = { + val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext) + if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel + copyValues(copied, extra).setParent(parent) + copied + } + + @Since("1.6.0") + override def topicsMatrix: Matrix = { + if (oldLocalModel.isEmpty) { + oldLocalModel = Some(oldDistributedModel.toLocal) + } + super.topicsMatrix + } + + @Since("1.6.0") + override def isDistributed: Boolean = true + + @Since("1.6.0") + override def logLikelihood(dataset: DataFrame): Double = { + if (oldLocalModel.isEmpty) { + oldLocalModel = Some(oldDistributedModel.toLocal) + } + super.logLikelihood(dataset) + } + + @Since("1.6.0") + override def logPerplexity(dataset: DataFrame): Double = { + if (oldLocalModel.isEmpty) { + oldLocalModel = Some(oldDistributedModel.toLocal) + } + super.logPerplexity(dataset) + } + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) + * + * Notes: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + * - This is computed from the topic distributions computed during training. If you call + * [[logLikelihood()]] on the same training dataset, the topic distributions will be computed + * again, possibly giving different results. + */ + @Since("1.6.0") + lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | Dirichlet hyperparameters) + */ + @Since("1.6.0") + lazy val logPrior: Double = oldDistributedModel.logPrior +} + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "term" = "word": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over terms representing some concept + * - "document": one piece of text, corresponding to one row in the input data + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * + * Input data (featuresCol): + * LDA is given a collection of documents as input data, via the featuresCol parameter. + * Each document is specified as a [[Vector]] of length vocabSize, where each entry is the + * count for the corresponding term (word) in the document. Feature transformers such as + * [[org.apache.spark.ml.feature.Tokenizer]] and [[org.apache.spark.ml.feature.CountVectorizer]] + * can be useful for converting text to word count vectors. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Since("1.6.0") +@Experimental +class LDA @Since("1.6.0") ( + @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("lda")) + + setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, + learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, + optimizeDocConcentration -> true) + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("1.6.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.6.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Array[Double]): this.type = set(docConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Double): this.type = set(docConcentration, Array(value)) + + /** @group setParam */ + @Since("1.6.0") + def setTopicConcentration(value: Double): this.type = set(topicConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setOptimizer(value: String): this.type = set(optimizer, value) + + /** @group setParam */ + @Since("1.6.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningOffset(value: Double): this.type = set(learningOffset, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningDecay(value: Double): this.type = set(learningDecay, value) + + /** @group setParam */ + @Since("1.6.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + + @Since("1.6.0") + override def copy(extra: ParamMap): LDA = defaultCopy(extra) + + @Since("1.6.0") + override def fit(dataset: DataFrame): LDAModel = { + transformSchema(dataset.schema, logging = true) + val oldLDA = new OldLDA() + .setK($(k)) + .setDocConcentration(getOldDocConcentration) + .setTopicConcentration(getOldTopicConcentration) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setCheckpointInterval($(checkpointInterval)) + .setOptimizer(getOldOptimizer) + // TODO: persist here, or in old LDA? + val oldData = LDA.getOldDataset(dataset, $(featuresCol)) + val oldModel = oldLDA.run(oldData) + val newModel = oldModel match { + case m: OldLocalLDAModel => + new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext) + case m: OldDistributedLDAModel => + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + } + copyValues(newModel).setParent(this) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + + +private[clustering] object LDA { + + /** Get dataset for spark.mllib LDA */ + def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + dataset + .withColumn("docId", monotonicallyIncreasingId()) + .select("docId", featuresCol) + .map { case Row(docId: Long, features: Vector) => + (docId, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 31d8a9fdea1c6..cd520f09bd466 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -183,8 +183,7 @@ abstract class LDAModel private[clustering] extends Saveable { /** * Local LDA model. * This model stores only the inferred topics. - * It may be used for computing topics for new documents, but it may give less accurate answers - * than the [[DistributedLDAModel]]. + * * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") @@ -353,7 +352,7 @@ class LocalLDAModel private[clustering] ( documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { - (id, Vectors.zeros(k)) + (id, Vectors.zeros(k)) } else { val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, @@ -366,6 +365,28 @@ class LocalLDAModel private[clustering] ( } } + /** Get a method usable as a UDF for [[topicDistributions()]] */ + private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = sc.broadcast(expElogbeta) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + (termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + Vectors.zeros(k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbetaBc.value, + docConcentrationBrz, + gammaShape, + k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + /** * Java-friendly version of [[topicDistributions]] */ @@ -477,8 +498,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { /** * Distributed LDA model. * This model stores the inferred topics, the full training dataset, and the topic distributions. - * When computing topics for new documents, it may give more accurate answers - * than the [[LocalLDAModel]]. */ @Since("1.3.0") class DistributedLDAModel private[clustering] ( diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala new file mode 100644 index 0000000000000..edb927495e8bf --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +object LDASuite { + def generateLDAData( + sql: SQLContext, + rows: Int, + k: Int, + vocabSize: Int): DataFrame = { + val avgWC = 1 // average instances of each word in a doc + val sc = sql.sparkContext + val rng = new java.util.Random() + rng.setSeed(1) + val rdd = sc.parallelize(1 to rows).map { i => + Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) + }.map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } +} + + +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { + + val k: Int = 5 + val vocabSize: Int = 30 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + } + + test("default parameters") { + val lda = new LDA() + + assert(lda.getFeaturesCol === "features") + assert(lda.getMaxIter === 20) + assert(lda.isDefined(lda.seed)) + assert(lda.getCheckpointInterval === 10) + assert(lda.getK === 10) + assert(!lda.isSet(lda.docConcentration)) + assert(!lda.isSet(lda.topicConcentration)) + assert(lda.getOptimizer === "online") + assert(lda.getLearningDecay === 0.51) + assert(lda.getLearningOffset === 1024) + assert(lda.getSubsamplingRate === 0.05) + assert(lda.getOptimizeDocConcentration) + assert(lda.getTopicDistributionCol === "topicDistribution") + } + + test("set parameters") { + val lda = new LDA() + .setFeaturesCol("test_feature") + .setMaxIter(33) + .setSeed(123) + .setCheckpointInterval(7) + .setK(9) + .setTopicConcentration(0.56) + .setTopicDistributionCol("myOutput") + + assert(lda.getFeaturesCol === "test_feature") + assert(lda.getMaxIter === 33) + assert(lda.getSeed === 123) + assert(lda.getCheckpointInterval === 7) + assert(lda.getK === 9) + assert(lda.getTopicConcentration === 0.56) + assert(lda.getTopicDistributionCol === "myOutput") + + + // setOptimizer + lda.setOptimizer("em") + assert(lda.getOptimizer === "em") + lda.setOptimizer("online") + assert(lda.getOptimizer === "online") + lda.setLearningDecay(0.53) + assert(lda.getLearningDecay === 0.53) + lda.setLearningOffset(1027) + assert(lda.getLearningOffset === 1027) + lda.setSubsamplingRate(0.06) + assert(lda.getSubsamplingRate === 0.06) + lda.setOptimizeDocConcentration(false) + assert(!lda.getOptimizeDocConcentration) + } + + test("parameters validation") { + val lda = new LDA() + + // misc Params + intercept[IllegalArgumentException] { + new LDA().setK(1) + } + intercept[IllegalArgumentException] { + new LDA().setOptimizer("no_such_optimizer") + } + intercept[IllegalArgumentException] { + new LDA().setDocConcentration(-1.1) + } + intercept[IllegalArgumentException] { + new LDA().setTopicConcentration(-1.1) + } + + // validateParams() + lda.validateParams() + lda.setDocConcentration(1.1) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) + withClue("LDA docConcentration validity check failed for bad array length") { + intercept[IllegalArgumentException] { + lda.validateParams() + } + } + + // Online LDA + intercept[IllegalArgumentException] { + new LDA().setLearningOffset(0) + } + intercept[IllegalArgumentException] { + new LDA().setLearningDecay(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(1.1) + } + } + + test("fit & transform with Online LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) + val model = lda.fit(dataset) + + MLTestingUtils.checkCopy(model) + + assert(!model.isInstanceOf[DistributedLDAModel]) + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(!model.isDistributed) + + // transform() + val transformed = model.transform(dataset) + val expectedColumns = Array("features", lda.getTopicDistributionCol) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + transformed.select(lda.getTopicDistributionCol).collect().foreach { r => + val topicDistribution = r.getAs[Vector](0) + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + } + + // logLikelihood, logPerplexity + val ll = model.logLikelihood(dataset) + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPerplexity(dataset) + assert(lp >= 0.0 && lp != Double.PositiveInfinity) + + // describeTopics + val topics = model.describeTopics(3) + assert(topics.count() === k) + assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + topics.select("termIndices").collect().foreach { case r: Row => + val termIndices = r.getAs[Seq[Int]](0) + assert(termIndices.length === 3 && termIndices.toSet.size === 3) + } + topics.select("termWeights").collect().foreach { case r: Row => + val termWeights = r.getAs[Seq[Double]](0) + assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0)) + } + } + + test("fit & transform with EM LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val model_ = lda.fit(dataset) + + MLTestingUtils.checkCopy(model_) + + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(model.isDistributed) + + val localModel = model.toLocal + assert(!localModel.isInstanceOf[DistributedLDAModel]) + + // training logLikelihood, logPrior + val ll = model.trainingLogLikelihood + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPrior + assert(lp <= 0.0 && lp != Double.NegativeInfinity) + } +} From 3121e78168808c015fb21da8b0d44bb33649fb81 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 10 Nov 2015 16:25:22 -0800 Subject: [PATCH 288/324] [SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1 and update toString of Exchange https://issues.apache.org/jira/browse/SPARK-9830 This is the follow-up pr for https://github.com/apache/spark/pull/9556 to address davies' comments. Author: Yin Huai Closes #9607 from yhuai/removeAgg1-followup. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 58 +++++--- .../expressions/aggregate/Average.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Stddev.scala | 2 +- .../catalyst/expressions/aggregate/Sum.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 127 ++++++++++++++---- .../scala/org/apache/spark/sql/SQLConf.scala | 1 + .../apache/spark/sql/execution/Exchange.scala | 8 +- .../apache/spark/sql/execution/commands.scala | 10 ++ 10 files changed, 160 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b1e14390b7dc0..a9cd9a77038e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -532,7 +532,7 @@ class Analyzer( case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) // We get an aggregate function, we need to wrap it in an AggregateExpression. - case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct) + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8322e9930cd5a..5a4b0c1e39ce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -110,17 +110,21 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => - // TODO: Is it possible that the child of a agg function is another - // agg function? - aggExpr.aggregateFunction.children.foreach { - // This is just a sanity check, our analysis rule PullOutNondeterministic should - // already pull out those nondeterministic expressions and evaluate them in - // a Project node. - case child if !child.deterministic => + aggExpr.aggregateFunction.children.foreach { child => + child.foreach { + case agg: AggregateExpression => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { failAnalysis( s"nondeterministic expression ${expr.prettyString} should not " + s"appear in the arguments of an aggregate function.") - case child => // OK + } } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( @@ -133,19 +137,33 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkSupportedGroupingDataType( + expressionString: String, + dataType: DataType): Unit = dataType match { + case BinaryType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in binary type or its inner field is " + + s"in binary type") + case a: ArrayType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in array type or its inner field is " + + s"in array type") + case m: MapType => + failAnalysis(s"expression $expressionString cannot be used in " + + s"grouping expression because it is in map type or its inner field is " + + s"in map type") + case s: StructType => + s.fields.foreach { f => + checkSupportedGroupingDataType(expressionString, f.dataType) + } + case udt: UserDefinedType[_] => + checkSupportedGroupingDataType(expressionString, udt.sqlType) + case _ => // OK + } + def checkValidGroupingExprs(expr: Expression): Unit = { - expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case a: ArrayType => - failAnalysis(s"array type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK - } + checkSupportedGroupingDataType(expr.prettyString, expr.dataType) + if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 7f9e5034702e9..94ac4bf09b90b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -34,7 +34,7 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function average") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 984ce7f24dacc..de5872ab11eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 5b9eb7ae02f25..2748009623355 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function stddev") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index c005ec9657211..cfb042e0aa782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -32,7 +32,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function sum") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5a2368e329976..2e7c3bd67b554 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ +import scala.beans.{BeanProperty, BeanInfo} + +@BeanInfo +private[sql] case class GroupableData(@BeanProperty data: Int) + +private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { + + override def sqlType: DataType = IntegerType + + override def serialize(obj: Any): Int = { + obj match { + case groupableData: GroupableData => groupableData.data + } + } + + override def deserialize(datum: Any): GroupableData = { + datum match { + case data: Int => GroupableData(data) + } + } + + override def userClass: Class[GroupableData] = classOf[GroupableData] + + private[spark] override def asNullable: GroupableUDT = this +} + +@BeanInfo +private[sql] case class UngroupableData(@BeanProperty data: Array[Int]) + +private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { + + override def sqlType: DataType = ArrayType(IntegerType) + + override def serialize(obj: Any): ArrayData = { + obj match { + case groupableData: UngroupableData => new GenericArrayData(groupableData.data) + } + } + + override def deserialize(datum: Any): UngroupableData = { + datum match { + case data: Array[Int] => UngroupableData(data) + } + } + + override def userClass: Class[UngroupableData] = classOf[UngroupableData] + + private[spark] override def asNullable: UngroupableUDT = this +} + case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest { assert(error.message.contains("Conflicting attributes")) } - test("aggregation can't work on binary and map types") { - val plan = - Aggregate( - AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + test("check grouping expression data types") { + def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = { + val plan = + Aggregate( + AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", dataType)(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + shouldSuccess match { + case true => + assertAnalysisSuccess(plan, true) + case false => + assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil) + } - assertAnalysisError(plan, - "binary type expression a cannot be used in grouping expression" :: Nil) + } - val plan2 = - Aggregate( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + val supportedDataTypes = Seq( + StringType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", StringType, nullable = true), + new GroupableUDT()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) + } - assertAnalysisError(plan2, - "map type expression a cannot be used in grouping expression" :: Nil) + val unsupportedDataTypes = Seq( + BinaryType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new UngroupableUDT()) + unsupportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = false) + } + } - val plan3 = + test("we should fail analysis when we find nested aggregate functions") { + val plan = Aggregate( - AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil, - Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil, LocalRelation( - AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)), + AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - assertAnalysisError(plan3, - "array type expression a cannot be used in grouping expression" :: Nil) + assertAnalysisError( + plan, + "It is not allowed to use an aggregate function in the argument of " + + "another aggregate function." :: Nil) } test("Join can't work on binary and map types") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 89e196c066007..57d7d30e0eca2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -474,6 +474,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index a4ce328c1a9eb..b733b26987bcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -44,14 +44,14 @@ case class Exchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - "Shuffle" + s"(coordinator id: ${System.identityHashCode(coordinator)})" case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => - "May shuffle" - case None => "Shuffle without coordinator" + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" } val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" - s"$simpleNodeName($extraInfo)" + s"${simpleNodeName}${extraInfo}" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e5f60b15e7359..8b2755a58757c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + + s"continue to be true.") + Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { From 21c562fa03430365f5c2b7d6de1f8f60ab2140d4 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 10 Nov 2015 16:28:21 -0800 Subject: [PATCH 289/324] [SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up (3) This PR is a 2nd follow-up for [SPARK-9241](https://issues.apache.org/jira/browse/SPARK-9241). It contains the following improvements: * Fix for a potential bug in distinct child expression and attribute alignment. * Improved handling of duplicate distinct child expressions. * Added test for distinct UDAF with multiple children. cc yhuai Author: Herman van Hovell Closes #9566 from hvanhovell/SPARK-9241-followup-2. --- .../DistinctAggregationRewriter.scala | 9 ++-- .../execution/AggregationQuerySuite.scala | 41 +++++++++++++++++-- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 397eff05686b6..c0c960471a61a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap - val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) @@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrMap(x)) + evalWithinGroup(id, distinctAggChildAttrLookup(x)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6bf2c53440baf..8253921563b3a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class LongProductSum extends UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // Register UDAFs sqlContext.udf.register("mydoublesum", new MyDoubleSum) sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { @@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te | count(distinct value2), | sum(distinct value2), | count(distinct value1, value2), + | longProductSum(distinct value1, value2), | count(value1), | sum(value1), | count(value2), | sum(value2), + | longProductSum(value1, value2), | count(*), | count(1) |FROM agg2 |GROUP BY key """.stripMargin), - Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) :: - Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) :: - Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) :: - Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil) + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } test("test count") { From a3989058c0938c8c59c278e7d1a766701cfa255b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 10 Nov 2015 16:32:32 -0800 Subject: [PATCH 290/324] [SPARK-10827][CORE] AppClient should not use `askWithReply` in `receiveAndReply` Changed AppClient to be non-blocking in `receiveAndReply` by using a separate thread to wait for response and reply to the context. The threads are managed by a thread pool. Also added unit tests for the AppClient interface. Author: Bryan Cutler Closes #9317 from BryanCutler/appClient-receiveAndReply-SPARK-10827. --- .../spark/deploy/client/AppClient.scala | 33 ++- .../spark/deploy/client/AppClientSuite.scala | 209 ++++++++++++++++++ 2 files changed, 238 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 25ea6925434ab..3f29da663b798 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -49,8 +49,8 @@ private[spark] class AppClient( private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var endpoint: RpcEndpointRef = null - private var appId: String = null + @volatile private var endpoint: RpcEndpointRef = null + @volatile private var appId: String = null @volatile private var registered = false private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint @@ -77,6 +77,11 @@ private[spark] class AppClient( private val registrationRetryThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") + // A thread pool to perform receive then reply actions in a thread so as not to block the + // event loop. + private val askAndReplyThreadPool = + ThreadUtils.newDaemonCachedThreadPool("appclient-receive-and-reply-threadpool") + override def onStart(): Unit = { try { registerWithMaster(1) @@ -200,7 +205,7 @@ private[spark] class AppClient( case r: RequestExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](r)) + case Some(m) => askAndReplyAsync(m, context, r) case None => logWarning("Attempted to request executors before registering with Master.") context.reply(false) @@ -208,13 +213,32 @@ private[spark] class AppClient( case k: KillExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](k)) + case Some(m) => askAndReplyAsync(m, context, k) case None => logWarning("Attempted to kill executors before registering with Master.") context.reply(false) } } + private def askAndReplyAsync[T]( + endpointRef: RpcEndpointRef, + context: RpcCallContext, + msg: T): Unit = { + // Create a thread to ask a message and reply with the result. Allow thread to be + // interrupted during shutdown, otherwise context must be notified of NonFatal errors. + askAndReplyThreadPool.execute(new Runnable { + override def run(): Unit = { + try { + context.reply(endpointRef.askWithRetry[Boolean](msg)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(t) => + context.sendFailure(t) + } + } + }) + } + override def onDisconnected(address: RpcAddress): Unit = { if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") @@ -252,6 +276,7 @@ private[spark] class AppClient( registrationRetryThread.shutdownNow() registerMasterFutures.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() + askAndReplyThreadPool.shutdownNow() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala new file mode 100644 index 0000000000000..1e5c05a73f8aa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, Master} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.Utils + +/** + * End-to-end tests for application client in standalone mode. + */ +class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { + private val numWorkers = 2 + private val conf = new SparkConf() + private val securityManager = new SecurityManager(conf) + + private var masterRpcEnv: RpcEnv = null + private var workerRpcEnvs: Seq[RpcEnv] = null + private var master: Master = null + private var workers: Seq[Worker] = null + + /** + * Start the local cluster. + * Note: local-cluster mode is insufficient because we want a reference to the Master. + */ + override def beforeAll(): Unit = { + super.beforeAll() + masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) + workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) + } + master = makeMaster() + workers = makeWorkers(10, 2048) + // Wait until all workers register with master successfully + eventually(timeout(60.seconds), interval(10.millis)) { + assert(getMasterState.workers.size === numWorkers) + } + } + + override def afterAll(): Unit = { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + super.afterAll() + } + + test("interface methods of AppClient using local Master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + ci.client.start() + + // Client should connect with one Master which registers the application + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.connectedIdList.size === 1, "client listener should have one connection") + assert(apps.size === 1, "master should have 1 registered app") + } + + // Send message to Master to request Executors, verify request by change in executor limit + val numExecutorsRequested = 1 + assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") + } + + // Send request to kill executor, verify request was made + assert { + val apps = getApplications() + val executorId: String = apps.head.executors.head._2.fullId + ci.client.killExecutors(Seq(executorId)) + } + + // Issue stop command for Client to disconnect from Master + ci.client.stop() + + // Verify Client is marked dead and unregistered from Master + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.deadReasonList.size === 1, "client should have been marked dead") + assert(apps.isEmpty, "master should have 0 registered apps") + } + } + + test("request from AppClient before initialized with master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + // requests to master should fail immediately + assert(ci.client.requestTotalExecutors(3) === false) + } + + // =============================== + // | Utility methods for testing | + // =============================== + + /** Return a SparkConf for applications that want to talk to our Master. */ + private def appConf: SparkConf = { + new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .set("spark.executor.memory", "256m") + } + + /** Make a master to which our application will send executor requests. */ + private def makeMaster(): Master = { + val master = new Master(masterRpcEnv, masterRpcEnv.address, 0, securityManager, conf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + /** Make a few workers that talk to our master. */ + private def makeWorkers(cores: Int, memory: Int): Seq[Worker] = { + (0 until numWorkers).map { i => + val rpcEnv = workerRpcEnvs(i) + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + worker + } + } + + /** Get the Master state */ + private def getMasterState: MasterStateResponse = { + master.self.askWithRetry[MasterStateResponse](RequestMasterState) + } + + /** Get the applictions that are active from Master */ + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + /** Application Listener to collect events */ + private class AppClientCollector extends AppClientListener with Logging { + val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + @volatile var disconnectedCount: Int = 0 + val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + + def connected(id: String): Unit = { + connectedIdList += id + } + + def disconnected(): Unit = { + synchronized { + disconnectedCount += 1 + } + } + + def dead(reason: String): Unit = { + deadReasonList += reason + } + + def executorAdded( + id: String, + workerId: String, + hostPort: String, + cores: Int, + memory: Int): Unit = { + execAddedList += id + } + + def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + execRemovedList += id + } + } + + /** Create AppClient and supporting objects */ + private class AppClientInst(masterUrl: String) { + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, securityManager) + private val cmd = new Command(TestExecutor.getClass.getCanonicalName.stripSuffix("$"), + List(), Map(), Seq(), Seq(), Seq()) + private val desc = new ApplicationDescription("AppClientSuite", Some(1), 512, cmd, "ignored") + val listener = new AppClientCollector + val client = new AppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) + } + +} From c0e48dfa611fa5d94132af7e6f6731f60ab833da Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 10 Nov 2015 16:42:28 -0800 Subject: [PATCH 291/324] [SPARK-11566] [MLLIB] [PYTHON] Refactoring GaussianMixtureModel.gaussians in Python cc jkbradley Author: Yu ISHIKAWA Closes #9534 from yu-iskw/SPARK-11566. --- .../python/GaussianMixtureModelWrapper.scala | 21 ++++++------------- python/pyspark/mllib/clustering.py | 2 +- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 0ec88ef77d695..6a3b20c88d2d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,14 +17,11 @@ package org.apache.spark.mllib.api.python -import java.util.{List => JList} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} import org.apache.spark.mllib.clustering.GaussianMixtureModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -36,17 +33,11 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { /** * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian */ - val gaussians: JList[Object] = { - val modelGaussians = model.gaussians - var i = 0 - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - while (i < k) { - mu += modelGaussians(i).mu - sigma += modelGaussians(i).sigma - i += 1 + val gaussians: Array[Byte] = { + val modelGaussians = model.gaussians.map { gaussian => + Array[Any](gaussian.mu, gaussian.sigma) } - List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 1fa061dc2da99..c9e6f1dec6bf8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -266,7 +266,7 @@ def gaussians(self): """ return [ MultivariateGaussian(gaussian[0], gaussian[1]) - for gaussian in zip(*self.call("gaussians"))] + for gaussian in self.call("gaussians")] @property @since('1.4.0') From 33112f9c48680c33d663978f76806ebf0ea39789 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 10 Nov 2015 16:50:22 -0800 Subject: [PATCH 292/324] [SPARK-10192][CORE] simple test w/ failure involving a shared dependency just trying to increase test coverage in the scheduler, this already works. It includes a regression test for SPARK-9809 copied some test utils from https://github.com/apache/spark/pull/5636, we can wait till that is merged first Author: Imran Rashid Closes #8402 from squito/test_retry_in_shared_shuffle_dep. --- .../spark/scheduler/DAGSchedulerSuite.scala | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3816b8c4a09aa..068b49bd5844b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -594,11 +594,17 @@ class DAGSchedulerSuite * @param stageId - The current stageId * @param attemptIdx - The current attempt count */ - private def completeNextResultStageWithSuccess(stageId: Int, attemptIdx: Int): Unit = { + private def completeNextResultStageWithSuccess( + stageId: Int, + attemptIdx: Int, + partitionToResult: Int => Int = _ => 42): Unit = { val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) - complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map(_ => (Success, 42)).toSeq) + val taskResults = stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (Success, partitionToResult(idx)) + } + complete(stageAttempt, taskResults.toSeq) } /** @@ -1054,6 +1060,47 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * Run two jobs, with a shared dependency. We simulate a fetch failure in the second job, which + * requires regenerating some outputs of the shared dependency. One key aspect of this test is + * that the second job actually uses a different stage for the shared dependency (a "skipped" + * stage). + */ + test("shuffle fetch failure in a reused shuffle dependency") { + // Run the first job successfully, which creates one shuffle dependency + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(0, 0, 2) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + // the scheduler now creates a new task set to regenerate the missing map output, but this time + // using a different stage, the "skipped" one + + // SPARK-9809 -- this stage is submitted without a task for each partition (because some of + // the shuffle map output is still available from stage 0); make sure we've still got internal + // accumulators setup + assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + } + /** * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we * have completions from both the first & second attempt of stage 1. So all the map output is From 3e0a6cf1e02a19b37c68d3026415d53bb57a576b Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 10 Nov 2015 16:51:25 -0800 Subject: [PATCH 293/324] [SPARK-11572] Exit AsynchronousListenerBus thread when stop() is called As vonnagy reported in the following thread: http://search-hadoop.com/m/q3RTtk982kvIow22 Attempts to join the thread in AsynchronousListenerBus resulted in lock up because AsynchronousListenerBus thread was still getting messages `SparkListenerExecutorMetricsUpdate` from the DAGScheduler Author: tedyu Closes #9546 from ted-yu/master. --- .../org/apache/spark/util/AsynchronousListenerBus.scala | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 61b5a4cecddce..b8481eabc7618 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -67,15 +67,12 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri processingEvent = true } try { - val event = eventQueue.poll - if (event == null) { + if (stopped.get()) { // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } return } + val event = eventQueue.poll + assert(event != null, "event queue was empty but the listener bus was not stopped") postToAll(event) } finally { self.synchronized { From 900917541651abe7125f0d205085d2ab6a00d92c Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 10 Nov 2015 16:52:26 -0800 Subject: [PATCH 294/324] [SPARK-11615] Drop @VisibleForTesting annotation See http://search-hadoop.com/m/q3RTtjpe8r1iRbTj2 for discussion. Summary: addition of VisibleForTesting annotation resulted in spark-shell malfunctioning. Author: tedyu Closes #9585 from tedyu/master. --- .../src/main/scala/org/apache/spark/rpc/netty/Inbox.scala | 8 ++++---- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 2 -- .../org/apache/spark/util/AsynchronousListenerBus.scala | 5 ++--- .../org/apache/spark/util/collection/ExternalSorter.scala | 3 +-- scalastyle-config.xml | 7 +++++++ .../org/apache/spark/sql/execution/QueryExecution.scala | 3 --- .../spark/network/shuffle/ShuffleTestAccessor.scala | 1 - 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index c72b588db57fe..464027f07cc88 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -21,8 +21,6 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import com.google.common.annotations.VisibleForTesting - import org.apache.spark.{Logging, SparkException} import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} @@ -193,8 +191,10 @@ private[netty] class Inbox( def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - /** Called when we are dropping a message. Test cases override this to test message dropping. */ - @VisibleForTesting + /** + * Called when we are dropping a message. Test cases override this to test message dropping. + * Exposed for testing. + */ protected def onDrop(message: InboxMessage): Unit = { logWarning(s"Drop $message because $endpointRef is stopped") } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 77d034fa5ba2c..ca37829216f22 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -21,8 +21,6 @@ import java.util.concurrent.TimeoutException import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import com.google.common.annotations.VisibleForTesting - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index b8481eabc7618..b3b54af972cb4 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import com.google.common.annotations.VisibleForTesting import org.apache.spark.SparkContext /** @@ -119,8 +118,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri * For testing only. Wait until there are no more events in the queue, or until the specified * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue * emptied. + * Exposed for testing. */ - @VisibleForTesting @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { val finishTime = System.currentTimeMillis + timeoutMillis @@ -137,8 +136,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Return whether the listener daemon thread is still alive. + * Exposed for testing. */ - @VisibleForTesting def listenerThreadIsAlive: Boolean = listenerThread.isAlive /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index a44e72b7c16d3..bd6844d045cad 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,7 +23,6 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ @@ -608,8 +607,8 @@ private[spark] class ExternalSorter[K, V, C]( * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. + * Exposed for testing. */ - @VisibleForTesting def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 64a0c71bbef2a..050c3f360476f 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,13 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + @VisibleForTesting + + + Class\.forName Date: Tue, 10 Nov 2015 16:54:06 -0800 Subject: [PATCH 295/324] [SPARK-11361][STREAMING] Show scopes of RDD operations inside DStream.foreachRDD and DStream.transform in DAG viz Currently, when a DStream sets the scope for RDD generated by it, that scope is not allowed to be overridden by the RDD operations. So in case of `DStream.foreachRDD`, all the RDDs generated inside the foreachRDD get the same scope - `foreachRDD + + org.apache.xbean + xbean-asm5-shaded + org.apache.hadoop hadoop-client diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 1b49dca9dc78b..e27d2e6c94f7b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -21,8 +21,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -325,11 +325,11 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM4) { +private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name.contains("apply")) { - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { throw new ReturnStatementInClosureException @@ -337,7 +337,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { } } } else { - new MethodVisitor(ASM4) {} + new MethodVisitor(ASM5) {} } } } @@ -361,7 +361,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { override def visitMethod( access: Int, @@ -376,7 +376,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -385,7 +385,8 @@ private[util] class FieldAccessFinder( } } - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { // Check for calls a getter method for a variable in an interpreter wrapper object. // This means that the corresponding field will be accessed, so we should save it. @@ -408,7 +409,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -423,9 +424,9 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? diff --git a/docs/building-spark.md b/docs/building-spark.md index 4f73adb85446c..3d38edbdad4bc 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -190,6 +190,10 @@ Running only Java 8 tests and nothing else. mvn install -DskipTests -Pjava8-tests +or + + sbt -Pjava8-tests java8-tests/test + Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. For these tests to run your system must have a JDK 8 installation. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. diff --git a/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala new file mode 100644 index 0000000000000..fa0681db41088 --- /dev/null +++ b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Test cases where JDK8-compiled Scala user code is used with Spark. + */ +class JDK8ScalaSuite extends SparkFunSuite with SharedSparkContext { + test("basic RDD closure test (SPARK-6152)") { + sc.parallelize(1 to 1000).map(x => x * x).count() + } +} diff --git a/graphx/pom.xml b/graphx/pom.xml index 987b831021a54..8cd66c5b2e826 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -47,6 +47,10 @@ test-jar test + + org.apache.xbean + xbean-asm5-shaded + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 74a7de18d4161..a6d0cb6409664 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,11 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.spark.util.Utils - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm5.Opcodes._ +import org.apache.spark.util.Utils /** * Includes an utility function to test whether a function accesses a specific attribute @@ -107,18 +106,19 @@ private[graphx] object BytecodeUtils { * MethodInvocationFinder("spark/graph/Foo", "test") * its methodsInvoked variable will contain the set of methods invoked directly by * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual metod invoked by inspecting the bytecode. + * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) diff --git a/pom.xml b/pom.xml index c499a80aa0f43..01afa80617891 100644 --- a/pom.xml +++ b/pom.xml @@ -393,6 +393,14 @@ + + + org.apache.xbean + xbean-asm5-shaded + 4.4 + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 004941d5f50ae..3d2d235a00c93 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -23,15 +23,14 @@ import java.net.{HttpURLConnection, URI, URL, URLEncoder} import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils import org.apache.spark.util.ParentClassLoader -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ - /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. @@ -192,7 +191,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM4, cv) { +extends ClassVisitor(ASM5, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) @@ -202,7 +201,7 @@ extends ClassVisitor(ASM4, cv) { // field in the class to point to it, but do nothing otherwise. mv.visitCode() mv.visitVarInsn(ALOAD, 0) // load this - mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) mv.visitVarInsn(ALOAD, 0) // load this // val classType = className.replace('.', '/') // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c96855e261ee8..9fd6b5a07ec86 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -110,6 +110,11 @@ mockito-core test + + org.apache.xbean + xbean-asm5-shaded + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4b4f5c6c45c7a..97162249d9951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -21,8 +21,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -41,22 +41,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { l += 1L l.add(1L) } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") } test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = sparkContext.accumulator(0L) val f = () => { l += 1L } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } /** @@ -486,7 +484,7 @@ private class BoxingFinder( method: MethodIdentifier[_] = null, val boxingInvokes: mutable.Set[String] = mutable.Set.empty, visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { private val primitiveBoxingClassName = Set("java/lang/Long", @@ -503,11 +501,12 @@ private class BoxingFinder( MethodVisitor = { if (method != null && (method.name != name || method.desc != desc)) { // If method is specified, skip other methods. - return new MethodVisitor(ASM4) {} + return new MethodVisitor(ASM5) {} } - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { if (primitiveBoxingClassName.contains(owner)) { // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) @@ -522,10 +521,9 @@ private class BoxingFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => - visitedMethods += m - cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) - } + val cl = BoxingFinder.getClassReader(classOfMethodOwner) + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) } } } @@ -535,22 +533,14 @@ private class BoxingFinder( private object BoxingFinder { - def getClassReader(cls: Class[_]): Option[ClassReader] = { + def getClassReader(cls: Class[_]): ClassReader = { val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) val baos = new ByteArrayOutputStream(128) // Copy data over, before delegating to ClassReader - // else we can run out of open file handles. Utils.copyStream(resourceStream, baos, true) - // ASM4 doesn't support Java 8 classes, which requires ASM5. - // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), - // then ClassReader will throw IllegalArgumentException, - // However, since this is only for testing, it's safe to skip these classes. - try { - Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) - } catch { - case _: IllegalArgumentException => None - } + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } } From 27029bc8f6246514bd0947500c94cf38dc8616c3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 11 Nov 2015 11:24:55 -0800 Subject: [PATCH 312/324] [SPARK-11639][STREAMING][FLAKY-TEST] Implement BlockingWriteAheadLog for testing the BatchedWriteAheadLog Several elements could be drained if the main thread is not fast enough. zsxwing warned me about a similar problem, but missed it here :( Submitting the fix using a waiter. cc tdas Author: Burak Yavuz Closes #9605 from brkyvz/fix-flaky-test. --- .../streaming/util/BatchedWriteAheadLog.scala | 3 + .../streaming/util/WriteAheadLogSuite.scala | 124 +++++++++++------- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 9727ed2ba1445..6e6ed8d819721 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -182,6 +182,9 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp buffer.clear() } } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() } /** Static methods for aggregating and de-aggregating records. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index e96f4c2a29347..9e13f25c2efea 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util.concurrent.{ExecutionException, ThreadPoolExecutor} -import java.util.concurrent.atomic.AtomicInteger +import java.util.{Iterator => JIterator} +import java.util.concurrent.ThreadPoolExecutor import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.util.{Failure, Success} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -37,12 +36,12 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ThreadUtils, ManualClock, Utils} -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( @@ -315,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( allowBatching = true, closeFileAfterWrite = false, - "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually { + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { import BatchedWriteAheadLog._ import WriteAheadLogSuite._ @@ -326,6 +329,8 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( private var walBatchingExecutionContext: ExecutionContextExecutorService = _ private val sparkConf = new SparkConf() + private val queueLength = PrivateMethod[Int]('getQueueLength) + override def beforeEach(): Unit = { wal = mock[WriteAheadLog] walHandle = mock[WriteAheadLogRecordHandle] @@ -366,7 +371,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } // we make the write requests in separate threads so that we don't block the test thread - private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { val p = Promise[Unit]() p.completeWith(Future { val v = wal.write(event, time) @@ -375,28 +380,9 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - /** - * In order to block the writes on the writer thread, we mock the write method, and block it - * for some time with a promise. - */ - private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = { - // we would like to block the write so that we can queue requests - val promise = Promise[Any]() - when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( - new Answer[WriteAheadLogRecordHandle] { - override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { - Await.ready(promise.future, 4.seconds) - walHandle - } - } - ) - promise - } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // block the write so that we can batch some records - val promise = writeBlockingPromise(wal) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) val event1 = "hello" val event2 = "world" @@ -406,21 +392,27 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - promiseWriteEvent(batchedWal, event1, 3L) - // rest of the records will be batched while it takes 3 to get written - promiseWriteEvent(batchedWal, event2, 5L) - promiseWriteEvent(batchedWal, event3, 8L) - promiseWriteEvent(batchedWal, event4, 12L) - promiseWriteEvent(batchedWal, event5, 10L) + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + writeAsync(batchedWal, event4, 12L) + writeAsync(batchedWal, event5, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) } - promise.success(true) + blockingWal.allowWrite() val buffer1 = wrapArrayArrayByte(Array(event1)) val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) verify(wal, times(1)).write(meq(buffer1), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. @@ -437,27 +429,32 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } test("BatchedWriteAheadLog - fail everything in queue during shutdown") { - val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) - // block the write so that we can batch some records - writeBlockingPromise(wal) - - val event1 = ("hello", 3L) - val event2 = ("world", 5L) - val event3 = ("this", 8L) - val event4 = ("is", 9L) - val event5 = ("doge", 10L) + val event1 = "hello" + val event2 = "world" + val event3 = "this" // The queue.take() immediately takes the 3, and there is nothing left in the queue at that // moment. Then the promise blocks the writing of 3. The rest get queued. - val writePromises = Seq(event1, event2, event3, event4, event5).map { event => - promiseWriteEvent(batchedWal, event._1, event._2) + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) eventually(timeout(1 second)) { - assert(walBatchingThreadPool.getActiveCount === 5) + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written } + val writePromises = Seq(promise1, promise2, promise3) + batchedWal.close() eventually(timeout(1 second)) { assert(writePromises.forall(_.isCompleted)) @@ -641,4 +638,37 @@ object WriteAheadLogSuite { def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } From df97df2b39194f60051f78cce23f0ba6cfe4b1df Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 11 Nov 2015 12:47:02 -0800 Subject: [PATCH 313/324] [SPARK-11644][SQL] Remove the option to turn off unsafe and codegen. Author: Reynold Xin Closes #9618 from rxin/SPARK-11644. --- .../scala/org/apache/spark/sql/SQLConf.scala | 27 +--- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 1 - .../spark/sql/execution/SparkPlan.scala | 120 +++++++---------- .../spark/sql/execution/SparkPlanner.scala | 4 - .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/aggregate/utils.scala | 13 +- .../apache/spark/sql/execution/commands.scala | 27 ++++ .../spark/sql/execution/joins/HashJoin.scala | 4 +- .../sql/execution/joins/HashOuterJoin.scala | 6 +- .../sql/execution/joins/HashSemiJoin.scala | 9 +- .../sql/execution/joins/SortMergeJoin.scala | 7 +- .../execution/joins/SortMergeOuterJoin.scala | 7 +- .../sql/execution/local/HashJoinNode.scala | 5 +- .../spark/sql/execution/local/LocalNode.scala | 80 +++++------- .../apache/spark/sql/sources/interfaces.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 31 +---- .../spark/sql/DataFrameTungstenSuite.scala | 68 +++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 23 +--- .../sql/execution/TungstenSortSuite.scala | 13 -- .../execution/joins/BroadcastJoinSuite.scala | 4 +- .../execution/local/HashJoinNodeSuite.scala | 23 +--- .../local/NestedLoopJoinNodeSuite.scala | 21 +-- .../execution/metric/SQLMetricsSuite.scala | 123 +++++------------- .../execution/AggregationQuerySuite.scala | 44 +------ .../sql/hive/execution/HiveExplainSuite.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 72 +++++----- 27 files changed, 257 insertions(+), 494 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 57d7d30e0eca2..e02b502b7b4d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -252,24 +252,8 @@ private[spark] object SQLConf { "not be provided to ExchangeCoordinator.", isPublic = false) - val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", - defaultValue = Some(true), - doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + - "manages memory and dynamically generates bytecode for expression evaluation.") - - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.", - isPublic = false) - - val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, use the new optimized Tungsten physical execution backend.", - isPublic = false) - val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", - defaultValue = Some(true), // use CODEGEN_ENABLED as default + defaultValue = Some(true), doc = "When true, common subexpressions will be eliminated.", isPublic = false) @@ -475,6 +459,9 @@ private[spark] object SQLConf { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" + val TUNGSTEN_ENABLED = "spark.sql.tungsten.enabled" + val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" } } @@ -541,14 +528,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) - private[spark] def subexpressionEliminationEnabled: Boolean = - getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled) + getConf(SUBEXPRESSION_ELIMINATION_ENABLED) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b733b26987bcb..d0e4e068092f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -58,7 +58,7 @@ case class Exchange( * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && + GenerateUnsafeProjection.canSupport(child.schema) && !newPartitioning.isInstanceOf[RangePartitioning] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 77843f53b9bd0..5da5aea17e25b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -77,7 +77,6 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} """.stripMargin.trim } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8650ac500b652..1b833002f434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -54,18 +54,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // the value of subexpressionEliminationEnabled will be set by the desserializer after the // constructor has run. - val codegenEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.codegenEnabled - } else { - false - } - val unsafeEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.unsafeEnabled - } else { - false - } val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.subexpressionEliminationEnabled } else { @@ -233,83 +223,63 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) + log.debug(s"Creating Projection: $expressions, inputSchema: $inputSchema") + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } } } protected def newMutableProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) + expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } } } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } } } protected def newOrdering( - order: Seq[SortOrder], - inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - if (codegenEnabled) { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new InterpretedOrdering(order, inputSchema) - } - } - } else { - new InterpretedOrdering(order, inputSchema) + order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new InterpretedOrdering(order, inputSchema) + } } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index a10d1edcc91aa..cf482ae4a05ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -27,10 +27,6 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val sparkContext: SparkContext = sqlContext.sparkContext - def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled - - def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled - def numPartitions: Int = sqlContext.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d65cb1bae7fb5..96242f160aa51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -327,8 +327,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { + if (TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) @@ -368,8 +367,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Project(projectList, child) => // If unsafe mode is enabled and we support these data types in Unsafe, use the // Tungsten project. Otherwise, use the normal project. - if (sqlContext.conf.unsafeEnabled && - UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + if (UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { execution.TungstenProject(projectList, planLater(child)) :: Nil } else { execution.Project(projectList, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 79abf2d5929be..a70e41436c7aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -59,13 +59,10 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // 1. Create an Aggregate Operator for partial aggregations. val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -144,11 +141,9 @@ object Utils { child: SparkPlan): Seq[SparkPlan] = { val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expression. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 8b2755a58757c..e29c281b951f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -121,6 +121,33 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " + + s"will be ignored. Tungsten will continue to be used.") + Seq(Row(SQLConf.Deprecated.TUNGSTEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " + + s"will be ignored. Codegen will continue to be used.") + Seq(Row(SQLConf.Deprecated.CODEGEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " + + s"will be ignored. Unsafe mode will continue to be used.") + Seq(Row(SQLConf.Deprecated.UNSAFE_ENABLED, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7ce4a517838cb..997f7f494f4a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -45,9 +45,7 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) + UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema) } override def outputsUnsafeRows: Boolean = isUnsafeMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 15b06b1537f8c..3633f356b014b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -65,9 +65,9 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) + joinType != FullOuter && + UnsafeProjection.canSupport(buildKeys) && + UnsafeProjection.canSupport(self.schema) } override def outputsUnsafeRows: Boolean = isUnsafeMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index beb141ade616d..c7d13e0a72a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -34,11 +34,10 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema) - && UnsafeProjection.canSupport(right.schema)) + UnsafeProjection.canSupport(leftKeys) && + UnsafeProjection.canSupport(rightKeys) && + UnsafeProjection.canSupport(left.schema) && + UnsafeProjection.canSupport(right.schema) } override def outputsUnsafeRows: Boolean = supportUnsafe diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 17030947b7bbc..7aee8e3dd3fce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -54,10 +54,9 @@ case class SortMergeJoin( requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil protected[this] def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) + UnsafeProjection.canSupport(leftKeys) && + UnsafeProjection.canSupport(rightKeys) && + UnsafeProjection.canSupport(schema) } override def outputsUnsafeRows: Boolean = isUnsafeMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7e854e6702f77..5f1590c463836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -90,10 +90,9 @@ case class SortMergeOuterJoin( } private def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) + UnsafeProjection.canSupport(leftKeys) && + UnsafeProjection.canSupport(rightKeys) && + UnsafeProjection.canSupport(schema) } override def outputsUnsafeRows: Boolean = isUnsafeMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index b1dc719ca8508..aef655727fbbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -46,10 +46,7 @@ trait HashJoinNode { private[this] var joinKeys: Projection = _ protected def isUnsafeMode: Boolean = { - (codegenEnabled && - unsafeEnabled && - UnsafeProjection.canSupport(schema) && - UnsafeProjection.canSupport(streamedKeys)) + UnsafeProjection.canSupport(schema) && UnsafeProjection.canSupport(streamedKeys) } private def streamSideKeyGenerator: Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index f96b62a67a254..d3381eac91d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -35,10 +35,6 @@ import org.apache.spark.sql.types.StructType */ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { - protected val codegenEnabled: Boolean = conf.codegenEnabled - - protected val unsafeEnabled: Boolean = conf.unsafeEnabled - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") /** @@ -111,21 +107,17 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) + s"Creating Projection: $expressions, inputSchema: $inputSchema") + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } } } @@ -133,41 +125,33 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) + s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } } } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 5b8841bc154a5..48de693a999d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -423,8 +423,6 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) - private val codegenEnabled = sqlContext.conf.codegenEnabled - private var _partitionSpec: PartitionSpec = _ private class FileStatusCache { @@ -661,7 +659,6 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema - val codegenEnabled = this.codegenEnabled val needConversion = this.needConversion val requiredOutput = requiredColumns.map { col => @@ -678,11 +675,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled) { + val buildProjection = GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) - } else { - () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) - } val projectedRows = { val mutableProjection = buildProjection() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f3a7aa280367a..e4f23fe17b757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -621,11 +621,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - } + checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -844,31 +840,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) // We will reuse the same Expression object for LocalRelation. - val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) - assert(df.showString(5) == df.showString(5)) + val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df1.showString(5) == df1.showString(5)) } test("SPARK-8609: local DataFrame with random columns should return same value after sort") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) // We will reuse the same Expression object for LocalRelation. val df = (1 to 10).map(Tuple1.apply).toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 7ae12a7895f7e..68e99d6a6b816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -31,52 +31,46 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("test simple types") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") - assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) - } + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } test("test struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val struct = Row(1, 2L, 3.0F, 3.0) - val data = sparkContext.parallelize(Seq(Row(1, struct))) + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sparkContext.parallelize(Seq(Row(1, struct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(struct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) } test("test nested struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val innerStruct = Row(1, "abcd") - val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType) - .add("b5", new StructType() - .add("b5a", IntegerType) - .add("b5b", StringType)) - .add("b6", StringType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(outerStruct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 19e850a46fdfc..acabe32c67bc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -261,8 +261,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen") { - val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) @@ -347,7 +345,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -567,12 +564,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } - } - test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), @@ -1624,12 +1615,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen updates peak execution memory") { - withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { - testCodeGen( - "SELECT key, count(value) FROM testData GROUP BY key", - (1 to 100).map(i => Row(i, 1))) - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) } } @@ -1783,9 +1772,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // This bug will be triggered when Tungsten is enabled and there are multiple // SortMergeJoin operators executed in the same task. val confs = - SQLConf.SORTMERGE_JOIN.key -> "true" :: - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: - SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + SQLConf.SORTMERGE_JOIN.key -> "true" :: SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: Nil withSQLConf(confs: _*) { val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") val df2 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 7a0f0dfd2b7f1..85486c08894c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -31,19 +31,6 @@ import org.apache.spark.sql.types._ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - } - - override def afterAll(): Unit = { - try { - sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) - } finally { - super.afterAll() - } - } - test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index dcbfdca71acb6..5b2998c3c76d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} /** - * Test various broadcast join operators with unsafe enabled. + * Test various broadcast join operators. * * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered @@ -45,8 +45,6 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setAppName("testing") val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) - sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 8c2e78b2a9db7..44b0d9d4102a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -28,12 +28,9 @@ import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRig class HashJoinNodeSuite extends LocalNodeTest { // Test all combinations of the two dimensions: with/out unsafe and build sides - private val maybeUnsafeAndCodegen = Seq(false, true) private val buildSides = Seq(BuildLeft, BuildRight) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - testJoin(unsafeAndCodegen, buildSide) - } + buildSides.foreach { buildSide => + testJoin(buildSide) } /** @@ -45,10 +42,7 @@ class HashJoinNodeSuite extends LocalNodeTest { buildKeys: Seq[Expression], buildNode: LocalNode): HashedRelation = { - val isUnsafeMode = - conf.codegenEnabled && - conf.unsafeEnabled && - UnsafeProjection.canSupport(buildKeys) + val isUnsafeMode = UnsafeProjection.canSupport(buildKeys) val buildSideKeyGenerator = if (isUnsafeMode) { @@ -68,15 +62,10 @@ class HashJoinNodeSuite extends LocalNodeTest { /** * Test inner hash join with varying degrees of matches. */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + private def testJoin(buildSide: BuildSide): Unit = { + val testNamePrefix = buildSide val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) // Actual test body def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { @@ -119,7 +108,7 @@ class HashJoinNodeSuite extends LocalNodeTest { .map { case (k, v) => (k, v, k, rightInputMap(k)) } Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val makeUnsafeNode = wrapForUnsafe(makeNode) val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val actualOutput = hashJoinNode.collect().map { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 40299d9d5ee37..252f7cc8971f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -26,30 +26,21 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} class NestedLoopJoinNodeSuite extends LocalNodeTest { // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types - private val maybeUnsafeAndCodegen = Seq(false, true) private val buildSides = Seq(BuildLeft, BuildRight) private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - joinTypes.foreach { joinType => - testJoin(unsafeAndCodegen, buildSide, joinType) - } + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(buildSide, joinType) } } /** * Test outer nested loop joins with varying degrees of matches. */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide, - joinType: JoinType): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + private def testJoin(buildSide: BuildSide, joinType: JoinType): Unit = { + val testNamePrefix = s"$buildSide / $joinType" val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) // Actual test body def runTest( @@ -63,7 +54,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { resolveExpressions( new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val makeUnsafeNode = wrapForUnsafe(makeNode) val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) val actualOutput = hashJoinNode.collect().map { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 97162249d9951..544c1ef303ae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -110,33 +110,23 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("Project metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( - "number of rows" -> 2L))) - ) - } + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) } test("TungstenProject metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("TungstenProject", Map( - "number of rows" -> 2L))) - ) - } + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) } test("Filter metrics") { @@ -150,71 +140,30 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("SortBasedAggregate metrics") { - // Because SortBasedAggregate may skip different rows if the number of partitions is different, - // this test should use the deterministic number of partitions. - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> - // SortBasedAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) - // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 3L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } - } - test("TungstenAggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) - // -> TungstenAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) } test("SortMergeJoin metrics") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 22d2aefd699b5..61e3e913c23ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -808,54 +808,12 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } } -class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ +class TungstenAggregationQuerySuite extends AggregationQuerySuite - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} - -class TungstenAggregationQuerySuite extends AggregationQuerySuite { - - var originalUnsafeEnabled: Boolean = _ - - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ - - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") - } - override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => sqlContext.setConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 94162da4eae1a..a7b7ad0093915 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -37,8 +37,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", - "== Physical Plan ==", - "Code Generation") + "== Physical Plan ==") } test("explain create table command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 5f9a447759b48..5ab477efc4ee0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -28,11 +28,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton - import org.apache.spark.util.Utils + case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -92,44 +92,36 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { } test("Max/Min on named_struct") { - def testOrderInStruct(): Unit = { - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - - // nested struct cases - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - } - val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) } test("SPARK-6409 UDAF Average test") { From a9a6b80c718008aac7c411dfe46355efe58dee2e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 11 Nov 2015 12:48:51 -0800 Subject: [PATCH 314/324] [SPARK-11645][SQL] Remove OpenHashSet for the old aggregate. Author: Reynold Xin Closes #9621 from rxin/SPARK-11645. --- .../expressions/codegen/CodeGenerator.scala | 6 - .../codegen/GenerateUnsafeProjection.scala | 7 +- .../spark/sql/catalyst/expressions/sets.scala | 194 ------------------ .../sql/execution/SparkSqlSerializer.scala | 103 +--------- .../spark/sql/UserDefinedTypeSuite.scala | 11 - 5 files changed, 5 insertions(+), 316 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5a4bba232b04b..ccd91d3549b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ -// These classes are here to avoid issues with serialization and integration with quasiquotes. -class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] -class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] - /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * @@ -205,8 +201,6 @@ class CodeGenContext { case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" case _: MapType => "MapData" - case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName - case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9ef226141421b..4c17d02a23725 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case dt: OpenHashSetUDT => false // it's not a standard UDT case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro in.map(BindReferences.bindReference(_, inputSchema)) def generate( - expressions: Seq[Expression], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { create(canonicalize(expressions), subexpressionEliminationEnabled) } protected def create(expressions: Seq[Expression]): UnsafeProjection = { - create(expressions, false) + create(expressions, subexpressionEliminationEnabled = false) } private def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala deleted file mode 100644 index d124d29d534b8..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - -/** The data type for expressions returning an OpenHashSet as the result. */ -private[sql] class OpenHashSetUDT( - val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { - - override def sqlType: DataType = ArrayType(elementType) - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def serialize(obj: Any): Seq[Any] = { - obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq - } - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def deserialize(datum: Any): OpenHashSet[Any] = { - val iterator = datum.asInstanceOf[Seq[Any]].iterator - val set = new OpenHashSet[Any] - while(iterator.hasNext) { - set.add(iterator.next()) - } - - set - } - - override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] - - private[spark] override def asNullable: OpenHashSetUDT = this -} - -/** - * Creates a new set of the specified type - */ -case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { - - override def nullable: Boolean = false - - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - - override def eval(input: InternalRow): Any = { - new OpenHashSet[Any]() - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - elementType match { - case IntegerType | LongType => - ev.isNull = "false" - s""" - ${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}(); - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"new Set($dataType)" -} - -/** - * Adds an item to a set. - * For performance, this expression mutates its input during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class AddItemToSet(item: Expression, set: Expression) - extends Expression with CodegenFallback { - - override def children: Seq[Expression] = item :: set :: Nil - - override def nullable: Boolean = set.nullable - - override def dataType: DataType = set.dataType - - override def eval(input: InternalRow): Any = { - val itemEval = item.eval(input) - val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] - - if (itemEval != null) { - if (setEval != null) { - setEval.add(itemEval) - setEval - } else { - null - } - } else { - setEval - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = "false" - ev.value = setEval.value - itemEval.code + setEval.code + s""" - if (!${itemEval.isNull} && !${setEval.isNull}) { - (($htype)${setEval.value}).add(${itemEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"$set += $item" -} - -/** - * Combines the elements of two sets. - * For performance, this expression mutates its left input set during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CombineSets(left: Expression, right: Expression) - extends BinaryExpression with CodegenFallback { - - override def nullable: Boolean = left.nullable - override def dataType: DataType = left.dataType - - override def eval(input: InternalRow): Any = { - val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] - if(leftEval != null) { - val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] - if (rightEval != null) { - val iterator = rightEval.iterator - while(iterator.hasNext) { - val rightValue = iterator.next() - leftEval.add(rightValue) - } - } - leftEval - } else { - null - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = leftEval.isNull - ev.value = leftEval.value - leftEval.code + rightEval.code + s""" - if (!${leftEval.isNull} && !${rightEval.isNull}) { - ${leftEval.value}.union((${htype})${rightEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } -} - -/** - * Returns the number of elements in the input set. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { - - override def dataType: DataType = LongType - - protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[OpenHashSet[Any]].size.toLong - - override def toString: String = s"$child.count()" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index b19ad4f1c563e..8317f648ccb4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.{SparkConf, SparkEnv} + private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { val kryo = super.newKryo() @@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) - kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], - new HyperLogLogSerializer) kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer) kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer) - // Specific hashsets must come first TODO: Move to core. - kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) - kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], - new OpenHashSetSerializer) kryo.register(classOf[Decimal]) kryo.register(classOf[JavaHashMap[_, _]]) @@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } private[execution] class KryoResourcePool(size: Int) - extends ResourcePool[SerializerInstance](size) { + extends ResourcePool[SerializerInstance](size) { val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] { new java.math.BigDecimal(input.readString()) } } - -private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { - def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) { - val bytes = hyperLogLog.getBytes() - output.writeInt(bytes.length) - output.writeBytes(bytes) - } - - def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = { - val length = input.readInt() - val bytes = input.readBytes(length) - HyperLogLog.Builder.build(bytes) - } -} - -private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { - def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - val numItems = input.readInt() - val set = new OpenHashSet[Any](numItems + 1) - var i = 0 - while (i < numItems) { - val row = - new GenericInternalRow(rowSerializer.read( - kryo, - input, - classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) - set.add(row) - i += 1 - } - set - } -} - -private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { - def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value: Int = iterator.next() - output.writeInt(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { - val numItems = input.readInt() - val set = new IntegerHashSet - var i = 0 - while (i < numItems) { - val value = input.readInt() - set.add(value) - i += 1 - } - set - } -} - -private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { - def write(kryo: Kryo, output: Output, hs: LongHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value = iterator.next() - output.writeLong(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { - val numItems = input.readInt() - val set = new LongHashSet - var i = 0 - while (i < numItems) { - val value = input.readLong() - set.add(value) - i += 1 - } - set - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index e31c528f3a633..f602f2fb89ca5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -23,7 +23,6 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -131,15 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("OpenHashSetUDT") { - val openHashSetUDT = new OpenHashSetUDT(IntegerType) - val set = new OpenHashSet[Int] - (1 to 10).foreach(i => set.add(i)) - - val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) - assert(actual.iterator.toSet === set.iterator.toSet) - } - test("UDTs with JSON") { val data = Seq( "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", @@ -163,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") assert(new MyDenseVectorUDT().typeName === "mydensevector") - assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } test("Catalyst type converter null handling for UDTs") { From dd77e278b99e45c20fdefb1c795f3c5148d577db Mon Sep 17 00:00:00 2001 From: Nick Evans Date: Wed, 11 Nov 2015 13:29:30 -0800 Subject: [PATCH 315/324] [SPARK-11335][STREAMING] update kafka direct python docs on how to get the offset ranges for a KafkaRDD tdas koeninger This updates the Spark Streaming + Kafka Integration Guide doc with a working method to access the offsets of a `KafkaRDD` through Python. Author: Nick Evans Closes #9289 from manygrams/update_kafka_direct_python_docs. --- docs/streaming-kafka-integration.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index ab7f0117c0b7f..b00351b2fbcc0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -181,7 +181,20 @@ Next, we discuss how to use this approach in your streaming application. );
    - Not supported yet + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges)
    From 2d76e44b1a88e08047806972b2d241a89e499bab Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 11 Nov 2015 14:30:38 -0800 Subject: [PATCH 316/324] [SPARK-11647] Attempt to reduce time/flakiness of Thriftserver CLI and SparkSubmit tests This patch aims to reduce the test time and flakiness of HiveSparkSubmitSuite, SparkSubmitSuite, and CliSuite. Key changes: - Disable IO synchronization calls for Derby writes, since durability doesn't matter for tests. This was done for HiveCompatibilitySuite in #6651 and resulted in huge test speedups. - Add a few missing `--conf`s to disable various Spark UIs. The CliSuite, in particular, never disabled these UIs, leaving it prone to port-contention-related flakiness. - Fix two instances where tests defined `beforeAll()` methods which were never called because the appropriate traits were not mixed in. I updated these tests suites to extend `BeforeAndAfterEach` so that they play nicely with our `ResetSystemProperties` trait. Author: Josh Rosen Closes #9623 from JoshRosen/SPARK-11647. --- .../spark/deploy/RPackageUtilsSuite.scala | 12 ++++++----- .../spark/deploy/SparkSubmitSuite.scala | 6 ++++-- .../sql/hive/thriftserver/CliSuite.scala | 21 ++++++++++++------- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 17 +++++++++++---- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 1ed4bae3ca21e..cc30ba223e1c3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -33,8 +33,12 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.ResetSystemProperties -class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { +class RPackageUtilsSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ResetSystemProperties { private val main = MavenCoordinate("a", "b", "c") private val dep1 = MavenCoordinate("a", "dep1", "c") @@ -60,11 +64,9 @@ class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { } } - def beforeAll() { - System.setProperty("spark.testing", "true") - } - override def beforeEach(): Unit = { + super.beforeEach() + System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1fd470cd3b01d..66a50512003dc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -37,10 +37,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class SparkSubmitSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3fa5c8528b602..fcf039916913a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -27,7 +27,7 @@ import scala.concurrent.{Await, Promise} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} @@ -36,21 +36,26 @@ import org.apache.spark.{Logging, SparkFunSuite} * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() val scratchDirPath = Utils.createTempDir() - before { + override def beforeAll(): Unit = { + super.beforeAll() warehousePath.delete() metastorePath.delete() scratchDirPath.delete() } - after { - warehousePath.delete() - metastorePath.delete() - scratchDirPath.delete() + override def afterAll(): Unit = { + try { + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() + } finally { + super.afterAll() + } } /** @@ -79,6 +84,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local + | --driver-java-options -Dderby.system.durability=test + | --conf spark.ui.enabled=false | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10e4ae2c50308..24a3afee148c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -23,7 +23,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -42,14 +42,14 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class HiveSparkSubmitSuite extends SparkFunSuite with Matchers - // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we - // add a timestamp to provide more diagnosis information. + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { // TODO: rewrite these or mark them as slow tests to be run sparingly - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -66,6 +66,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -79,6 +80,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -93,6 +95,7 @@ class HiveSparkSubmitSuite val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--class", "Main", testJar) runSparkSubmit(args) @@ -104,6 +107,9 @@ class HiveSparkSubmitSuite "--class", SPARK_9757.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -114,6 +120,9 @@ class HiveSparkSubmitSuite "--class", SPARK_11009.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } From e1bcf6af9ba4f131f84d71660d0ab5598c0b7b67 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 11 Nov 2015 15:30:21 -0800 Subject: [PATCH 317/324] [SPARK-10827] replace volatile with Atomic* in AppClient.scala. This is a followup for #9317 to replace volatile fields with AtomicBoolean and AtomicReference. Author: Reynold Xin Closes #9611 from rxin/SPARK-10827. --- .../spark/deploy/client/AppClient.scala | 68 ++++++++++--------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 3f29da663b798..afab362e213b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.client import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.util.control.NonFatal @@ -49,9 +50,9 @@ private[spark] class AppClient( private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - @volatile private var endpoint: RpcEndpointRef = null - @volatile private var appId: String = null - @volatile private var registered = false + private val endpoint = new AtomicReference[RpcEndpointRef] + private val appId = new AtomicReference[String] + private val registered = new AtomicBoolean(false) private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { @@ -59,16 +60,17 @@ private[spark] class AppClient( private var master: Option[RpcEndpointRef] = None // To avoid calling listener.disconnected() multiple times private var alreadyDisconnected = false - @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times - @volatile private var registerMasterFutures: Array[JFuture[_]] = null - @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + // To avoid calling listener.dead() multiple times + private val alreadyDead = new AtomicBoolean(false) + private val registerMasterFutures = new AtomicReference[Array[JFuture[_]]] + private val registrationRetryTimer = new AtomicReference[JScheduledFuture[_]] // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. private val registerMasterThreadPool = new ThreadPoolExecutor( 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time + masterRpcAddresses.length, // Make sure we can register with all masters at the same time 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable](), ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) @@ -100,7 +102,7 @@ private[spark] class AppClient( for (masterAddress <- masterRpcAddresses) yield { registerMasterThreadPool.submit(new Runnable { override def run(): Unit = try { - if (registered) { + if (registered.get) { return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") @@ -123,22 +125,22 @@ private[spark] class AppClient( * nthRetry means this is the nth attempt to register with master. */ private def registerWithMaster(nthRetry: Int) { - registerMasterFutures = tryRegisterAllMasters() - registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + registerMasterFutures.set(tryRegisterAllMasters()) + registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = { Utils.tryOrExit { - if (registered) { - registerMasterFutures.foreach(_.cancel(true)) + if (registered.get) { + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerWithMaster(nthRetry + 1) } } } - }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) } /** @@ -163,10 +165,10 @@ private[spark] class AppClient( // RegisteredApplications due to an unstable network. // 2. Receive multiple RegisteredApplication from different masters because the master is // changing. - appId = appId_ - registered = true + appId.set(appId_) + registered.set(true) master = Some(masterRef) - listener.connected(appId) + listener.connected(appId.get) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) @@ -178,7 +180,7 @@ private[spark] class AppClient( cores)) // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) + sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -193,13 +195,13 @@ private[spark] class AppClient( logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) alreadyDisconnected = false - masterRef.send(MasterChangeAcknowledged(appId)) + masterRef.send(MasterChangeAcknowledged(appId.get)) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopAppClient => markDead("Application has been stopped.") - sendToMaster(UnregisterApplication(appId)) + sendToMaster(UnregisterApplication(appId.get)) context.reply(true) stop() @@ -263,18 +265,18 @@ private[spark] class AppClient( } def markDead(reason: String) { - if (!alreadyDead) { + if (!alreadyDead.get) { listener.dead(reason) - alreadyDead = true + alreadyDead.set(true) } } override def onStop(): Unit = { - if (registrationRetryTimer != null) { - registrationRetryTimer.cancel(true) + if (registrationRetryTimer.get != null) { + registrationRetryTimer.get.cancel(true) } registrationRetryThread.shutdownNow() - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() askAndReplyThreadPool.shutdownNow() } @@ -283,19 +285,19 @@ private[spark] class AppClient( def start() { // Just launch an rpcEndpoint; it will call back into the listener. - endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) + endpoint.set(rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))) } def stop() { - if (endpoint != null) { + if (endpoint.get != null) { try { val timeout = RpcUtils.askRpcTimeout(conf) - timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) + timeout.awaitResult(endpoint.get.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - endpoint = null + endpoint.set(null) } } @@ -306,8 +308,8 @@ private[spark] class AppClient( * @return whether the request is acknowledged. */ def requestTotalExecutors(requestedTotal: Int): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal)) } else { logWarning("Attempted to request executors before driver fully initialized.") false @@ -319,8 +321,8 @@ private[spark] class AppClient( * @return whether the kill request is acknowledged. */ def killExecutors(executorIds: Seq[String]): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds)) } else { logWarning("Attempted to kill executors before driver fully initialized.") false From 1a21be15f655b9696ddac80aac629445a465f621 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 11 Nov 2015 15:41:36 -0800 Subject: [PATCH 318/324] [SPARK-11672][ML] disable spark.ml read/write tests Saw several failures on Jenkins, e.g., https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/2040/testReport/org.apache.spark.ml.util/JavaDefaultReadWriteSuite/testDefaultReadWrite/. This is the first failure in master build: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3982/ I cannot reproduce it on local. So temporarily disable the tests and I will look into the issue under the same JIRA. I'm going to merge the PR after Jenkins passes compile. Author: Xiangrui Meng Closes #9641 from mengxr/SPARK-11672. --- .../org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java | 4 ++-- .../spark/ml/classification/LogisticRegressionSuite.scala | 2 +- .../scala/org/apache/spark/ml/feature/BinarizerSuite.scala | 2 +- .../scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index c39538014be81..4f7aeac1ec54c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -23,7 +23,7 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Test; +import org.junit.Ignore; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; @@ -50,7 +50,7 @@ public void tearDown() { Utils.deleteRecursively(tempDir); } - @Test + @Ignore // SPARK-11672 public void testDefaultReadWrite() throws IOException { String uid = "my_params"; MyParams instance = new MyParams(uid); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 51b06b7eb6d53..e4c2f1baa4fa1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -872,7 +872,7 @@ class LogisticRegressionSuite assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) } - test("read/write") { + ignore("read/write") { // SPARK-11672 // Set some Params to make sure set Params are serialized. val lr = new LogisticRegression() .setElasticNetParam(0.1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9dfa1439cc303..a66fe03281935 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -68,7 +68,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } - test("read/write") { + ignore("read/write") { // SPARK-11672 val binarizer = new Binarizer() .setInputCol("feature") .setOutputCol("binarized_feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index cac4bd9aa3ab8..44e09c38f9375 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -105,7 +105,7 @@ object MyParams extends Readable[MyParams] { class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("default read/write") { + ignore("default read/write") { // SPARK-11672 val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } From b8ff6888e76b437287d7d6bf2d4b9c759710a195 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 11 Nov 2015 16:23:24 -0800 Subject: [PATCH 319/324] [SPARK-8992][SQL] Add pivot to dataframe api This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray Closes #7841 from aray/sql-pivot. --- .../sql/catalyst/analysis/Analyzer.scala | 42 +++++++ .../plans/logical/basicOperators.scala | 14 +++ .../org/apache/spark/sql/GroupedData.scala | 103 ++++++++++++++++-- .../scala/org/apache/spark/sql/SQLConf.scala | 7 ++ .../spark/sql/DataFramePivotSuite.scala | 87 +++++++++++++++ .../apache/spark/sql/test/SQLTestData.scala | 12 ++ 6 files changed, 255 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a9cd9a77038e7..2f4670b55bdba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -166,6 +167,10 @@ class Analyzer( case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } @@ -248,6 +253,43 @@ class Analyzer( } } + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 597f03e752707..32b09b59af436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -386,6 +386,20 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + } + } +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 5babf2cc0ca25..63dd7fbcbe9e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.types.{StringType, NumericType} /** @@ -50,14 +50,8 @@ class GroupedData protected[sql]( aggExprs } - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } + val aliasedAgg = aggregates.map(alias) + groupType match { case GroupedData.GroupByType => DataFrame( @@ -68,9 +62,22 @@ class GroupedData protected[sql]( case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + DataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { @@ -273,6 +280,77 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified + * aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) + * // Or without specifying column values + * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case GroupedData.GroupByType => + val pivotValues = if (values.nonEmpty) { + values.map { + case Column(literal: Literal) => literal + case other => + throw new UnsupportedOperationException( + s"The values of a pivot must be literals, found $other") + } + } else { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .map(Literal(_)).toSeq + if (values.length > maxValues) { + throw new RuntimeException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + + s"to at least the number of distinct values of the pivot column.") + } + values + } + new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") + * // Or without specifying column values + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: String, values: Any*): GroupedData = { + val resolvedPivotColumn = Column(df.resolve(pivotColumn)) + pivot(resolvedPivotColumn, values.map(functions.lit): _*) + } } @@ -307,4 +385,9 @@ private[sql] object GroupedData { * To indicate it's the ROLLUP */ private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e02b502b7b4d5..41d28d448ccc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -437,6 +437,13 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val DATAFRAME_PIVOT_MAX_VALUES = intConf( + "spark.sql.pivotMaxValues", + defaultValue = Some(10000), + doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error." + ) + val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 0000000000000..0c23d142670c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ + + test("pivot courses with literals") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year with literals") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with literals and multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year with no values") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot max values inforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[RuntimeException]( + courseSales.groupBy($"year").pivot($"course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 520dea7f7dd92..abad0d7eaaedf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -295,4 +306,5 @@ private[sql] object SQLTestData { case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) } From a40838adff7a0095cdd2b3ec7a487e5c48081e3f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 21 Oct 2015 23:48:43 -0700 Subject: [PATCH 320/324] Add support for colnames, colnames<-, coltypes<- --- R/pkg/NAMESPACE | 3 +- R/pkg/R/DataFrame.R | 53 +++++++++++++++++++++++++++++++- R/pkg/R/generics.R | 12 ++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1bc..1e9091eafdc63 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -27,6 +27,7 @@ exportMethods("arrange", "attach", "cache", "collect", + "colnames", "coltypes", "columns", "count", @@ -274,4 +275,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") \ No newline at end of file + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cc868069d1e5a..fd5f878e6f271 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,7 +254,7 @@ setMethod("dtypes", #' @family dataframe_funcs #' @rdname columns #' @name columns -#' @aliases names +#' @aliases names colnames #' @export #' @examples #'\dontrun{ @@ -293,6 +293,57 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @name coltypes +#' @param x (DataFrame) +#' @return value (character) A character vector with the target column types for the given DataFrame +#' @rdname coltypes +#' @aliases coltypes +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' coltypes(df) <- c("string", "integer") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0 || length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + cast(col, value[i]) + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 92ad4ee8685ee..b3197d02c24ee 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -399,6 +399,18 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname colnames +#' @export +setGeneric("colnames", function(x) { standardGeneric("colnames") }) + +#' @rdname colnames<- +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes<- +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 9e453a1e7c2f4..cc5f276908feb 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -598,6 +598,30 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form expect_equal(testNames[2], "name") }) +test_that("names() colnames() set the column names", { + df <- jsonFile(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") +}) + +test_that("coltypes() set the column types", { + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("string", "integer") + expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + expect_error(coltypes(df) <- c("string"), + "Length of type vector should match the number of columns for DataFrame") +}) + test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) From 33430471bf42f44a430c19a2a23b1d6ec0ec22e4 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 12:33:28 -0800 Subject: [PATCH 321/324] Take R types instead to map to JVM types, add check for NA to keep column --- R/pkg/R/DataFrame.R | 24 +++++++++++++++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd5f878e6f271..4cbdba654f059 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -310,13 +310,22 @@ setMethod("colnames<-", dataFrame(sdf) }) +rToScalaTypes <- new.env() +rToScalaTypes[["integer"]] <- "integer" # in R, integer is 32bit +rToScalaTypes[["numeric"]] <- "double" # in R, numeric == double which is 64bit +rToScalaTypes[["double"]] <- "double" +rToScalaTypes[["character"]] <- "string" +rToScalaTypes[["logical"]] <- "boolean" + #' coltypes #' #' Set the column types of a DataFrame. #' #' @name coltypes #' @param x (DataFrame) -#' @return value (character) A character vector with the target column types for the given DataFrame +#' @return value (character) A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. #' @rdname coltypes #' @aliases coltypes #' @export @@ -326,7 +335,8 @@ setMethod("colnames<-", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) -#' coltypes(df) <- c("string", "integer") +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") #'} setMethod("coltypes<-", signature(x = "DataFrame", value = "character"), @@ -338,7 +348,15 @@ setMethod("coltypes<-", } newCols <- lapply(seq_len(ncols), function(i) { col <- getColumn(x, cols[i]) - cast(col, value[i]) + if (!is.na(value[i])) { + stype <- rToScalaTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } }) nx <- select(x, newCols) dataFrame(nx@sdf) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cc5f276908feb..683357cacb084 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -612,14 +612,20 @@ test_that("coltypes() set the column types", { expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) - coltypes(df) <- c("string", "integer") + coltypes(df) <- c("character", "integer") expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) value <- collect(df[, 2])[[3, 1]] expect_equal(value, collect(df1)[[3, 1]]) expect_equal(value, 22) - expect_error(coltypes(df) <- c("string"), + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("cast(name as string)", "string"), + c("cast(cast(age as int) as double)", "double"))) + + expect_error(coltypes(df) <- c("character"), "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) test_that("head() and first() return the correct data", { From 969dc0e1a7281efc9443220cbbfd9f9fa905d409 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 13:27:10 -0800 Subject: [PATCH 322/324] This seems to fix the Rd error - no idea why it worked before. --- R/pkg/R/generics.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b3197d02c24ee..54c1c034a9d0e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -399,15 +399,15 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) -#' @rdname colnames +#' @rdname columns #' @export setGeneric("colnames", function(x) { standardGeneric("colnames") }) -#' @rdname colnames<- +#' @rdname columns #' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) -#' @rdname coltypes<- +#' @rdname columns #' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) From 4e820ecba237ea9f1ac8ca3f8e4a19be012da9cd Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sun, 1 Nov 2015 16:33:40 -0800 Subject: [PATCH 323/324] fix test broken from column name change from cast --- R/pkg/inst/tests/test_sparkSQL.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 683357cacb084..8b552c8109209 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -613,14 +613,13 @@ test_that("coltypes() set the column types", { df1 <- select(df, cast(df$age, "integer")) coltypes(df) <- c("character", "integer") - expect_equal(dtypes(df), list(c("cast(name as string)", "string"), c("cast(age as int)", "int"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) value <- collect(df[, 2])[[3, 1]] expect_equal(value, collect(df1)[[3, 1]]) expect_equal(value, 22) coltypes(df) <- c(NA, "numeric") - expect_equal(dtypes(df), list(c("cast(name as string)", "string"), - c("cast(cast(age as int) as double)", "double"))) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) expect_error(coltypes(df) <- c("character"), "Length of type vector should match the number of columns for DataFrame") From e2399b57aa09eb10f14cc8001c48e0f636c63bab Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 11 Nov 2015 22:40:08 -0800 Subject: [PATCH 324/324] rebase, merge with coltypes change, fix generic, doc --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 120 +++++++++++++++---------------- R/pkg/R/generics.R | 12 ++-- R/pkg/R/types.R | 8 +++ R/pkg/inst/tests/test_sparkSQL.R | 44 ++++++------ 5 files changed, 97 insertions(+), 88 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 1e9091eafdc63..248f71499a2d8 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -29,6 +29,7 @@ exportMethods("arrange", "collect", "colnames", "coltypes", + "coltypes<-", "columns", "count", "cov", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 4cbdba654f059..b0639738dff10 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -263,6 +263,7 @@ setMethod("dtypes", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -272,7 +273,6 @@ setMethod("columns", }) }) -#' @family dataframe_funcs #' @rdname columns #' @name names setMethod("names", @@ -281,7 +281,6 @@ setMethod("names", columns(x) }) -#' @family dataframe_funcs #' @rdname columns #' @name names<- setMethod("names<-", @@ -297,7 +296,7 @@ setMethod("names<-", #' @name colnames setMethod("colnames", signature(x = "DataFrame"), - function(x) { + function(x, do.NULL = TRUE, prefix = "col") { columns(x) }) @@ -310,24 +309,67 @@ setMethod("colnames<-", dataFrame(sdf) }) -rToScalaTypes <- new.env() -rToScalaTypes[["integer"]] <- "integer" # in R, integer is 32bit -rToScalaTypes[["numeric"]] <- "double" # in R, numeric == double which is 64bit -rToScalaTypes[["double"]] <- "double" -rToScalaTypes[["character"]] <- "string" -rToScalaTypes[["logical"]] <- "boolean" +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @name coltypes +#' @param x (DataFrame) +#' @return value (character) A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @family dataframe_funcs +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) #' coltypes #' #' Set the column types of a DataFrame. #' -#' @name coltypes +#' @name coltypes<- #' @param x (DataFrame) -#' @return value (character) A character vector with the target column types for the given +#' @param value (character) A character vector with the target column types for the given #' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA #' to keep that column as-is. #' @rdname coltypes -#' @aliases coltypes #' @export #' @examples #'\dontrun{ @@ -343,7 +385,10 @@ setMethod("coltypes<-", function(x, value) { cols <- columns(x) ncols <- length(cols) - if (length(value) == 0 || length(value) != ncols) { + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { stop("Length of type vector should match the number of columns for DataFrame") } newCols <- lapply(seq_len(ncols), function(i) { @@ -2221,52 +2266,3 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) - -#' Returns the column types of a DataFrame. -#' -#' @name coltypes -#' @title Get column types of a DataFrame -#' @family dataframe_funcs -#' @param x (DataFrame) -#' @return value (character) A character vector with the column types of the given DataFrame -#' @rdname coltypes -#' @examples \dontrun{ -#' irisDF <- createDataFrame(sqlContext, iris) -#' coltypes(irisDF) -#' } -setMethod("coltypes", - signature(x = "DataFrame"), - function(x) { - # Get the data types of the DataFrame by invoking dtypes() function - types <- sapply(dtypes(x), function(x) {x[[2]]}) - - # Map Spark data types into R's data types using DATA_TYPES environment - rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { - - # Check for primitive types - type <- PRIMITIVE_TYPES[[x]] - - if (is.null(type)) { - # Check for complex types - for (t in names(COMPLEX_TYPES)) { - if (substring(x, 1, nchar(t)) == t) { - type <- COMPLEX_TYPES[[t]] - break - } - } - - if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) - } - } - type - }) - - # Find which types don't have mapping to R - naIndices <- which(is.na(rTypes)) - - # Assign the original scala data types to the unmatched ones - rTypes[naIndices] <- types[naIndices] - - rTypes - }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 54c1c034a9d0e..96079c8b186b3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -401,13 +401,17 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname columns #' @export -setGeneric("colnames", function(x) { standardGeneric("colnames") }) +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns #' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) -#' @rdname columns +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes #' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) @@ -1099,7 +1103,3 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") - -#' @rdname coltypes -#' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1828c23ab0f6d..4b69589dfa247 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -41,3 +41,11 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToScalaTypes <- new.env() +rToScalaTypes[["integer"]] <- "integer" # in R, integer is 32bit +rToScalaTypes[["numeric"]] <- "double" # in R, numeric == double which is 64bit +rToScalaTypes[["double"]] <- "double" +rToScalaTypes[["character"]] <- "string" +rToScalaTypes[["logical"]] <- "boolean" diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8b552c8109209..576e67ac1a843 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -605,26 +605,12 @@ test_that("names() colnames() set the column names", { colnames(df) <- c("col3", "col4") expect_equal(names(df)[1], "col3") -}) - -test_that("coltypes() set the column types", { - df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") - expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) - - df1 <- select(df, cast(df$age, "integer")) - coltypes(df) <- c("character", "integer") - expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) - value <- collect(df[, 2])[[3, 1]] - expect_equal(value, collect(df1)[[3, 1]]) - expect_equal(value, 22) - - coltypes(df) <- c(NA, "numeric") - expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) - expect_error(coltypes(df) <- c("character"), - "Length of type vector should match the number of columns for DataFrame") - expect_error(coltypes(df) <- c("environment", "list"), - "Only atomic type is supported for column types") + # Test base::colnames + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x","Y") + expect_equal(colnames(m2), c("x", "Y")) }) test_that("head() and first() return the correct data", { @@ -1584,7 +1570,7 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) -test_that("Method coltypes() to get R's data types of a DataFrame", { +test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) data <- data.frame(c1=c(1,2,3), @@ -1603,6 +1589,24 @@ test_that("Method coltypes() to get R's data types of a DataFrame", { x <- createDataFrame(sqlContext, list(list(as.environment( list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") + + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) unlink(parquetPath)
    - {app.desc.name} + {app.desc.name} {app.coresGranted} diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 967aa0976f0ce..3164760b08a71 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -31,8 +31,9 @@ private[deploy] object DeployTestUtils { } def createAppInfo() : ApplicationInfo = { + val appDesc = createAppDesc() val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + "id", appDesc, JsonConstants.submitDate, null, Int.MaxValue) appInfo.endTime = JsonConstants.currTimeInMillis appInfo } From d188a67762dfc09929e30931509be5851e29dfa5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 3 Nov 2015 22:30:23 +0800 Subject: [PATCH 143/324] [SPARK-10533][SQL] handle scientific notation in sqlParser https://issues.apache.org/jira/browse/SPARK-10533 val df = sqlContext.createDataFrame(Seq(("a",1.0),("b",2.0),("c",3.0))) df.filter("_2 < 2.0e1").show Scientific notation didn't work. Author: Daoyuan Wang Closes #9085 from adrian-wang/scinotation. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 15 +++++++++++++-- .../org/apache/spark/sql/catalyst/SqlParser.scala | 11 +++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 11 ++++++++--- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 2bac08eac4fe2..04ac4f20c66ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -82,6 +82,10 @@ class SqlLexical extends StdLexical { override def toString: String = chars } + case class DecimalLit(chars: String) extends Token { + override def toString: String = chars + } + /* This is a work around to support the lazy setting */ def initialize(keywords: Seq[String]): Unit = { reserved.clear() @@ -102,8 +106,12 @@ class SqlLexical extends StdLexical { } override lazy val token: Parser[Token] = - ( identChar ~ (identChar | digit).* ^^ - { case first ~ rest => processIdent((first :: rest).mkString) } + ( rep1(digit) ~ ('.' ~> digit.*).? ~ (exp ~> sign.? ~ rep1(digit)) ^^ { + case i ~ None ~ (sig ~ rest) => + DecimalLit(i.mkString + "e" + sig.mkString + rest.mkString) + case i ~ Some(d) ~ (sig ~ rest) => + DecimalLit(i.mkString + "." + d.mkString + "e" + sig.mkString + rest.mkString) + } | digit.* ~ identChar ~ (identChar | digit).* ^^ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { @@ -125,6 +133,9 @@ class SqlLexical extends StdLexical { override def identChar: Parser[Elem] = letter | elem('_') + private lazy val sign: Parser[Elem] = elem("s", c => c == '+' || c == '-') + private lazy val exp: Parser[Elem] = elem("e", c => c == 'E' || c == 'e') + override def whitespace: Parser[Any] = ( whitespaceChar | '/' ~ '*' ~ comment diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d7567e8613e3c..1ba559d9e3b18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -337,6 +337,9 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | sign.? ~ unsignedFloat ^^ { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } + | sign.? ~ unsignedDecimal ^^ { + case s ~ d => Literal(toDecimalOrDouble(s.getOrElse("") + d)) + } ) protected lazy val unsignedFloat: Parser[String] = @@ -344,6 +347,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) ) + protected lazy val unsignedDecimal: Parser[String] = + ( "." ~> decimalLit ^^ { u => "0." + u } + | elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) + ) + + def decimalLit: Parser[String] = + elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) + protected lazy val sign: Parser[String] = ("+" | "-") protected lazy val integral: Parser[String] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6b86c5951b413..a883bcb7b1012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -177,9 +177,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("filterExpr") { - checkAnswer( - testData.filter("key > 90"), - testData.collect().filter(_.getInt(0) > 90).toSeq) + val res = testData.collect().filter(_.getInt(0) > 90).toSeq + checkAnswer(testData.filter("key > 90"), res) + checkAnswer(testData.filter("key > 9.0e1"), res) + checkAnswer(testData.filter("key > .9e+2"), res) + checkAnswer(testData.filter("key > 0.9e+2"), res) + checkAnswer(testData.filter("key > 900e-1"), res) + checkAnswer(testData.filter("key > 900.0E-1"), res) + checkAnswer(testData.filter("key > 9.e+1"), res) } test("filterExpr using where") { From 57446eb69ceb6b8856ab22b54abb22b47b80f841 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Nov 2015 07:06:00 -0800 Subject: [PATCH 144/324] [SPARK-11256] Mark all Stage/ResultStage/ShuffleMapStage internal state as private. Author: Reynold Xin Closes #9219 from rxin/stage-cleanup1. --- .../apache/spark/scheduler/DAGScheduler.scala | 33 +++++----- .../apache/spark/scheduler/ResultStage.scala | 19 +++++- .../spark/scheduler/ShuffleMapStage.scala | 61 +++++++++++++------ .../org/apache/spark/scheduler/Stage.scala | 5 +- 4 files changed, 80 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 995862ece5944..5673fbf2c8fea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack} +import scala.collection.mutable.{HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -535,10 +535,8 @@ class DAGScheduler( jobIdToActiveJob -= job.jobId activeJobs -= job job.finalStage match { - case r: ResultStage => - r.resultOfJob = None - case m: ShuffleMapStage => - m.mapStageJobs = m.mapStageJobs.filter(_ != job) + case r: ResultStage => r.removeActiveJob() + case m: ShuffleMapStage => m.removeActiveJob(job) } } @@ -848,7 +846,7 @@ class DAGScheduler( val jobSubmissionTime = clock.getTimeMillis() jobIdToActiveJob(jobId) = job activeJobs += job - finalStage.resultOfJob = Some(job) + finalStage.setActiveJob(job) val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) listenerBus.post( @@ -880,7 +878,7 @@ class DAGScheduler( val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) clearCacheLocs() logInfo("Got map stage job %s (%s) with %d output partitions".format( - jobId, callSite.shortForm, dependency.rdd.partitions.size)) + jobId, callSite.shortForm, dependency.rdd.partitions.length)) logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) @@ -888,7 +886,7 @@ class DAGScheduler( val jobSubmissionTime = clock.getTimeMillis() jobIdToActiveJob(jobId) = job activeJobs += job - finalStage.mapStageJobs = job :: finalStage.mapStageJobs + finalStage.addActiveJob(job) val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) listenerBus.post( @@ -950,12 +948,12 @@ class DAGScheduler( // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. outputCommitCoordinator.stageStart(stage.id) - val taskIdToLocations = try { + val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { stage match { case s: ShuffleMapStage => partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap case s: ResultStage => - val job = s.resultOfJob.get + val job = s.activeJob.get partitionsToCompute.map { id => val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) @@ -1016,7 +1014,7 @@ class DAGScheduler( } case stage: ResultStage => - val job = stage.resultOfJob.get + val job = stage.activeJob.get partitionsToCompute.map { id => val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) @@ -1132,7 +1130,7 @@ class DAGScheduler( // Cast to ResultStage here because it's part of the ResultTask // TODO Refactor this out to a function that accepts a ResultStage val resultStage = stage.asInstanceOf[ResultStage] - resultStage.resultOfJob match { + resultStage.activeJob match { case Some(job) => if (!job.finished(rt.outputId)) { updateAccumulators(event) @@ -1187,7 +1185,7 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(_.headOption.orNull), + shuffleStage.outputLocInMapOutputTrackerFormat(), changeEpoch = true) clearCacheLocs() @@ -1197,8 +1195,7 @@ class DAGScheduler( // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + - shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { // Mark any map-stage jobs waiting on this stage as finished @@ -1312,8 +1309,10 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(_.headOption.orNull) - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + mapOutputTracker.registerMapOutputs( + shuffleId, + stage.outputLocInMapOutputTrackerFormat(), + changeEpoch = true) } if (shuffleToMapStage.isEmpty) { mapOutputTracker.incrementEpoch() diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c1d86af7e8fb5..d1687830ff7bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -41,10 +41,25 @@ private[spark] class ResultStage( * The active job for this result stage. Will be empty if the job has already finished * (e.g., because the job was cancelled). */ - var resultOfJob: Option[ActiveJob] = None + private[this] var _activeJob: Option[ActiveJob] = None + def activeJob: Option[ActiveJob] = _activeJob + + def setActiveJob(job: ActiveJob): Unit = { + _activeJob = Option(job) + } + + def removeActiveJob(): Unit = { + _activeJob = None + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + * + * This can only be called when there is an active job. + */ override def findMissingPartitions(): Seq[Int] = { - val job = resultOfJob.get + val job = activeJob.get (0 until job.numPartitions).filter(id => !job.finished(id)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 3832d99eddaef..51416e5ce97fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -43,35 +43,53 @@ private[spark] class ShuffleMapStage( val shuffleDep: ShuffleDependency[_, _, _]) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + private[this] var _mapStageJobs: List[ActiveJob] = Nil + + private[this] var _numAvailableOutputs: Int = 0 + + /** + * List of [[MapStatus]] for each partition. The index of the array is the map partition id, + * and each value in the array is the list of possible [[MapStatus]] for a partition + * (a single task might run multiple times). + */ + private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + override def toString: String = "ShuffleMapStage " + id - /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ - var mapStageJobs: List[ActiveJob] = Nil + /** + * Returns the list of active jobs, + * i.e. map-stage jobs that were submitted to execute this stage independently (if any). + */ + def mapStageJobs: Seq[ActiveJob] = _mapStageJobs + + /** Adds the job to the active job list. */ + def addActiveJob(job: ActiveJob): Unit = { + _mapStageJobs = job :: _mapStageJobs + } + + /** Removes the job from the active job list. */ + def removeActiveJob(job: ActiveJob): Unit = { + _mapStageJobs = _mapStageJobs.filter(_ != job) + } /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - var numAvailableOutputs: Int = 0 + def numAvailableOutputs: Int = _numAvailableOutputs /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = numAvailableOutputs == numPartitions - - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + def isAvailable: Boolean = _numAvailableOutputs == numPartitions + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}") + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") missing } @@ -79,7 +97,7 @@ private[spark] class ShuffleMapStage( val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList if (prevList == Nil) { - numAvailableOutputs += 1 + _numAvailableOutputs += 1 } } @@ -88,10 +106,19 @@ private[spark] class ShuffleMapStage( val newList = prevList.filterNot(_.location == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 + _numAvailableOutputs -= 1 } } + /** + * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned + * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, + * that position is filled with null. + */ + def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { + outputLocs.map(_.headOption.orNull) + } + /** * Removes all shuffle outputs associated with this executor. Note that this will also remove * outputs which are served by an external shuffle server (if one exists), as they are still @@ -105,12 +132,12 @@ private[spark] class ShuffleMapStage( outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true - numAvailableOutputs -= 1 + _numAvailableOutputs -= 1 } } if (becameUnavailable) { logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) + this, execId, _numAvailableOutputs, numPartitions, isAvailable)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5ce4a484344f1..7ea24a217bd39 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -71,8 +71,8 @@ private[scheduler] abstract class Stage( /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 - val name = callSite.shortForm - val details = callSite.longForm + val name: String = callSite.shortForm + val details: String = callSite.longForm private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty @@ -134,6 +134,7 @@ private[scheduler] abstract class Stage( def latestInfo: StageInfo = _latestInfo override final def hashCode(): Int = id + override final def equals(other: Any): Boolean = other match { case stage: Stage => stage != null && stage.id == id case _ => false From d6035d97c91fe78b1336ade48134252915263ea6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 3 Nov 2015 07:41:50 -0800 Subject: [PATCH 145/324] [SPARK-10304] [SQL] Partition discovery should throw an exception if the dir structure is invalid JIRA: https://issues.apache.org/jira/browse/SPARK-10304 This patch detects if the structure of partition directories is not valid. The test cases are from #8547. Thanks zhzhan. cc liancheng Author: Liang-Chi Hsieh Closes #8840 from viirya/detect_invalid_part_dir. --- .../datasources/PartitioningUtils.scala | 36 +++++++++++++------ .../ParquetPartitionDiscoverySuite.scala | 36 +++++++++++++++++-- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 0a2007e15843c..628c5e18936c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -77,9 +77,11 @@ private[sql] object PartitioningUtils { defaultPartitionName: String, typeInference: Boolean): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. - val pathsWithPartitionValues = paths.flatMap { path => - parsePartition(path, defaultPartitionName, typeInference).map(path -> _) - } + val (partitionValues, optBasePaths) = paths.map { path => + parsePartition(path, defaultPartitionName, typeInference) + }.unzip + + val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _)) if (pathsWithPartitionValues.isEmpty) { // This dataset is not partitioned. @@ -87,6 +89,12 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. + val basePaths = optBasePaths.flatMap(x => x) + assert( + basePaths.distinct.size == 1, + "Conflicting directory structures detected. Suspicious paths:\b" + + basePaths.mkString("\n\t", "\n\t", "\n\n")) + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) // Creates the StructType which represents the partition columns. @@ -110,12 +118,12 @@ private[sql] object PartitioningUtils { } /** - * Parses a single partition, returns column names and values of each partition column. For - * example, given: + * Parses a single partition, returns column names and values of each partition column, also + * the base path. For example, given: * {{{ * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 * }}} - * it returns: + * it returns the partition: * {{{ * PartitionValues( * Seq("a", "b", "c"), @@ -124,34 +132,40 @@ private[sql] object PartitioningUtils { * Literal.create("hello", StringType), * Literal.create(3.14, FloatType))) * }}} + * and the base path: + * {{{ + * /path/to/partition + * }}} */ private[sql] def parsePartition( path: Path, defaultPartitionName: String, - typeInference: Boolean): Option[PartitionValues] = { + typeInference: Boolean): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null var chopped = path + var basePath = path while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left // uncleaned. Here we simply ignore them. if (chopped.getName.toLowerCase == "_temporary") { - return None + return (None, None) } val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) maybeColumn.foreach(columns += _) + basePath = chopped chopped = chopped.getParent - finished = maybeColumn.isEmpty || chopped.getParent == null + finished = (maybeColumn.isEmpty && !columns.isEmpty) || chopped.getParent == null } if (columns.isEmpty) { - None + (None, Some(path)) } else { val (columnNames, values) = columns.reverse.unzip - Some(PartitionValues(columnNames, values)) + (Some(PartitionValues(columnNames, values)), Some(basePath)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 3a23b8ed66808..67b6a37fa502e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -58,14 +58,46 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check(defaultPartitionName, Literal.create(null, NullType)) } + test("parse invalid partitioned directories") { + // Invalid + var paths = Seq( + "hdfs://host:9000/invalidPath", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello") + + var exception = intercept[AssertionError] { + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) + + // Valid + paths = Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/_temporary/path") + + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + + // Invalid + paths = Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/path1") + + exception = intercept[AssertionError] { + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) + } + test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)._1) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true).get + parsePartition(new Path(path), defaultPartitionName, true) }.getMessage assert(message.contains(expected)) From d6f10aa7ea2806c0fbcfc31d7dee91d28319fab7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 3 Nov 2015 08:29:07 -0800 Subject: [PATCH 146/324] [SPARK-9836][ML] Provide R-like summary statistics for OLS via normal equation solver https://issues.apache.org/jira/browse/SPARK-9836 Author: Yanbo Liang Closes #9413 from yanboliang/spark-9836. --- .../spark/ml/optim/WeightedLeastSquares.scala | 15 +- .../ml/regression/LinearRegression.scala | 90 +++++++++++- .../mllib/linalg/CholeskyDecomposition.scala | 16 +++ .../ml/regression/LinearRegressionSuite.scala | 129 ++++++++++++++++++ 4 files changed, 243 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 3d64f7f296137..e612a2122ed62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -26,10 +26,12 @@ import org.apache.spark.rdd.RDD * Model fitted by [[WeightedLeastSquares]]. * @param coefficients model coefficients * @param intercept model intercept + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 */ private[ml] class WeightedLeastSquaresModel( val coefficients: DenseVector, - val intercept: Double) extends Serializable + val intercept: Double, + val diagInvAtWA: DenseVector) extends Serializable /** * Weighted least squares solver via normal equation. @@ -73,7 +75,9 @@ private[ml] class WeightedLeastSquares( val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() logInfo(s"Number of instances: ${summary.count}.") + val k = summary.k val triK = summary.triK + val wSum = summary.wSum val bBar = summary.bBar val bStd = summary.bStd val aBar = summary.aBar @@ -109,6 +113,11 @@ private[ml] class WeightedLeastSquares( val x = new DenseVector(CholeskyDecomposition.solve(aaBar.values, abBar.values)) + val aaInv = CholeskyDecomposition.inverse(aaBar.values, k) + // aaInv is a packed upper triangular matrix, here we get all elements on diagonal + val diagInvAtWA = new DenseVector((1 to k).map { i => + aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) + // compute intercept val intercept = if (fitIntercept) { bBar - BLAS.dot(aBar, x) @@ -116,7 +125,7 @@ private[ml] class WeightedLeastSquares( 0.0 } - new WeightedLeastSquaresModel(x, intercept) + new WeightedLeastSquaresModel(x, intercept, diagInvAtWA) } } @@ -131,7 +140,7 @@ private[ml] object WeightedLeastSquares { var k: Int = _ var count: Long = _ var triK: Int = _ - private var wSum: Double = _ + var wSum: Double = _ private var wwSum: Double = _ private var bSum: Double = _ private var bbSum: Double = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6e9c7442b8110..c51e30483ab3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.stats.distributions.StudentsT import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental @@ -36,7 +37,7 @@ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf, lit} +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel /** @@ -173,8 +174,11 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + summaryModel, + model.diagInvAtWA.toArray, $(featuresCol), Array(0D)) + return lrModel.setSummary(trainingSummary) } @@ -221,6 +225,8 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) @@ -316,6 +322,8 @@ class LinearRegression(override val uid: String) summaryModel.transform(dataset), predictionColName, $(labelCol), + model, + Array(0D), $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) @@ -371,7 +379,8 @@ class LinearRegressionModel private[ml] ( private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() - new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, + $(labelCol), this, Array(0D)) } /** @@ -412,9 +421,11 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + model: LinearRegressionModel, + diagInvAtWA: Array[Double], val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol) { + extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { /** Number of training iterations until termination */ val totalIterations = objectiveHistory.length @@ -430,7 +441,9 @@ class LinearRegressionTrainingSummary private[regression] ( class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, - val labelCol: String) extends Serializable { + val labelCol: String, + val model: LinearRegressionModel, + val diagInvAtWA: Array[Double]) extends Serializable { @transient private val metrics = new RegressionMetrics( predictions @@ -474,6 +487,75 @@ class LinearRegressionSummary private[regression] ( predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) } + /** Number of instances in DataFrame predictions */ + lazy val numInstances: Long = predictions.count() + + /** Degrees of freedom */ + private val degreesOfFreedom: Long = if (model.getFitIntercept) { + numInstances - model.coefficients.size - 1 + } else { + numInstances - model.coefficients.size + } + + /** + * The weighted residuals, the usual residuals rescaled by + * the square root of the instance weights. + */ + lazy val devianceResiduals: Array[Double] = { + val weighted = if (model.getWeightCol.isEmpty) lit(1.0) else sqrt(col(model.getWeightCol)) + val dr = predictions.select(col(model.getLabelCol).minus(col(model.getPredictionCol)) + .multiply(weighted).as("weightedResiduals")) + .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) + .first() + Array(dr.getDouble(0), dr.getDouble(1)) + } + + /** + * Standard error of estimated coefficients. + * Note that standard error of estimated intercept is not supported currently. + */ + lazy val coefficientStandardErrors: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No Std. Error of coefficients available for this LinearRegressionModel") + } else { + val rss = if (model.getWeightCol.isEmpty) { + meanSquaredError * numInstances + } else { + val t = udf { (pred: Double, label: Double, weight: Double) => + math.pow(label - pred, 2.0) * weight } + predictions.select(t(col(model.getPredictionCol), col(model.getLabelCol), + col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) + } + val sigma2 = rss / degreesOfFreedom + diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + } + } + + /** T-statistic of estimated coefficients. + * Note that t-statistic of estimated intercept is not supported currently. + */ + lazy val tValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No t-statistic available for this LinearRegressionModel") + } else { + model.coefficients.toArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + } + } + + /** Two-sided p-value of estimated coefficients. + * Note that p-value of estimated intercept is not supported currently. + */ + lazy val pValues: Array[Double] = { + if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { + throw new UnsupportedOperationException( + "No p-value available for this LinearRegressionModel") + } else { + tValues.map { x => 2.0 * (1.0 - StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } + } + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 66eb40b6f4a69..0cd371e9cce34 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -40,4 +40,20 @@ private[spark] object CholeskyDecomposition { assert(code == 0, s"lapack.dpotrs returned $code.") bx } + + /** + * Computes the inverse of a real symmetric positive definite matrix A + * using the Cholesky factorization A = U**T*U. + * The input arguments are modified in-place to store the inverse matrix. + * @param UAi the upper triangular factor U from the Cholesky factorization A = U**T*U + * @param k the dimension of A + * @return the upper triangle of the (symmetric) inverse of A + */ + def inverse(UAi: Array[Double], k: Int): Array[Double] = { + val info = new intW(0) + lapack.dpptri("U", k, UAi, info) + val code = info.`val` + assert(code == 0, s"lapack.dpptri returned $code.") + UAi + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 235c796d785a6..fbf83e8922861 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -35,6 +35,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var datasetWithDenseFeature: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ + @transient var datasetWithWeight: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -73,6 +74,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { xMean = Seq.fill(featureSize)(r.nextDouble).toArray, xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) + + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + df <- as.data.frame(cbind(A, b)) + */ + datasetWithWeight = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -603,6 +620,16 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { // To clalify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) + val devianceResidualsR = Array(-0.35566, 0.34504) + val seCoefR = Array(0.0011756, 0.0009032) + val tValsR = Array(3998, 7971) + val pValsR = Array(0, 0) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } + model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } } } @@ -725,4 +752,106 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .sliding(2) .forall(x => x(0) >= x(1))) } + + test("linear regression summary with weighted samples and intercept by normal solver") { + /* + R code: + + model <- glm(formula = "b ~ .", data = df, weights = w) + summary(model) + + Call: + glm(formula = "b ~ .", data = df, weights = w) + + Deviance Residuals: + 1 2 3 4 + 1.920 -1.358 -1.109 0.960 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 18.080 9.608 1.882 0.311 + V1 6.080 5.556 1.094 0.471 + V2 -0.600 1.960 -0.306 0.811 + + (Dispersion parameter for gaussian family taken to be 7.68) + + Null deviance: 202.00 on 3 degrees of freedom + Residual deviance: 7.68 on 1 degrees of freedom + AIC: 18.783 + + Number of Fisher Scoring iterations: 2 + */ + + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(6.080, -0.600)) + val interceptR = 18.080 + val devianceResidualsR = Array(-1.358, 1.920) + val seCoefR = Array(5.556, 1.960) + val tValsR = Array(1.094, -0.306) + val pValsR = Array(0.471, 0.811) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept ~== interceptR absTol 1E-3) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + } + + test("linear regression summary with weighted samples and w/o intercept by normal solver") { + /* + R code: + + model <- glm(formula = "b ~ . -1", data = df, weights = w) + summary(model) + + Call: + glm(formula = "b ~ . -1", data = df, weights = w) + + Deviance Residuals: + 1 2 3 4 + 1.950 2.344 -4.600 2.103 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + V1 -3.7271 2.9032 -1.284 0.3279 + V2 3.0100 0.6022 4.998 0.0378 * + --- + Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 + + (Dispersion parameter for gaussian family taken to be 17.4376) + + Null deviance: 5962.000 on 4 degrees of freedom + Residual deviance: 34.875 on 2 degrees of freedom + AIC: 22.835 + + Number of Fisher Scoring iterations: 2 + */ + + val model = new LinearRegression() + .setWeightCol("weight") + .setSolver("normal") + .setFitIntercept(false) + .fit(datasetWithWeight) + val coefficientsR = Vectors.dense(Array(-3.7271, 3.0100)) + val interceptR = 0.0 + val devianceResidualsR = Array(-4.600, 2.344) + val seCoefR = Array(2.9032, 0.6022) + val tValsR = Array(-1.284, 4.998) + val pValsR = Array(0.3279, 0.0378) + + assert(model.coefficients ~== coefficientsR absTol 1E-3) + assert(model.intercept === interceptR) + model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => + assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } + } } From 3434572b141075f00698d94e6ee80febd3093c3b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 3 Nov 2015 08:31:16 -0800 Subject: [PATCH 147/324] [MINOR][ML] Fix naming conventions of AFTSurvivalRegression coefficients Rename ```regressionCoefficients``` back to ```coefficients```, and name ```weights``` to ```parameters```. See discussion [here](https://github.com/apache/spark/pull/9311/files#diff-e277fd0bc21f825d3196b4551c01fe5fR230). mengxr vectorijk dbtsai Author: Yanbo Liang Closes #9431 from yanboliang/aft-coefficients. --- .../ml/regression/AFTSurvivalRegression.scala | 38 +++++++++---------- .../AFTSurvivalRegressionSuite.scala | 12 +++--- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4dbbc7d39931b..b7d095872ffa5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -200,17 +200,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size /* - The coefficients vector has three parts: + The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter the second element: Double, intercept of the beta parameter the third to the end elements: Doubles, regression coefficients vector of the beta parameter */ - val initialCoefficients = Vectors.zeros(numFeatures + 2) + val initialParameters = Vectors.zeros(numFeatures + 2) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficients.toBreeze.toDenseVector) + initialParameters.toBreeze.toDenseVector) - val coefficients = { + val parameters = { val arrayBuilder = mutable.ArrayBuilder.make[Double] var state: optimizer.State = null while (states.hasNext) { @@ -227,10 +227,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val regressionCoefficients = Vectors.dense(coefficients.slice(2, coefficients.length)) - val intercept = coefficients(1) - val scale = math.exp(coefficients(0)) - val model = new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale) + val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val intercept = parameters(1) + val scale = math.exp(parameters(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) copyValues(model.setParent(this)) } @@ -251,7 +251,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("1.6.0") class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override val uid: String, - @Since("1.6.0") val regressionCoefficients: Vector, + @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { @@ -275,7 +275,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime - val lambda = math.exp(BLAS.dot(regressionCoefficients, features) + intercept) + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) // shape parameter for the Weibull distribution of lifetime val k = 1 / scale val quantiles = $(quantileProbabilities).map { @@ -286,7 +286,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def predict(features: Vector): Double = { - math.exp(BLAS.dot(regressionCoefficients, features) + intercept) + math.exp(BLAS.dot(coefficients, features) + intercept) } @Since("1.6.0") @@ -309,7 +309,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { - copyValues(new AFTSurvivalRegressionModel(uid, regressionCoefficients, intercept, scale), extra) + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } } @@ -369,17 +369,17 @@ class AFTSurvivalRegressionModel private[ml] ( * \frac{\partial (-\iota)}{\partial (\log\sigma)}= * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] * }}} - * @param coefficients including three part: The log of scale parameter, the intercept and + * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. */ -private class AFTAggregator(coefficients: BDV[Double], fitIntercept: Boolean) +private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { // beta is the intercept and regression coefficients to the covariates - private val beta = coefficients.slice(1, coefficients.length) + private val beta = parameters.slice(1, parameters.length) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(coefficients(0)) + private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 @@ -449,15 +449,15 @@ private class AFTAggregator(coefficients: BDV[Double], fitIntercept: Boolean) /** * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. - * It returns the loss and gradient at a particular point (coefficients). + * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) extends DiffFunction[BDV[Double]] { - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index c0f791bce13d1..359f31027172b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -141,12 +141,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(-0.039) + val coefficientsR = Vectors.dense(-0.039) val interceptR = 1.759 val scaleR = 1.41 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -212,12 +212,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 5 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(-0.0844, 0.0677) + val coefficientsR = Vectors.dense(-0.0844, 0.0677) val interceptR = 1.9206 val scaleR = 0.977 assert(model.intercept ~== interceptR relTol 1E-3) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* @@ -282,12 +282,12 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex Number of Newton-Raphson Iterations: 6 n= 1000 */ - val regressionCoefficientsR = Vectors.dense(0.896, -0.709) + val coefficientsR = Vectors.dense(0.896, -0.709) val interceptR = 0.0 val scaleR = 1.52 assert(model.intercept === interceptR) - assert(model.regressionCoefficients ~== regressionCoefficientsR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) assert(model.scale ~== scaleR relTol 1E-3) /* From f54ff19b1edd4903950cb334987a447445fa97ef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 3 Nov 2015 08:32:37 -0800 Subject: [PATCH 148/324] [SPARK-11349][ML] Support transform string label for RFormula Currently ```RFormula``` can only handle label with ```NumericType``` or ```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression training), we should also support label of ```StringType``` which is needed for Logistic Regression (glm with family = "binomial"). For label of ```StringType```, we should use ```StringIndexer``` to transform it to 0-based index. Author: Yanbo Liang Closes #9302 from yanboliang/spark-11349. --- .../apache/spark/ml/feature/RFormula.scala | 10 +++++++++- .../spark/ml/feature/RFormulaSuite.scala | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f9b840097f3ed..5c43a41bee3b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R .setOutputCol($(featuresCol)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) + + if (dataset.schema.fieldNames.contains(resolvedFormula.label) && + dataset.schema(resolvedFormula.label).dataType == StringType) { + encoderStages += new StringIndexer() + .setInputCol(resolvedFormula.label) + .setOutputCol($(labelCol)) + } + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } @@ -172,7 +180,7 @@ class RFormulaModel private[feature]( override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) - if (hasLabelCol(schema)) { + if (hasLabelCol(withFeatures)) { withFeatures } else if (schema.exists(_.name == resolvedFormula.label)) { val nullable = schema(resolvedFormula.label).dataType match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b56013008b116..dc20a5ec2152d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.collect() === expected.collect()) } + test("index string label") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), + ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)) + ).toDF("id", "a", "b", "features", "label") + // assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") val original = sqlContext.createDataFrame( From b2e4b314d989de8cad012bbddba703b31d8378a4 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Tue, 3 Nov 2015 08:51:40 -0800 Subject: [PATCH 149/324] [SPARK-9790][YARN] Expose in WebUI if NodeManager is the reason why executors were killed. Author: Mark Grover Closes #8093 from markgrover/nm2. --- .../main/scala/org/apache/spark/TaskEndReason.scala | 8 ++++++-- .../scala/org/apache/spark/rpc/RpcEndpointRef.scala | 4 ++-- .../org/apache/spark/scheduler/TaskSetManager.scala | 4 ++-- .../cluster/CoarseGrainedSchedulerBackend.scala | 5 +++-- .../scheduler/cluster/YarnSchedulerBackend.scala | 1 + .../scala/org/apache/spark/util/JsonProtocol.scala | 11 ++++++++--- .../spark/ui/jobs/JobProgressListenerSuite.scala | 2 +- .../org/apache/spark/util/JsonProtocolSuite.scala | 11 ++++++----- 8 files changed, 29 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 18278b292ff5a..13241b77bf97b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -223,8 +223,10 @@ case class TaskCommitDenied( * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String, exitCausedByApp: Boolean = true) - extends TaskFailedReason { +case class ExecutorLostFailure( + execId: String, + exitCausedByApp: Boolean = true, + reason: Option[String]) extends TaskFailedReason { override def toErrorString: String = { val exitBehavior = if (exitCausedByApp) { "caused by one of the running tasks" @@ -232,6 +234,8 @@ case class ExecutorLostFailure(execId: String, exitCausedByApp: Boolean = true) "unrelated to the running tasks" } s"ExecutorLostFailure (executor ${execId} exited due to an issue ${exitBehavior})" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + + reason.map { r => s" Reason: $r" }.getOrElse("") } override def countTowardsTaskFailures: Boolean = exitCausedByApp diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index f25710bb5bd6e..623da3e9c11b8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -67,7 +67,7 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this * method retries, the message handling in the receiver side should be idempotent. * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. * * @param message the message to send @@ -82,7 +82,7 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method * retries, the message handling in the receiver side should be idempotent. * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. * * @param message the message to send diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 9b3fad9012abc..114468c48c44c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -802,8 +802,8 @@ private[spark] class TaskSetManager( case exited: ExecutorExited => exited.exitCausedByApp case _ => true } - handleFailedTask( - tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp)) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, + Some(reason.toString))) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 439a11927026b..ebce5021b19dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -125,7 +125,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } - } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -195,7 +194,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onDisconnected(remoteAddress: RpcAddress): Unit = { addressToExecutorId .get(remoteAddress) - .foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated"))) + .foreach(removeExecutor(_, SlaveLost("Remote RPC client disassociated. Likely due to " + + "containers exceeding thresholds, or network issues. Check driver logs for WARN " + + "messages."))) } // Make fake resource offers on just one executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index cb24072d7d941..d75d6f673e84e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -175,6 +175,7 @@ private[spark] abstract class YarnSchedulerBackend( addWebUIFilter(filterName, filterParams, proxyBase) case RemoveExecutor(executorId, reason) => + logWarning(reason.toString) removeExecutor(executorId, reason) } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ad6615c1124d0..ee2eb58cf5e2a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -367,9 +367,10 @@ private[spark] object JsonProtocol { ("Job ID" -> taskCommitDenied.jobID) ~ ("Partition ID" -> taskCommitDenied.partitionID) ~ ("Attempt Number" -> taskCommitDenied.attemptNumber) - case ExecutorLostFailure(executorId, exitCausedByApp) => + case ExecutorLostFailure(executorId, exitCausedByApp, reason) => ("Executor ID" -> executorId) ~ - ("Exit Caused By App" -> exitCausedByApp) + ("Exit Caused By App" -> exitCausedByApp) ~ + ("Loss Reason" -> reason.map(_.toString)) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -812,7 +813,11 @@ private[spark] object JsonProtocol { case `executorLostFailure` => val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - ExecutorLostFailure(executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true)) + val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + ExecutorLostFailure( + executorId.getOrElse("Unknown"), + exitCausedByApp.getOrElse(true), + reason) case `unknownReason` => UnknownReason } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index b140387d309f3..e02f5a1b20fe3 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -243,7 +243,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, - ExecutorLostFailure("0"), + ExecutorLostFailure("0", true, Some("Induced failure")), UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 86137f259c13d..953456c2caa89 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -152,7 +152,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) testTaskEndReason(TaskCommitDenied(2, 3, 4)) - testTaskEndReason(ExecutorLostFailure("100", true)) + testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) // BlockId @@ -296,10 +296,10 @@ class JsonProtocolSuite extends SparkFunSuite { test("ExecutorLostFailure backward compatibility") { // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. - val executorLostFailure = ExecutorLostFailure("100", true) + val executorLostFailure = ExecutorLostFailure("100", true, Some("Induced failure")) val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) .removeField({ _._1 == "Executor ID" }) - val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown", true, Some("Induced failure")) assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -603,10 +603,11 @@ class JsonProtocolSuite extends SparkFunSuite { assert(jobId1 === jobId2) assert(partitionId1 === partitionId2) assert(attemptNumber1 === attemptNumber2) - case (ExecutorLostFailure(execId1, exit1CausedByApp), - ExecutorLostFailure(execId2, exit2CausedByApp)) => + case (ExecutorLostFailure(execId1, exit1CausedByApp, reason1), + ExecutorLostFailure(execId2, exit2CausedByApp, reason2)) => assert(execId1 === execId2) assert(exit1CausedByApp === exit2CausedByApp) + assert(reason1 === reason2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } From ebf8b0b48deaad64f7ca27051caee763451e2623 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 3 Nov 2015 10:07:45 -0800 Subject: [PATCH 150/324] [SPARK-10978][SQL] Allow data sources to eliminate filters This PR adds a new method `unhandledFilters` to `BaseRelation`. Data sources which implement this method properly may avoid the overhead of defensive filtering done by Spark SQL. Author: Cheng Lian Closes #9399 from liancheng/spark-10978.unhandled-filters. --- .../datasources/DataSourceStrategy.scala | 131 ++++++++++++++---- .../apache/spark/sql/sources/interfaces.scala | 9 ++ .../parquet/ParquetFilterSuite.scala | 2 +- .../spark/sql/sources/FilteredScanSuite.scala | 129 ++++++++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 47 ++++++- .../sql/sources/SimpleTextRelation.scala | 65 ++++++++- 6 files changed, 315 insertions(+), 68 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 65859865c8fbc..7265d6a4de2e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -43,7 +43,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l, projects, filters, - (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil + (requestedColumns, allPredicates, _) => + toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _)) => pruneFilterProject( @@ -266,47 +267,81 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { relation, projects, filterPredicates, - (requestedColumns, pushedFilters) => { - scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray) + (requestedColumns, _, pushedFilters) => { + scanBuilder(requestedColumns, pushedFilters.toArray) }) } - // Based on Catalyst expressions. + // Based on Catalyst expressions. The `scanBuilder` function accepts three arguments: + // + // 1. A `Seq[Attribute]`, containing all required column attributes. Used to handle relation + // traits that support column pruning (e.g. `PrunedScan` and `PrunedFilteredScan`). + // + // 2. A `Seq[Expression]`, containing all gathered Catalyst filter expressions, only used for + // `CatalystScan`. + // + // 3. A `Seq[Filter]`, containing all data source `Filter`s that are converted from (possibly a + // subset of) Catalyst filter expressions and can be handled by `relation`. Used to handle + // relation traits (`CatalystScan` excluded) that support filter push-down (e.g. + // `PrunedFilteredScan` and `HadoopFsRelation`). + // + // Note that 2 and 3 shouldn't be used together. protected def pruneFilterProjectRaw( - relation: LogicalRelation, - projects: Seq[NamedExpression], - filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = { + relation: LogicalRelation, + projects: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]) = { val projectSet = AttributeSet(projects.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = filterPredicates.reduceLeftOption(expressions.And) - val pushedFilters = filterPredicates.map { _ transform { + val candidatePredicates = filterPredicates.map { _ transform { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }} + val (unhandledPredicates, pushedFilters) = + selectFilters(relation.relation, candidatePredicates) + + // A set of column attributes that are only referenced by pushed down filters. We can eliminate + // them from requested columns. + val handledSet = { + val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains) + val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references)) + AttributeSet(handledPredicates.flatMap(_.references)) -- + (projectSet ++ unhandledSet).map(relation.attributeMap) + } + + // Combines all Catalyst filter `Expression`s that are either not convertible to data source + // `Filter`s or cannot be handled by `relation`. + val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. - val requestedColumns = - projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. - .map(relation.attributeMap) // Match original case of attributes. + val requestedColumns = projects + // Safe due to if above. + .asInstanceOf[Seq[Attribute]] + // Match original case of attributes. + .map(relation.attributeMap) + // Don't request columns that are only referenced by pushed filters. + .filterNot(handledSet.contains) val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters), + scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { - val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + // Don't request columns that are only referenced by pushed filters. + val requestedColumns = + (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, - scanBuilder(requestedColumns, pushedFilters), + scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation) execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } @@ -334,11 +369,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } /** - * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, - * and convert them. + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ - protected[sql] def selectFilters(filters: Seq[Expression]) = { - def translate(predicate: Expression): Option[Filter] = predicate match { + protected[sql] def translateFilter(predicate: Expression): Option[Filter] = { + predicate match { case expressions.EqualTo(a: Attribute, Literal(v, t)) => Some(sources.EqualTo(a.name, convertToScala(v, t))) case expressions.EqualTo(Literal(v, t), a: Attribute) => @@ -387,16 +423,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.IsNotNull(a.name)) case expressions.And(left, right) => - (translate(left) ++ translate(right)).reduceOption(sources.And) + (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And) case expressions.Or(left, right) => for { - leftFilter <- translate(left) - rightFilter <- translate(right) + leftFilter <- translateFilter(left) + rightFilter <- translateFilter(right) } yield sources.Or(leftFilter, rightFilter) case expressions.Not(child) => - translate(child).map(sources.Not) + translateFilter(child).map(sources.Not) case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringStartsWith(a.name, v.toString)) @@ -409,7 +445,52 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case _ => None } + } + + /** + * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s + * and can be handled by `relation`. + * + * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element contains all Catalyst + * predicate [[Expression]]s that are either not convertible or cannot be handled by + * `relation`. The second element contains all converted data source [[Filter]]s that can + * be handled by `relation`. + */ + protected[sql] def selectFilters( + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = { + + // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are + // called `predicate`s, while all data source filters of type `sources.Filter` are simply called + // `filter`s. + + val translated: Seq[(Expression, Filter)] = + for { + predicate <- predicates + filter <- translateFilter(predicate) + } yield predicate -> filter + + // A map from original Catalyst expressions to corresponding translated data source filters. + val translatedMap: Map[Expression, Filter] = translated.toMap + + // Catalyst predicate expressions that cannot be translated to data source filters. + val unrecognizedPredicates = predicates.filterNot(translatedMap.contains) + + // Data source filters that cannot be handled by `relation` + val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet + + val (unhandled, handled) = translated.partition { + case (predicate, filter) => + unhandledFilters.contains(filter) + } + + // Catalyst predicate expressions that can be translated to data source filters, but cannot be + // handled by `relation`. + val (unhandledPredicates, _) = unhandled.unzip + + // Translated data source filters that can be handled by `relation` + val (_, handledFilters) = handled.unzip - filters.flatMap(translate) + (unrecognizedPredicates ++ unhandledPredicates, handledFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7a553511483ff..e296d631f0f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -233,6 +233,15 @@ abstract class BaseRelation { * @since 1.4.0 */ def needConversion: Boolean = true + + /** + * Given an array of [[Filter]]s, returns an array of [[Filter]]s that this data source relation + * cannot handle. Spark SQL will apply all returned [[Filter]]s against rows returned by this + * data source relation. + * + * @since 1.6.0 + */ + def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f88ddc77a6a4e..c24c9f025dad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -59,7 +59,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex }.flatten assert(analyzedPredicate.nonEmpty) - val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate) + val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter) assert(selectedFilters.nonEmpty) selectedFilters.foreach { pred => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 68ce37c00077e..7541e723029bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import org.apache.spark.sql.execution.datasources.LogicalRelation + import scala.language.existentials import org.apache.spark.rdd.RDD @@ -44,16 +46,39 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL StructField("b", IntegerType, nullable = false) :: StructField("c", StringType, nullable = false) :: Nil) + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + def unhandled(filter: Filter): Boolean = { + filter match { + case EqualTo(col, v) => col == "b" + case EqualNullSafe(col, v) => col == "b" + case LessThan(col, v: Int) => col == "b" + case LessThanOrEqual(col, v: Int) => col == "b" + case GreaterThan(col, v: Int) => col == "b" + case GreaterThanOrEqual(col, v: Int) => col == "b" + case In(col, values) => col == "b" + case IsNull(col) => col == "b" + case IsNotNull(col) => col == "b" + case Not(pred) => unhandled(pred) + case And(left, right) => unhandled(left) || unhandled(right) + case Or(left, right) => unhandled(left) || unhandled(right) + case _ => false + } + } + + filters.filter(unhandled) + } + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val rowBuilders = requiredColumns.map { case "a" => (i: Int) => Seq(i) case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase() * 5) + Seq(c * 5 + c.toUpperCase * 5) } FiltersPushed.list = filters + ColumnsRequired.set = requiredColumns.toSet // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { @@ -86,9 +111,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5 - !filters.map(translateFilterOnA(_)(a)).contains(false) && - !filters.map(translateFilterOnC(_)(c)).contains(false) + val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => @@ -101,6 +125,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } +// Used together with `SimpleFilteredScan` to check pushed columns. +object ColumnsRequired { + var set: Set[String] = Set.empty +} + class FilteredScanSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ @@ -115,12 +144,15 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { | to '10' |) """.stripMargin) + + // UDF for testing filter push-down + caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3) } sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", @@ -202,49 +234,64 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { "SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5))) - testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b", "c")) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1, Set("a")) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1, Set("b")) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1, Set("a", "b", "c")) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, Set("a", "b", "c")) + testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4) - testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0, Set("a", "b", "c")) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0, Set("a", "b", "c")) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1, Set("a", "b", "c")) + testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0, Set("a", "b", "c")) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) - testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c")) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c")) - testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) - testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + // Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down. + testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c")) - def testPushDown(sqlString: String, expectedCount: Int): Unit = { + // A query with an unconvertible filter, an unhandled filter, and a handled filter. + testPushDown( + """SELECT a + | FROM oneToTenFiltered + | WHERE udf_gt3(b) + | AND b < 16 + | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') + """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b")) + + def testPushDown( + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String]): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { @@ -254,6 +301,17 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") } val rawCount = rawPlan.execute().count() + assert(ColumnsRequired.set === requiredColumnNames) + + assert { + val table = caseInsensitiveContext.table("oneToTenFiltered") + val relation = table.queryExecution.logical.collectFirst { + case LogicalRelation(r, _) => r + }.get + + // `relation` should be able to handle all pushed filters + relation.unhandledFilters(FiltersPushed.list.toArray).isEmpty + } if (rawCount != expectedCount) { fail( @@ -264,4 +322,3 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { } } } - diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index a3a124488d983..d945408341fc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -18,11 +18,16 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path - import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.execution.PhysicalRDD +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + import testImplicits._ + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName // We have a very limited number of supported types at here since it is just for a @@ -64,4 +69,44 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { .load(file.getCanonicalPath)) } } + + private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName) + private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName) + + test("unhandledFilters") { + withTempPath { dir => + + val path = dir.getCanonicalPath + writer.save(s"$path/p=0") + writer.save(s"$path/p=1") + + val isOdd = udf((_: Int) % 2 == 1) + val df = reader.load(path) + .filter( + // This filter is inconvertible + isOdd('a) && + // This filter is convertible but unhandled + 'a > 1 && + // This filter is convertible and handled + 'b > "val_1" && + // This filter references a partiiton column, won't be pushed down + 'p === 1 + ).select('a, 'p) + val rawScan = df.queryExecution.executedPlan collect { + case p: PhysicalRDD => p + } match { + case Seq(p) => p + } + + val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType) + + assertResult(Set((2, 1), (3, 1))) { + rawScan.execute().collect() + .map { CatalystTypeConverters.convertToScala(_, outputSchema) } + .map { case Row(a, p) => (a, p) }.toSet + } + + checkAnswer(df, Row(3, 1)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index aeaaa3e1c5220..da09e1b00ae48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources import java.text.NumberFormat -import java.util.UUID import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} @@ -26,12 +25,12 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.spark.rdd.RDD import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. @@ -124,6 +123,53 @@ class SimpleTextRelation( } } + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus]): RDD[Row] = { + + val fields = this.dataSchema.map(_.dataType) + val inputAttributes = this.dataSchema.toAttributes + val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name)) + val dataSchema = this.dataSchema + + val inputPaths = inputFiles.map(_.getPath).mkString(",") + sparkContext.textFile(inputPaths).mapPartitions { iterator => + // Constructs a filter predicate to simulate filter push-down + val predicate = { + val filterCondition: Expression = filters.collect { + // According to `unhandledFilters`, `SimpleTextRelation` only handles `GreaterThan` filter + case sources.GreaterThan(column, value) => + val dataType = dataSchema(column).dataType + val literal = Literal.create(value, dataType) + val attribute = inputAttributes.find(_.name == column).get + expressions.GreaterThan(attribute, literal) + }.reduceOption(expressions.And).getOrElse(Literal(true)) + InterpretedPredicate.create(filterCondition, inputAttributes) + } + + // Uses a simple projection to simulate column pruning + val projection = new InterpretedMutableProjection(outputAttributes, inputAttributes) + val toScala = { + val requiredSchema = StructType.fromAttributes(outputAttributes) + CatalystTypeConverters.createToScalaConverter(requiredSchema) + } + + iterator.map { record => + new GenericInternalRow(record.split(",", -1).zip(fields).map { + case (v, dataType) => + val value = if (v == "") null else v + // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) + Cast(Literal(value), dataType).eval() + }) + }.filter { row => + predicate(row) + }.map { row => + toScala(projection(row)).asInstanceOf[Row] + } + } + } + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) @@ -134,6 +180,15 @@ class SimpleTextRelation( new SimpleTextOutputWriter(path, context) } } + + // `SimpleTextRelation` only handles `GreaterThan` filter. This is used to test filter push-down + // and `BaseRelation.unhandledFilters()`. + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { + filters.filter { + case _: GreaterThan => false + case _ => true + } + } } /** From a9676cc7107c5df6c62a58668c4d95ced1238370 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 3 Nov 2015 11:53:10 -0800 Subject: [PATCH 151/324] [SPARK-11407][SPARKR] Add doc for running from RStudio ![image](https://cloud.githubusercontent.com/assets/8969467/10871746/612ba44a-80a4-11e5-99a0-40b9931dee52.png) (This is without css, but you get the idea) shivaram Author: felixcheung Closes #9401 from felixcheung/rstudioprogrammingguide. --- docs/sparkr.md | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 497a276679f3b..437bd4756c276 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -30,14 +30,22 @@ The entry point into SparkR is the `SparkContext` which connects your R program You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name , any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the SparkContext. If you are working from the `sparkR` shell, the -`SQLContext` and `SparkContext` should already be created for you. +`SQLContext` and `SparkContext` should already be created for you, and you would not need to call +`sparkR.init`. +
    {% highlight r %} sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) {% endhighlight %} +
    + +## Starting Up from RStudio -In the event you are creating `SparkContext` instead of using `sparkR` shell or `spark-submit`, you +You can also start SparkR from RStudio. You can connect your R program to a Spark cluster from +RStudio, R shell, Rscript or other R IDEs. To start, make sure SPARK_HOME is set in environment +(you can check [Sys.getenv](https://stat.ethz.ch/R-manual/R-devel/library/base/html/Sys.getenv.html)), +load the SparkR package, and call `sparkR.init` as below. In addition to calling `sparkR.init`, you could also specify certain Spark driver properties. Normally these [Application properties](configuration.html#application-properties) and [Runtime Environment](configuration.html#runtime-environment) cannot be set programmatically, as the @@ -45,9 +53,41 @@ driver JVM process would have been started, in this case SparkR takes care of th them, pass them as you would other configuration properties in the `sparkEnvir` argument to `sparkR.init()`. +
    {% highlight r %} -sc <- sparkR.init("local[*]", "SparkR", "/home/spark", list(spark.driver.memory="2g")) +if (nchar(Sys.getenv("SPARK_HOME")) < 1) { + Sys.setenv(SPARK_HOME = "/home/spark") +} +library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))) +sc <- sparkR.init(master = "local[*]", sparkEnvir = list(spark.driver.memory="2g")) {% endhighlight %} +
    + +The following options can be set in `sparkEnvir` with `sparkR.init` from RStudio: + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameProperty groupspark-submit equivalent
    spark.driver.memoryApplication Properties--driver-memory
    spark.driver.extraClassPathRuntime Environment--driver-class-path
    spark.driver.extraJavaOptionsRuntime Environment--driver-java-options
    spark.driver.extraLibraryPathRuntime Environment--driver-library-path
    From 1d04dc95c0d3caa485936e65b0493bcc9719f27e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 3 Nov 2015 13:33:46 -0800 Subject: [PATCH 152/324] [SPARK-11467][SQL] add Python API for stddev/variance Add Python API for stddev/stddev_pop/stddev_samp/variance/var_pop/var_samp/skewness/kurtosis Author: Davies Liu Closes #9424 from davies/py_var. --- python/pyspark/sql/functions.py | 17 ++++ python/pyspark/sql/group.py | 88 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 67 -------------- 3 files changed, 105 insertions(+), 67 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fa04f4cd83b6f..2f7c2f4aacd47 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -122,6 +122,21 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_1_6 = { + # unary math functions + "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + + " the expression in a group.", + "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + + " the expression in a group.", + "stddev_pop": "Aggregate function: returns population standard deviation of" + + " the expression in a group.", + "variance": "Aggregate function: returns the population variance of the values in a group.", + "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", + "var_pop": "Aggregate function: returns the population variance of the values in a group.", + "skewness": "Aggregate function: returns the skewness of the values in a group.", + "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." +} + # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + @@ -172,6 +187,8 @@ def _(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) +for _name, _doc in _functions_1_6.items(): + globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 71c0bccc5eeff..946b53e71c2c6 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -167,6 +167,94 @@ def sum(self, *cols): [Row(sum(age)=7, sum(height)=165)] """ + @df_varargs_api + @since(1.6) + def stddev(self, *cols): + """Compute the sample standard deviation for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().stddev('age', 'height').collect() + [Row(STDDEV(age)=2.12..., STDDEV(height)=3.53...)] + """ + + @df_varargs_api + @since(1.6) + def stddev_samp(self, *cols): + """Compute the sample standard deviation for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().stddev_samp('age', 'height').collect() + [Row(STDDEV_SAMP(age)=2.12..., STDDEV_SAMP(height)=3.53...)] + """ + + @df_varargs_api + @since(1.6) + def stddev_pop(self, *cols): + """Compute the population standard deviation for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().stddev_pop('age', 'height').collect() + [Row(STDDEV_POP(age)=1.5, STDDEV_POP(height)=2.5)] + """ + + @df_varargs_api + @since(1.6) + def variance(self, *cols): + """Compute the sample variance for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().variance('age', 'height').collect() + [Row(VARIANCE(age)=2.25, VARIANCE(height)=6.25)] + """ + + @df_varargs_api + @since(1.6) + def var_pop(self, *cols): + """Compute the sample variance for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().var_pop('age', 'height').collect() + [Row(VAR_POP(age)=2.25, VAR_POP(height)=6.25)] + """ + + @df_varargs_api + @since(1.6) + def var_samp(self, *cols): + """Compute the sample variance for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().var_samp('age', 'height').collect() + [Row(VAR_SAMP(age)=4.5, VAR_SAMP(height)=12.5)] + """ + + @df_varargs_api + @since(1.6) + def skewness(self, *cols): + """Compute the skewness for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().skewness('age', 'height').collect() + [Row(SKEWNESS(age)=0.0, SKEWNESS(height)=0.0)] + """ + + @df_varargs_api + @since(1.6) + def kurtosis(self, *cols): + """Compute the kurtosis for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df3.groupBy().kurtosis('age', 'height').collect() + [Row(KURTOSIS(age)=-2.0, KURTOSIS(height)=-2.0)] + """ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5a5c695e6ab3b..c8c52831668cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -254,14 +254,6 @@ object functions { */ def kurtosis(e: Column): Column = Kurtosis(e.expr) - /** - * Aggregate function: returns the kurtosis of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) - /** * Aggregate function: returns the last value in a group. * @@ -336,14 +328,6 @@ object functions { */ def skewness(e: Column): Column = Skewness(e.expr) - /** - * Aggregate function: returns the skewness of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def skewness(columnName: String): Column = skewness(Column(columnName)) - /** * Aggregate function: returns the unbiased sample standard deviation of * the expression in a group. @@ -353,15 +337,6 @@ object functions { */ def stddev(e: Column): Column = Stddev(e.expr) - /** - * Aggregate function: returns the unbiased sample standard deviation of - * the expression in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def stddev(columnName: String): Column = stddev(Column(columnName)) - /** * Aggregate function: returns the unbiased sample standard deviation of * the expression in a group. @@ -371,15 +346,6 @@ object functions { */ def stddev_samp(e: Column): Column = StddevSamp(e.expr) - /** - * Aggregate function: returns the unbiased sample standard deviation of - * the expression in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) - /** * Aggregate function: returns the population standard deviation of * the expression in a group. @@ -389,15 +355,6 @@ object functions { */ def stddev_pop(e: Column): Column = StddevPop(e.expr) - /** - * Aggregate function: returns the population standard deviation of - * the expression in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) - /** * Aggregate function: returns the sum of all values in the expression. * @@ -438,14 +395,6 @@ object functions { */ def variance(e: Column): Column = Variance(e.expr) - /** - * Aggregate function: returns the population variance of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def variance(columnName: String): Column = variance(Column(columnName)) - /** * Aggregate function: returns the unbiased variance of the values in a group. * @@ -454,14 +403,6 @@ object functions { */ def var_samp(e: Column): Column = VarianceSamp(e.expr) - /** - * Aggregate function: returns the unbiased variance of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def var_samp(columnName: String): Column = var_samp(Column(columnName)) - /** * Aggregate function: returns the population variance of the values in a group. * @@ -470,14 +411,6 @@ object functions { */ def var_pop(e: Column): Column = VariancePop(e.expr) - /** - * Aggregate function: returns the population variance of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ - def var_pop(columnName: String): Column = var_pop(Column(columnName)) - ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// From f6fcb4874ce20a1daa91b7434cf9c0254a89e979 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Nov 2015 00:15:50 +0100 Subject: [PATCH 153/324] [SPARK-11477] [SQL] support create Dataset from RDD Author: Wenchen Fan Closes #9434 from cloud-fan/rdd2ds and squashes the following commits: 0892d72 [Wenchen Fan] support create Dataset from RDD --- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 9 +++++++++ .../main/scala/org/apache/spark/sql/SQLImplicits.scala | 4 ++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 7 +++++++ 3 files changed, 20 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 2cb94430e6178..5ad3871093fc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -499,6 +499,15 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } + def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { + val enc = encoderFor[T] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d)) + val plan = LogicalRDD(attributes, encoded)(self) + + new Dataset[T](this, plan) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index f460a86414c41..f2904e270811e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -48,6 +48,10 @@ abstract class SQLImplicits { implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { + DatasetHolder(_sqlContext.createDataset(rdd)) + } + implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5973fa7f2a76b..3e9b621cfd67f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("toDS with RDD") { + val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() + checkAnswer( + ds.mapPartitions(_ => Iterator(1)), + 1, 1, 1) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( From 680b4e7bca935dc1569f35fa319bdfb01a12f7e0 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 3 Nov 2015 15:26:35 -0800 Subject: [PATCH 154/324] Fix typo in WebUI Author: Jacek Laskowski Closes #9444 from jaceklaskowski/TImely-fix. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 712782d27b3cf..51425e599e748 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -49,7 +49,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ("shuffle-read-time-proportion", "Shuffle Read Time"), ("executor-runtime-proportion", "Executor Computing Time"), ("shuffle-write-time-proportion", "Shuffle Write Time"), - ("serialization-time-proportion", "Result Serialization TIme"), + ("serialization-time-proportion", "Result Serialization Time"), ("getting-result-time-proportion", "Getting Result Time")) legendPairs.zipWithIndex.map { From 53e9cee3e4e845d1f875c487215c0f22503347b1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 3 Nov 2015 16:26:28 -0800 Subject: [PATCH 155/324] [SPARK-11466][CORE] Avoid mockito in multi-threaded FsHistoryProviderSuite test. The test functionality should be the same, but without using mockito; logs don't really say anything useful but I suspect it may be the cause of the flakiness, since updating mocks when multiple threads may be using it doesn't work very well. It also allows some other cleanup (= less test code in FsHistoryProvider). Author: Marcelo Vanzin Closes #9425 from vanzin/SPARK-11466. --- .../deploy/history/FsHistoryProvider.scala | 31 ++++++-------- .../history/FsHistoryProviderSuite.scala | 42 +++++++++---------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 24aa386c7212b..718efc4f3bd5e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -113,35 +113,30 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } // Conf option used for testing the initialization code. - val initThread = if (!conf.getBoolean("spark.history.testing.skipInitialize", false)) { - initialize(None) - } else { - null - } + val initThread = initialize() - private[history] def initialize(errorHandler: Option[Thread.UncaughtExceptionHandler]): Thread = { + private[history] def initialize(): Thread = { if (!isFsInSafeMode()) { startPolling() - return null + null + } else { + startSafeModeCheckThread(None) } + } + private[history] def startSafeModeCheckThread( + errorHandler: Option[Thread.UncaughtExceptionHandler]): Thread = { // Cannot probe anything while the FS is in safe mode, so spawn a new thread that will wait // for the FS to leave safe mode before enabling polling. This allows the main history server // UI to be shown (so that the user can see the HDFS status). - // - // The synchronization in the run() method is needed because of the tests; mockito can - // misbehave if the test is modifying the mocked methods while the thread is calling - // them. val initThread = new Thread(new Runnable() { override def run(): Unit = { try { - clock.synchronized { - while (isFsInSafeMode()) { - logInfo("HDFS is still in safe mode. Waiting...") - val deadline = clock.getTimeMillis() + - TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S) - clock.waitTillTime(deadline) - } + while (isFsInSafeMode()) { + logInfo("HDFS is still in safe mode. Waiting...") + val deadline = clock.getTimeMillis() + + TimeUnit.SECONDS.toMillis(SAFEMODE_CHECK_INTERVAL_S) + clock.waitTillTime(deadline) } startPolling() } catch { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 833aab14ca2da..5cab17f8a38f5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -41,7 +41,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.io._ import org.apache.spark.scheduler._ -import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} +import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -423,22 +423,16 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("provider waits for safe mode to finish before initializing") { val clock = new ManualClock() - val conf = createTestConf().set("spark.history.testing.skipInitialize", "true") - val provider = spy(new FsHistoryProvider(conf, clock)) - doReturn(true).when(provider).isFsInSafeMode() - - val initThread = provider.initialize(None) + val provider = new SafeModeTestProvider(createTestConf(), clock) + val initThread = provider.initialize() try { provider.getConfig().keys should contain ("HDFS State") clock.setTime(5000) provider.getConfig().keys should contain ("HDFS State") - // Synchronization needed because of mockito. - clock.synchronized { - doReturn(false).when(provider).isFsInSafeMode() - clock.setTime(10000) - } + provider.inSafeMode = false + clock.setTime(10000) eventually(timeout(1 second), interval(10 millis)) { provider.getConfig().keys should not contain ("HDFS State") @@ -451,18 +445,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("provider reports error after FS leaves safe mode") { testDir.delete() val clock = new ManualClock() - val conf = createTestConf().set("spark.history.testing.skipInitialize", "true") - val provider = spy(new FsHistoryProvider(conf, clock)) - doReturn(true).when(provider).isFsInSafeMode() - + val provider = new SafeModeTestProvider(createTestConf(), clock) val errorHandler = mock(classOf[Thread.UncaughtExceptionHandler]) - val initThread = provider.initialize(Some(errorHandler)) + val initThread = provider.startSafeModeCheckThread(Some(errorHandler)) try { - // Synchronization needed because of mockito. - clock.synchronized { - doReturn(false).when(provider).isFsInSafeMode() - clock.setTime(10000) - } + provider.inSafeMode = false + clock.setTime(10000) eventually(timeout(1 second), interval(10 millis)) { verify(errorHandler).uncaughtException(any(), any()) @@ -530,4 +518,16 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc log } + private class SafeModeTestProvider(conf: SparkConf, clock: Clock) + extends FsHistoryProvider(conf, clock) { + + @volatile var inSafeMode = true + + // Skip initialization so that we can manually start the safe mode check thread. + private[history] override def initialize(): Thread = null + + private[history] override def isFsInSafeMode(): Boolean = inSafeMode + + } + } From 5051262d4ca6a2c529c9b1ba86d54cce60a7af17 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Nov 2015 16:27:56 -0800 Subject: [PATCH 156/324] [SPARK-11489][SQL] Only include common first order statistics in GroupedData We added a bunch of higher order statistics such as skewness and kurtosis to GroupedData. I don't think they are common enough to justify being listed, since users can always use the normal statistics aggregate functions. That is to say, after this change, we won't support ```scala df.groupBy("key").kurtosis("colA", "colB") ``` However, we will still support ```scala df.groupBy("key").agg(kurtosis(col("colA")), kurtosis(col("colB"))) ``` Author: Reynold Xin Closes #9446 from rxin/SPARK-11489. --- python/pyspark/sql/group.py | 88 ----------- .../org/apache/spark/sql/GroupedData.scala | 146 ++++-------------- .../apache/spark/sql/JavaDataFrameSuite.java | 1 - 3 files changed, 28 insertions(+), 207 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 946b53e71c2c6..71c0bccc5eeff 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -167,94 +167,6 @@ def sum(self, *cols): [Row(sum(age)=7, sum(height)=165)] """ - @df_varargs_api - @since(1.6) - def stddev(self, *cols): - """Compute the sample standard deviation for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().stddev('age', 'height').collect() - [Row(STDDEV(age)=2.12..., STDDEV(height)=3.53...)] - """ - - @df_varargs_api - @since(1.6) - def stddev_samp(self, *cols): - """Compute the sample standard deviation for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().stddev_samp('age', 'height').collect() - [Row(STDDEV_SAMP(age)=2.12..., STDDEV_SAMP(height)=3.53...)] - """ - - @df_varargs_api - @since(1.6) - def stddev_pop(self, *cols): - """Compute the population standard deviation for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().stddev_pop('age', 'height').collect() - [Row(STDDEV_POP(age)=1.5, STDDEV_POP(height)=2.5)] - """ - - @df_varargs_api - @since(1.6) - def variance(self, *cols): - """Compute the sample variance for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().variance('age', 'height').collect() - [Row(VARIANCE(age)=2.25, VARIANCE(height)=6.25)] - """ - - @df_varargs_api - @since(1.6) - def var_pop(self, *cols): - """Compute the sample variance for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().var_pop('age', 'height').collect() - [Row(VAR_POP(age)=2.25, VAR_POP(height)=6.25)] - """ - - @df_varargs_api - @since(1.6) - def var_samp(self, *cols): - """Compute the sample variance for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().var_samp('age', 'height').collect() - [Row(VAR_SAMP(age)=4.5, VAR_SAMP(height)=12.5)] - """ - - @df_varargs_api - @since(1.6) - def skewness(self, *cols): - """Compute the skewness for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().skewness('age', 'height').collect() - [Row(SKEWNESS(age)=0.0, SKEWNESS(height)=0.0)] - """ - - @df_varargs_api - @since(1.6) - def kurtosis(self, *cols): - """Compute the kurtosis for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df3.groupBy().kurtosis('age', 'height').collect() - [Row(KURTOSIS(age)=-2.0, KURTOSIS(height)=-2.0)] - """ - def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index dc96384a4d28d..c2b2a4013d510 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -26,42 +26,14 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType -/** - * Companion object for GroupedData - */ -private[sql] object GroupedData { - def apply( - df: DataFrame, - groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) - } - - /** - * The Grouping Type - */ - private[sql] trait GroupType - - /** - * To indicate it's the GroupBy - */ - private[sql] object GroupByType extends GroupType - - /** - * To indicate it's the CUBE - */ - private[sql] object CubeType extends GroupType - - /** - * To indicate it's the ROLLUP - */ - private[sql] object RollupType extends GroupType -} /** * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. * + * The main method is the agg function, which has multiple variants. This class also contains + * convenience some first order statistics such as mean, sum for convenience. + * * @since 1.3.0 */ @Experimental @@ -124,7 +96,7 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min - case "stddev" => Stddev + case "stddev" | "std" => Stddev case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp case "variance" => Variance @@ -255,30 +227,6 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Average) } - /** - * Compute the skewness for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the skewness values for them. - * - * @since 1.6.0 - */ - @scala.annotation.varargs - def skewness(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Skewness) - } - - /** - * Compute the kurtosis for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the kurtosis values for them. - * - * @since 1.6.0 - */ - @scala.annotation.varargs - def kurtosis(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Kurtosis) - } - /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. @@ -316,86 +264,48 @@ class GroupedData protected[sql]( } /** - * Compute the sample standard deviation for each numeric columns for each group. + * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the stddev for them. + * When specified columns are given, only compute the sum for them. * - * @since 1.6.0 + * @since 1.3.0 */ @scala.annotation.varargs - def stddev(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Stddev) + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Sum) } +} - /** - * Compute the population standard deviation for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the stddev for them. - * - * @since 1.6.0 - */ - @scala.annotation.varargs - def stddev_pop(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(StddevPop) - } - /** - * Compute the sample standard deviation for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the stddev for them. - * - * @since 1.6.0 - */ - @scala.annotation.varargs - def stddev_samp(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(StddevSamp) +/** + * Companion object for GroupedData. + */ +private[sql] object GroupedData { + + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): GroupedData = { + new GroupedData(df, groupingExprs, groupType: GroupType) } /** - * Compute the sum for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 + * The Grouping Type */ - @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) - } + private[sql] trait GroupType /** - * Compute the sample variance for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the variance for them. - * - * @since 1.6.0 + * To indicate it's the GroupBy */ - @scala.annotation.varargs - def variance(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Variance) - } + private[sql] object GroupByType extends GroupType /** - * Compute the population variance for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the variance for them. - * - * @since 1.6.0 + * To indicate it's the CUBE */ - @scala.annotation.varargs - def var_pop(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(VariancePop) - } + private[sql] object CubeType extends GroupType /** - * Compute the sample variance for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - * When specified columns are given, only compute the variance for them. - * - * @since 1.6.0 + * To indicate it's the ROLLUP */ - @scala.annotation.varargs - def var_samp(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(VarianceSamp) - } + private[sql] object RollupType extends GroupType } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a1a3fdbb486ea..49f516e86d754 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -91,7 +91,6 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); - df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions From d648a4ad546eb05deab1005e92b815b2cbea621b Mon Sep 17 00:00:00 2001 From: lewuathe Date: Tue, 3 Nov 2015 16:38:22 -0800 Subject: [PATCH 157/324] [DOC] Missing link to R DataFrame API doc Author: lewuathe Author: Lewuathe Closes #9394 from Lewuathe/missing-link-to-R-dataframe. --- R/pkg/R/DataFrame.R | 105 +++++++++++++++++++++++++++++++--- docs/sql-programming-guide.md | 2 +- 2 files changed, 98 insertions(+), 9 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 87a2c66ffd2a9..df5bc8137187b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -23,15 +23,23 @@ NULL setOldClass("jobj") #' @title S4 class that represents a DataFrame -#' @description DataFrames can be created using functions like -#' \code{jsonFile}, \code{table} etc. +#' @description DataFrames can be created using functions like \link{createDataFrame}, +#' \link{jsonFile}, \link{table} etc. +#' @family dataframe_funcs #' @rdname DataFrame -#' @seealso jsonFile, table #' @docType class #' #' @slot env An R environment that stores bookkeeping states of the DataFrame #' @slot sdf A Java object reference to the backing Scala DataFrame +#' @seealso \link{createDataFrame}, \link{jsonFile}, \link{table} +#' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' df <- createDataFrame(sqlContext, faithful) +#'} setClass("DataFrame", slots = list(env = "environment", sdf = "jobj")) @@ -46,7 +54,6 @@ setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { #' @rdname DataFrame #' @export -#' #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the dataFrame is cached dataFrame <- function(sdf, isCached = FALSE) { @@ -61,6 +68,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname printSchema #' @name printSchema #' @export @@ -85,6 +93,7 @@ setMethod("printSchema", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname schema #' @name schema #' @export @@ -108,6 +117,7 @@ setMethod("schema", #' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @family dataframe_funcs #' @rdname explain #' @name explain #' @export @@ -138,6 +148,7 @@ setMethod("explain", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname isLocal #' @name isLocal #' @export @@ -162,6 +173,7 @@ setMethod("isLocal", #' @param x A SparkSQL DataFrame #' @param numRows The number of rows to print. Defaults to 20. #' +#' @family dataframe_funcs #' @rdname showDF #' @name showDF #' @export @@ -186,6 +198,7 @@ setMethod("showDF", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname show #' @name show #' @export @@ -212,6 +225,7 @@ setMethod("show", "DataFrame", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname dtypes #' @name dtypes #' @export @@ -237,6 +251,7 @@ setMethod("dtypes", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname columns #' @name columns #' @aliases names @@ -257,6 +272,7 @@ setMethod("columns", }) }) +#' @family dataframe_funcs #' @rdname columns #' @name names setMethod("names", @@ -265,6 +281,7 @@ setMethod("names", columns(x) }) +#' @family dataframe_funcs #' @rdname columns #' @name names<- setMethod("names<-", @@ -283,6 +300,7 @@ setMethod("names<-", #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table #' +#' @family dataframe_funcs #' @rdname registerTempTable #' @name registerTempTable #' @export @@ -310,6 +328,7 @@ setMethod("registerTempTable", #' @param overwrite A logical argument indicating whether or not to overwrite #' the existing rows in the table. #' +#' @family dataframe_funcs #' @rdname insertInto #' @name insertInto #' @export @@ -334,6 +353,7 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname cache #' @name cache #' @export @@ -360,6 +380,8 @@ setMethod("cache", #' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The DataFrame to persist +#' +#' @family dataframe_funcs #' @rdname persist #' @name persist #' @export @@ -386,6 +408,8 @@ setMethod("persist", #' #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted +#' +#' @family dataframe_funcs #' @rdname unpersist-methods #' @name unpersist #' @export @@ -412,6 +436,8 @@ setMethod("unpersist", #' #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. +#' +#' @family dataframe_funcs #' @rdname repartition #' @name repartition #' @export @@ -435,8 +461,10 @@ setMethod("repartition", # Convert the rows of a DataFrame into JSON objects and return an RDD where # each element contains a JSON string. # -#@param x A SparkSQL DataFrame +# @param x A SparkSQL DataFrame # @return A StringRRDD of JSON objects +# +# @family dataframe_funcs # @rdname tojson # @export # @examples @@ -462,6 +490,8 @@ setMethod("toJSON", #' #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved +#' +#' @family dataframe_funcs #' @rdname saveAsParquetFile #' @name saveAsParquetFile #' @export @@ -484,6 +514,8 @@ setMethod("saveAsParquetFile", #' Return a new DataFrame containing the distinct rows in this DataFrame. #' #' @param x A SparkSQL DataFrame +#' +#' @family dataframe_funcs #' @rdname distinct #' @name distinct #' @export @@ -506,6 +538,7 @@ setMethod("distinct", # #' @description Returns a new DataFrame containing distinct rows in this DataFrame #' +#' @family dataframe_funcs #' @rdname unique #' @name unique #' @aliases distinct @@ -522,6 +555,8 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' +#' @family dataframe_funcs #' @rdname sample #' @aliases sample_frac #' @export @@ -545,6 +580,7 @@ setMethod("sample", dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -560,6 +596,7 @@ setMethod("sample_frac", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname count #' @name count #' @aliases nrow @@ -583,6 +620,7 @@ setMethod("count", #' #' @name nrow #' +#' @family dataframe_funcs #' @rdname nrow #' @aliases count setMethod("nrow", @@ -595,6 +633,7 @@ setMethod("nrow", #' #' @param x a SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname ncol #' @name ncol #' @export @@ -615,6 +654,7 @@ setMethod("ncol", #' Returns the dimentions (number of rows and columns) of a DataFrame #' @param x a SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname dim #' @name dim #' @export @@ -637,6 +677,8 @@ setMethod("dim", #' @param x A SparkSQL DataFrame #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. +#' +#' @family dataframe_funcs #' @rdname collect #' @name collect #' @export @@ -704,6 +746,7 @@ setMethod("collect", #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. #' +#' @family dataframe_funcs #' @rdname limit #' @name limit #' @export @@ -724,6 +767,7 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' +#' @family dataframe_funcs #' @rdname take #' @name take #' @export @@ -752,6 +796,7 @@ setMethod("take", #' @param num The number of rows to return. Default is 6. #' @return A data.frame #' +#' @family dataframe_funcs #' @rdname head #' @name head #' @export @@ -774,6 +819,7 @@ setMethod("head", #' #' @param x A SparkSQL DataFrame #' +#' @family dataframe_funcs #' @rdname first #' @name first #' @export @@ -797,6 +843,7 @@ setMethod("first", # # @param x A Spark DataFrame # +# @family dataframe_funcs # @rdname DataFrame # @export # @examples @@ -827,6 +874,7 @@ setMethod("toRDD", #' @return a GroupedData #' @seealso GroupedData #' @aliases group_by +#' @family dataframe_funcs #' @rdname groupBy #' @name groupBy #' @export @@ -851,6 +899,7 @@ setMethod("groupBy", groupedData(sgd) }) +#' @family dataframe_funcs #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -864,6 +913,7 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame +#' @family dataframe_funcs #' @rdname agg #' @name agg #' @aliases summarize @@ -874,6 +924,7 @@ setMethod("agg", agg(groupBy(x), ...) }) +#' @family dataframe_funcs #' @rdname agg #' @name summarize setMethod("summarize", @@ -889,6 +940,7 @@ setMethod("summarize", # the requested map function. # ################################################################################### +# @family dataframe_funcs # @rdname lapply setMethod("lapply", signature(X = "DataFrame", FUN = "function"), @@ -897,6 +949,7 @@ setMethod("lapply", lapply(rdd, FUN) }) +# @family dataframe_funcs # @rdname lapply setMethod("map", signature(X = "DataFrame", FUN = "function"), @@ -904,6 +957,7 @@ setMethod("map", lapply(X, FUN) }) +# @family dataframe_funcs # @rdname flatMap setMethod("flatMap", signature(X = "DataFrame", FUN = "function"), @@ -911,7 +965,7 @@ setMethod("flatMap", rdd <- toRDD(X) flatMap(rdd, FUN) }) - +# @family dataframe_funcs # @rdname lapplyPartition setMethod("lapplyPartition", signature(X = "DataFrame", FUN = "function"), @@ -920,6 +974,7 @@ setMethod("lapplyPartition", lapplyPartition(rdd, FUN) }) +# @family dataframe_funcs # @rdname lapplyPartition setMethod("mapPartitions", signature(X = "DataFrame", FUN = "function"), @@ -927,6 +982,7 @@ setMethod("mapPartitions", lapplyPartition(X, FUN) }) +# @family dataframe_funcs # @rdname foreach setMethod("foreach", signature(x = "DataFrame", func = "function"), @@ -935,6 +991,7 @@ setMethod("foreach", foreach(rdd, func) }) +# @family dataframe_funcs # @rdname foreach setMethod("foreachPartition", signature(x = "DataFrame", func = "function"), @@ -1034,6 +1091,7 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @param select expression for the single Column or a list of columns to select from the DataFrame #' @return A new DataFrame containing only the rows that meet the condition with selected columns #' @export +#' @family dataframe_funcs #' @rdname subset #' @name subset #' @aliases [ @@ -1064,6 +1122,7 @@ setMethod("subset", signature(x = "DataFrame"), #' @param col A list of columns or single Column or name #' @return A new DataFrame with selected columns #' @export +#' @family dataframe_funcs #' @rdname select #' @name select #' @family subsetting functions @@ -1091,6 +1150,7 @@ setMethod("select", signature(x = "DataFrame", col = "character"), } }) +#' @family dataframe_funcs #' @rdname select #' @export setMethod("select", signature(x = "DataFrame", col = "Column"), @@ -1102,6 +1162,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname select #' @export setMethod("select", @@ -1126,6 +1187,7 @@ setMethod("select", #' @param expr A string containing a SQL expression #' @param ... Additional expressions #' @return A DataFrame +#' @family dataframe_funcs #' @rdname selectExpr #' @name selectExpr #' @export @@ -1153,6 +1215,7 @@ setMethod("selectExpr", #' @param colName A string containing the name of the new column. #' @param col A Column expression. #' @return A DataFrame with the new column added. +#' @family dataframe_funcs #' @rdname withColumn #' @name withColumn #' @aliases mutate transform @@ -1178,6 +1241,7 @@ setMethod("withColumn", #' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. +#' @family dataframe_funcs #' @rdname withColumn #' @name mutate #' @aliases withColumn transform @@ -1211,6 +1275,7 @@ setMethod("mutate", }) #' @export +#' @family dataframe_funcs #' @rdname withColumn #' @name transform #' @aliases withColumn mutate @@ -1228,6 +1293,7 @@ setMethod("transform", #' @param existingCol The name of the column you want to change. #' @param newCol The new column name. #' @return A DataFrame with the column name changed. +#' @family dataframe_funcs #' @rdname withColumnRenamed #' @name withColumnRenamed #' @export @@ -1259,6 +1325,7 @@ setMethod("withColumnRenamed", #' @param x A DataFrame #' @param newCol A named pair of the form new_column_name = existing_column #' @return A DataFrame with the column name changed. +#' @family dataframe_funcs #' @rdname withColumnRenamed #' @name rename #' @aliases withColumnRenamed @@ -1303,6 +1370,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param decreasing A logical argument indicating sorting order for columns when #' a character vector is specified for col #' @return A DataFrame where all elements are sorted. +#' @family dataframe_funcs #' @rdname arrange #' @name arrange #' @aliases orderby @@ -1329,6 +1397,7 @@ setMethod("arrange", dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname arrange #' @export setMethod("arrange", @@ -1360,6 +1429,7 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) +#' @family dataframe_funcs #' @rdname arrange #' @name orderby setMethod("orderBy", @@ -1376,6 +1446,7 @@ setMethod("orderBy", #' @param condition The condition to filter on. This may either be a Column expression #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. +#' @family dataframe_funcs #' @rdname filter #' @name filter #' @family subsetting functions @@ -1399,6 +1470,7 @@ setMethod("filter", dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname filter #' @name where setMethod("where", @@ -1419,6 +1491,7 @@ setMethod("where", #' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', #' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. +#' @family dataframe_funcs #' @rdname join #' @name join #' @export @@ -1477,6 +1550,7 @@ setMethod("join", #' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right #' outer join will be returned. If all.x and all.y are set to TRUE, a full #' outer join will be returned. +#' @family dataframe_funcs #' @rdname merge #' @export #' @examples @@ -1608,6 +1682,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. +#' @family dataframe_funcs #' @rdname unionAll #' @name unionAll #' @export @@ -1627,9 +1702,10 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -# +#' #' @description Returns a new DataFrame containing rows of all parameters. -# +#' +#' @family dataframe_funcs #' @rdname rbind #' @name rbind #' @aliases unionAll @@ -1651,6 +1727,7 @@ setMethod("rbind", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. +#' @family dataframe_funcs #' @rdname intersect #' @name intersect #' @export @@ -1677,6 +1754,7 @@ setMethod("intersect", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. +#' @family dataframe_funcs #' @rdname except #' @name except #' @export @@ -1716,6 +1794,7 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' +#' @family dataframe_funcs #' @rdname write.df #' @name write.df #' @aliases saveDF @@ -1751,6 +1830,7 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) +#' @family dataframe_funcs #' @rdname write.df #' @name saveDF #' @export @@ -1781,6 +1861,7 @@ setMethod("saveDF", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' +#' @family dataframe_funcs #' @rdname saveAsTable #' @name saveAsTable #' @export @@ -1821,6 +1902,7 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame +#' @family dataframe_funcs #' @rdname describe #' @name describe #' @aliases summary @@ -1843,6 +1925,7 @@ setMethod("describe", dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname describe #' @name describe setMethod("describe", @@ -1857,6 +1940,7 @@ setMethod("describe", #' #' @description Computes statistics for numeric columns of the DataFrame #' +#' @family dataframe_funcs #' @rdname summary #' @name summary setMethod("summary", @@ -1881,6 +1965,7 @@ setMethod("summary", #' @param cols Optional list of column names to consider. #' @return A DataFrame #' +#' @family dataframe_funcs #' @rdname nafunctions #' @name dropna #' @aliases na.omit @@ -1910,6 +1995,7 @@ setMethod("dropna", dataFrame(sdf) }) +#' @family dataframe_funcs #' @rdname nafunctions #' @name na.omit #' @export @@ -1937,6 +2023,7 @@ setMethod("na.omit", #' column is simply ignored. #' @return A DataFrame #' +#' @family dataframe_funcs #' @rdname nafunctions #' @name fillna #' @export @@ -2000,6 +2087,7 @@ setMethod("fillna", #' @title Download data from a DataFrame into a data.frame #' @param x a DataFrame #' @return a data.frame +#' @family dataframe_funcs #' @rdname as.data.frame #' @examples \dontrun{ #' @@ -2020,6 +2108,7 @@ setMethod("as.data.frame", #' the DataFrame is searched by R when evaluating a variable, so columns in #' the DataFrame can be accessed by simply giving their names. #' +#' @family dataframe_funcs #' @rdname attach #' @title Attach DataFrame to R search path #' @param what (DataFrame) The DataFrame to attach diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f07c9573696ed..510b3599721a3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -160,7 +160,7 @@ showDF(df) ## DataFrame Operations -DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/DataFrame.html). Here we include some basic examples of structured data processing using DataFrames: From e352de0db2789919e1e0385b79f29b508a6b2b77 Mon Sep 17 00:00:00 2001 From: Nong Date: Tue, 3 Nov 2015 16:44:37 -0800 Subject: [PATCH 158/324] [SPARK-11329] [SQL] Cleanup from spark-11329 fix. Author: Nong Closes #9442 from nongli/spark-11483. --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 +- .../sql/catalyst/analysis/unresolved.scala | 18 +---- .../scala/org/apache/spark/sql/Column.scala | 6 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 79 +++++++++++-------- 4 files changed, 55 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 1ba559d9e3b18..440e9e28fa783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -477,8 +477,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target)) } - } | primary + | (ident <~ "."). + <~ "*" ^^ { case target => UnresolvedStar(Option(target))} + | primary ) protected lazy val signedPrimary: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6975662e2b738..eae17c86ddc7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -183,28 +183,16 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case None => input.output // If there is a table, pick out attributes that are part of this table. case Some(t) => if (t.size == 1) { - input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty) + input.output.filter(_.qualifiers.exists(resolver(_, t.head))) } else { List() } } - if (!expandedAttributes.isEmpty) { - if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) { - return expandedAttributes - } else { - require(expandedAttributes.size == input.output.size) - expandedAttributes.zip(input.output).map { - case (e, originalAttribute) => - Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) - } - } - return expandedAttributes - } - - require(target.isDefined) + if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. + require(target.isDefined) val attribute = input.resolve(target.get, resolver) if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3cde9d6cb4708..c73f696962de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -60,8 +60,10 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName( - name.substring(0, name.length - 2)))) + case _ if name.endsWith(".*") => { + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + } case _ => UnresolvedAttribute.quotedString(name) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ee54bff24b196..6388a8b9c3720 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct} +import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -1956,7 +1956,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // Try with a registered table. sql("select struct(a, b) as record from testData2").registerTempTable("structTable") - checkAnswer(sql("SELECT record.* FROM structTable"), + checkAnswer( + sql("SELECT record.* FROM structTable"), Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) checkAnswer(sql( @@ -2019,50 +2020,62 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) // Try with a registered table - nestedStructData.registerTempTable("nestedStructTable") - checkAnswer(sql("SELECT record.* FROM nestedStructTable"), - nestedStructData.select($"record.*")) - checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"), - nestedStructData.select($"record.r1")) - checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"), - nestedStructData.select($"record.r1.*")) - - // Create paths with unusual characters. + withTempTable("nestedStructTable") { + nestedStructData.registerTempTable("nestedStructTable") + checkAnswer( + sql("SELECT record.* FROM nestedStructTable"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT record.r1 FROM nestedStructTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT record.r1.* FROM nestedStructTable"), + nestedStructData.select($"record.r1.*")) + + // Try resolving something not there. + assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) + .getMessage.contains("cannot resolve")) + } + + // Create paths with unusual characters val specialCharacterPath = sql( """ | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) - specialCharacterPath.registerTempTable("specialCharacterTable") - checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"), - nestedStructData.select($"record.*")) - checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), - nestedStructData.select($"record.r1")) - checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), - nestedStructData.select($"record.r2")) - checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), - nestedStructData.select($"record.r1.*")) + withTempTable("specialCharacterTable") { + specialCharacterPath.registerTempTable("specialCharacterTable") + checkAnswer( + specialCharacterPath.select($"`r&&b.c`.*"), + nestedStructData.select($"record.*")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) + checkAnswer( + sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), + nestedStructData.select($"record.r2")) + checkAnswer( + sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"), + nestedStructData.select($"record.r1.*")) + } // Try star expanding a scalar. This should fail. assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains( "Can only star expand struct data types.")) - - // Try resolving something not there. - assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable")) - .getMessage.contains("cannot resolve")) } - test("Struct Star Expansion - Name conflict") { // Create a data set that contains a naming conflict val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") - nameConflict.registerTempTable("nameConflict") - // Unqualified should resolve to table. - checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), - Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: - Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) - // Qualify the struct type with the table name. - checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), - Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + withTempTable("nameConflict") { + nameConflict.registerTempTable("nameConflict") + // Unqualified should resolve to table. + checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), + Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: + Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil) + // Qualify the struct type with the table name. + checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"), + Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) + } } } From 2692bdb7dbf36d6247f595d5fd0cb9cda89e1fdd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 3 Nov 2015 20:25:58 -0800 Subject: [PATCH 159/324] [SPARK-11455][SQL] fix case sensitivity of partition by depend on `caseSensitive` to do column name equality check, instead of just `==` Author: Wenchen Fan Closes #9410 from cloud-fan/partition. --- .../datasources/PartitioningUtils.scala | 7 ++--- .../datasources/ResolvedDataSource.scala | 27 ++++++++++++++----- .../sql/execution/datasources/rules.scala | 6 +++-- .../org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 628c5e18936c5..16dc23661c070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -287,10 +287,11 @@ private[sql] object PartitioningUtils { def validatePartitionColumnDataTypes( schema: StructType, - partitionColumns: Array[String]): Unit = { + partitionColumns: Array[String], + caseSensitive: Boolean): Unit = { - ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field => - field.dataType match { + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { + field => field.dataType match { case _: AtomicType => // OK case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 54beabbf63b5f..86a306b8f941d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging { val maybePartitionsSchema = if (partitionColumns.isEmpty) { None } else { - Some(partitionColumnsSchema(schema, partitionColumns)) + Some(partitionColumnsSchema( + schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis)) } val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging { def partitionColumnsSchema( schema: StructType, - partitionColumns: Array[String]): StructType = { + partitionColumns: Array[String], + caseSensitive: Boolean): StructType = { + val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { + schema.find(f => equality(f.name, col)).getOrElse { throw new RuntimeException(s"Partition column $col not found in schema $schema") } }).asNullable } + private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { + if (caseSensitive) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ def apply( sqlContext: SQLContext, @@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging { path.makeQualified(fs.getUri, fs.getWorkingDirectory) } - PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns) + val caseSensitive = sqlContext.conf.caseSensitiveAnalysis + PartitioningUtils.validatePartitionColumnDataTypes( + data.schema, partitionColumns, caseSensitive) - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val equality = columnNameEquality(caseSensitive) + val dataSchema = StructType( + data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) val r = dataSource.createRelation( sqlContext, Array(outputPath.toString), Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), + Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index abc016bf020d9..1a8e7ab202dc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray) + PartitioningUtils.validatePartitionColumnDataTypes( + r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis) // Get all input data source relations of the query. val srcRelations = query.collect { @@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns) + PartitioningUtils.validatePartitionColumnDataTypes( + query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) case _ => // OK } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a883bcb7b1012..a9e6413423118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1118,4 +1118,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { if (!allSequential) throw new SparkException("Partition should contain all sequential values") }) } + + test("fix case sensitivity of partition by") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val p = path.getAbsolutePath + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) + checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + } + } + } } From 8aff36e91de0fee2f3f56c6d240bb203b5bb48ba Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 4 Nov 2015 10:49:34 +0000 Subject: [PATCH 160/324] [SPARK-2960][DEPLOY] Support executing Spark from symlinks (reopen) This PR is based on the work of roji to support running Spark scripts from symlinks. Thanks for the great work roji . Would you mind taking a look at this PR, thanks a lot. For releases like HDP and others, normally it will expose the Spark executables as symlinks and put in `PATH`, but current Spark's scripts do not support finding real path from symlink recursively, this will make spark fail to execute from symlink. This PR try to solve this issue by finding the absolute path from symlink. Instead of using `readlink -f` like what this PR (https://github.com/apache/spark/pull/2386) implemented is that `-f` is not support for Mac, so here manually seeking the path through loop. I've tested with Mac and Linux (Cent OS), looks fine. This PR did not fix the scripts under `sbin` folder, not sure if it needs to be fixed also? Please help to review, any comment is greatly appreciated. Author: jerryshao Author: Shay Rojansky Closes #8669 from jerryshao/SPARK-2960. --- bin/beeline | 8 +++++--- bin/load-spark-env.sh | 32 ++++++++++++++++------------- bin/pyspark | 14 +++++++------ bin/run-example | 18 ++++++++-------- bin/spark-class | 15 +++++++------- bin/spark-shell | 9 +++++--- bin/spark-sql | 7 +++++-- bin/spark-submit | 6 ++++-- bin/sparkR | 9 +++++--- sbin/slaves.sh | 9 ++++---- sbin/spark-config.sh | 23 +++++++-------------- sbin/spark-daemon.sh | 23 +++++++++++---------- sbin/spark-daemons.sh | 9 ++++---- sbin/start-all.sh | 11 +++++----- sbin/start-history-server.sh | 11 +++++----- sbin/start-master.sh | 17 +++++++-------- sbin/start-mesos-dispatcher.sh | 11 +++++----- sbin/start-mesos-shuffle-service.sh | 11 +++++----- sbin/start-shuffle-service.sh | 11 +++++----- sbin/start-slave.sh | 18 ++++++++-------- sbin/start-slaves.sh | 19 ++++++++--------- sbin/start-thriftserver.sh | 11 +++++----- sbin/stop-all.sh | 14 ++++++------- sbin/stop-history-server.sh | 7 ++++--- sbin/stop-master.sh | 13 ++++++------ sbin/stop-mesos-dispatcher.sh | 9 ++++---- sbin/stop-mesos-shuffle-service.sh | 7 ++++--- sbin/stop-shuffle-service.sh | 7 ++++--- sbin/stop-slave.sh | 15 +++++++------- sbin/stop-slaves.sh | 15 +++++++------- sbin/stop-thriftserver.sh | 7 ++++--- 31 files changed, 213 insertions(+), 183 deletions(-) diff --git a/bin/beeline b/bin/beeline index 3fcb6df34339d..1627626941a73 100755 --- a/bin/beeline +++ b/bin/beeline @@ -23,8 +23,10 @@ # Enter posix mode for bash set -o posix -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +# Figure out if SPARK_HOME is set +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi CLASS="org.apache.hive.beeline.BeeLine" -exec "$FWDIR/bin/spark-class" $CLASS "$@" +exec "${SPARK_HOME}/bin/spark-class" $CLASS "$@" diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 95779e9ddbb18..eaea964ed5b3d 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -20,13 +20,17 @@ # This script loads spark-env.sh if it exists, and ensures it is only loaded once. # spark-env.sh is loaded from SPARK_CONF_DIR if set, or within the current directory's # conf/ subdirectory. -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" + +# Figure out where Spark is installed +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 # Returns the parent of the directory this script lives in. - parent_dir="$(cd "`dirname "$0"`"/..; pwd)" + parent_dir="${SPARK_HOME}" user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" @@ -42,18 +46,18 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" - ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" + ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.10" - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 - exit 1 - fi + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi - if [ -d "$ASSEMBLY_DIR2" ]; then - export SPARK_SCALA_VERSION="2.11" - else - export SPARK_SCALA_VERSION="2.10" - fi + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi fi diff --git a/bin/pyspark b/bin/pyspark index 18012ee4a0b4f..5eaa17d3c2016 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,9 +17,11 @@ # limitations under the License. # -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -source "$SPARK_HOME"/bin/load-spark-env.sh +source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` @@ -64,12 +66,12 @@ fi export PYSPARK_PYTHON # Add the PySpark classes to the Python path: -export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" -export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.9-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" -export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" +export PYTHONSTARTUP="${SPARK_HOME}/python/pyspark/shell.py" # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then @@ -82,4 +84,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" +exec "${SPARK_HOME}"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/run-example b/bin/run-example index 798e2caeb88ce..e1b0d5789bed6 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,11 +17,13 @@ # limitations under the License. # -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export SPARK_HOME="$FWDIR" -EXAMPLES_DIR="$FWDIR"/examples +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +EXAMPLES_DIR="${SPARK_HOME}"/examples -. "$FWDIR"/bin/load-spark-env.sh +. "${SPARK_HOME}"/bin/load-spark-env.sh if [ -n "$1" ]; then EXAMPLE_CLASS="$1" @@ -34,8 +36,8 @@ else exit 1 fi -if [ -f "$FWDIR/RELEASE" ]; then - JAR_PATH="${FWDIR}/lib" +if [ -f "${SPARK_HOME}/RELEASE" ]; then + JAR_PATH="${SPARK_HOME}/lib" else JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}" fi @@ -44,7 +46,7 @@ JAR_COUNT=0 for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do if [[ ! -e "$f" ]]; then - echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 + echo "Failed to find Spark examples assembly in ${SPARK_HOME}/lib or ${SPARK_HOME}/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 exit 1 fi @@ -67,7 +69,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -exec "$FWDIR"/bin/spark-submit \ +exec "${SPARK_HOME}"/bin/spark-submit \ --master $EXAMPLE_MASTER \ --class $EXAMPLE_CLASS \ "$SPARK_EXAMPLES_JAR" \ diff --git a/bin/spark-class b/bin/spark-class index 8cae6ccbabe7c..87d06693af4fe 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -17,10 +17,11 @@ # limitations under the License. # -# Figure out where Spark is installed -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$SPARK_HOME"/bin/load-spark-env.sh +. "${SPARK_HOME}"/bin/load-spark-env.sh # Find the java binary if [ -n "${JAVA_HOME}" ]; then @@ -36,10 +37,10 @@ fi # Find assembly jar SPARK_ASSEMBLY_JAR= -if [ -f "$SPARK_HOME/RELEASE" ]; then - ASSEMBLY_DIR="$SPARK_HOME/lib" +if [ -f "${SPARK_HOME}/RELEASE" ]; then + ASSEMBLY_DIR="${SPARK_HOME}/lib" else - ASSEMBLY_DIR="$SPARK_HOME/assembly/target/scala-$SPARK_SCALA_VERSION" + ASSEMBLY_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" fi GREP_OPTIONS= @@ -65,7 +66,7 @@ LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. if [ -n "$SPARK_PREPEND_CLASSES" ]; then - LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" + LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" fi export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" diff --git a/bin/spark-shell b/bin/spark-shell index 00ab7afd118b5..6583b5bd880ee 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -28,7 +28,10 @@ esac # Enter posix mode for bash set -o posix -export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, @@ -47,11 +50,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" + "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" + "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-sql b/bin/spark-sql index 4ea7bc6e39c07..970d12cbf51dd 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,6 +17,9 @@ # limitations under the License. # -export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" -exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" +exec "${SPARK_HOME}"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 255378b0f077c..023f9c162f4b8 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,9 +17,11 @@ # limitations under the License. # -SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" +exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/sparkR b/bin/sparkR index 464c29f369424..2c07a82e2173b 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,7 +17,10 @@ # limitations under the License. # -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" -source "$SPARK_HOME"/bin/load-spark-env.sh +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" -exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" +exec "${SPARK_HOME}"/bin/spark-submit sparkr-shell-main "$@" diff --git a/sbin/slaves.sh b/sbin/slaves.sh index cdad47ee2e594..c971aa3296b09 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -36,10 +36,11 @@ if [ $# -le 0 ]; then exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # If the slaves file is specified in the command line, # then it takes precedence over the definition in @@ -65,7 +66,7 @@ then shift fi -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$HOSTLIST" = "" ]; then if [ "$SPARK_SLAVES" = "" ]; then diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index e6bf544c14799..d8d9d00d64ebc 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -19,21 +19,12 @@ # should not be executable directly # also should not be passed any arguments, since we need original $* -# resolve links - $0 may be a softlink -this="${BASH_SOURCE:-$0}" -common_bin="$(cd -P -- "$(dirname -- "$this")" && pwd -P)" -script="$(basename -- "$this")" -this="$common_bin/$script" +# symlink and absolute path should rely on SPARK_HOME to resolve +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -# convert relative path to absolute path -config_bin="`dirname "$this"`" -script="`basename "$this"`" -config_bin="`cd "$config_bin"; pwd`" -this="$config_bin/$script" - -export SPARK_PREFIX="`dirname "$this"`"/.. -export SPARK_HOME="${SPARK_PREFIX}" -export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" -export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.9-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9-src.zip:${PYTHONPATH}" diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 0fbe795822fbf..6ab57df409529 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -37,10 +37,11 @@ if [ $# -le 1 ]; then exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # get arguments @@ -86,7 +87,7 @@ spark_rotate_log () fi } -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_IDENT_STRING" = "" ]; then export SPARK_IDENT_STRING="$USER" @@ -97,7 +98,7 @@ export SPARK_PRINT_LAUNCH_COMMAND="1" # get log directory if [ "$SPARK_LOG_DIR" = "" ]; then - export SPARK_LOG_DIR="$SPARK_HOME/logs" + export SPARK_LOG_DIR="${SPARK_HOME}/logs" fi mkdir -p "$SPARK_LOG_DIR" touch "$SPARK_LOG_DIR"/.spark_test > /dev/null 2>&1 @@ -137,7 +138,7 @@ run_command() { if [ "$SPARK_MASTER" != "" ]; then echo rsync from "$SPARK_MASTER" - rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' "$SPARK_MASTER/" "$SPARK_HOME" + rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' "$SPARK_MASTER/" "${SPARK_HOME}" fi spark_rotate_log "$log" @@ -145,12 +146,12 @@ run_command() { case "$mode" in (class) - nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & + nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & newpid="$!" ;; (submit) - nohup nice -n "$SPARK_NICENESS" "$SPARK_PREFIX"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & + nohup nice -n "$SPARK_NICENESS" "${SPARK_HOME}"/bin/spark-submit --class $command "$@" >> "$log" 2>&1 < /dev/null & newpid="$!" ;; @@ -205,13 +206,13 @@ case $option in else echo $pid file is present but $command not running exit 1 - fi + fi else echo $command not running. exit 2 - fi + fi ;; - + (*) echo $usage exit 1 diff --git a/sbin/spark-daemons.sh b/sbin/spark-daemons.sh index 5d9f2bb51cae0..dec2f4432df39 100755 --- a/sbin/spark-daemons.sh +++ b/sbin/spark-daemons.sh @@ -27,9 +27,10 @@ if [ $# -le 1 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/spark-daemon.sh" "$@" +exec "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/spark-daemon.sh" "$@" diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 1baf57cea09ee..6217f9bf28e3d 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -21,8 +21,9 @@ # Starts the master on this node. # Starts a worker on each node specified in conf/slaves -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi TACHYON_STR="" @@ -36,10 +37,10 @@ shift done # Load the Spark configuration -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # Start Master -"$sbin"/start-master.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-master.sh $TACHYON_STR # Start Workers -"$sbin"/start-slaves.sh $TACHYON_STR +"${SPARK_HOME}/sbin"/start-slaves.sh $TACHYON_STR diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 9034e5715cc85..6851d99b7e8f4 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -24,10 +24,11 @@ # Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 $@ +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 $@ diff --git a/sbin/start-master.sh b/sbin/start-master.sh index a7f5d5702fd80..c20e19a8412df 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -19,8 +19,9 @@ # Starts the master on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi ORIGINAL_ARGS="$@" @@ -39,9 +40,9 @@ case $1 in shift done -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_MASTER_PORT" = "" ]; then SPARK_MASTER_PORT=7077 @@ -55,12 +56,12 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then - "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "$sbin"/../tachyon/bin/tachyon format -s - "$sbin"/../tachyon/bin/tachyon-start.sh master + "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon format -s + "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon-start.sh master fi diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index ef1fc573d5c65..4777e1668c703 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -21,12 +21,13 @@ # Rest server to handle driver requests for Mesos cluster mode. # Only one cluster dispatcher is needed per Mesos cluster. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then SPARK_MESOS_DISPATCHER_PORT=7077 @@ -37,4 +38,4 @@ if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" +"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/start-mesos-shuffle-service.sh b/sbin/start-mesos-shuffle-service.sh index 64580762c5dc4..1845845676029 100755 --- a/sbin/start-mesos-shuffle-service.sh +++ b/sbin/start-mesos-shuffle-service.sh @@ -26,10 +26,11 @@ # Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle service configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh index 4fddcf7f95d40..793e165be6c78 100755 --- a/sbin/start-shuffle-service.sh +++ b/sbin/start-shuffle-service.sh @@ -24,10 +24,11 @@ # Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" -exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 4c919ff76a8f5..21455648d1c6d 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -21,14 +21,14 @@ # # Environment Variables # -# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# SPARK_WORKER_INSTANCES The number of worker instances to run on this # slave. Default is 1. -# SPARK_WORKER_PORT The base port number for the first worker. If set, +# SPARK_WORKER_PORT The base port number for the first worker. If set, # subsequent workers will increment this number. If # unset, Spark will find a valid port number, but # with no guarantee of a predictable pattern. # SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first -# worker. Subsequent workers will increment this +# worker. Subsequent workers will increment this # number. Default is 8081. usage="Usage: start-slave.sh where is like spark://localhost:7077" @@ -39,12 +39,13 @@ if [ $# -lt 1 ]; then exit 1 fi -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # First argument should be the master; we need to store it aside because we may # need to insert arguments between it and the other arguments @@ -71,7 +72,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } @@ -82,4 +83,3 @@ else start_instance $(( 1 + $i )) "$@" done fi - diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 24d6268815ed3..51ca81e053b70 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -19,16 +19,16 @@ # Starts a slave instance on each machine specified in the conf/slaves file. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" - +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -38,9 +38,8 @@ case $1 in shift done -. "$sbin/spark-config.sh" - -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # Find the port number for the master if [ "$SPARK_MASTER_PORT" = "" ]; then @@ -52,11 +51,11 @@ if [ "$SPARK_MASTER_IP" = "" ]; then fi if [ "$START_TACHYON" == "true" ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" + "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" # set -t so we can call sudo - SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 + SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 fi # Launch the slaves -"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" +"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 5b0aeb177fff3..ad7e7c5277eb1 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -23,8 +23,9 @@ # Enter posix mode for bash set -o posix -# Figure out where Spark is installed -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # NOTE: This exact class name is matched downstream by SparkSubmit. # Any changes need to be reflected there. @@ -39,10 +40,10 @@ function usage { pattern+="\|=======" pattern+="\|--help" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "${SPARK_HOME}"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "Thrift server options:" - "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then @@ -52,4 +53,4 @@ fi export SUBMIT_USAGE_FUNCTION=usage -exec "$FWDIR"/sbin/spark-daemon.sh submit $CLASS 1 "$@" +exec "${SPARK_HOME}"/sbin/spark-daemon.sh submit $CLASS 1 "$@" diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 1a9abe07db844..4e476ca05cb05 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -20,23 +20,23 @@ # Stop all spark daemons. # Run this on the master node. - -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi # Load the Spark configuration -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" # Stop the slaves, then the master -"$sbin"/stop-slaves.sh -"$sbin"/stop-master.sh +"${SPARK_HOME}/sbin"/stop-slaves.sh +"${SPARK_HOME}/sbin"/stop-master.sh if [ "$1" == "--wait" ] then printf "Waiting for workers to shut down..." while true do - running=`$sbin/slaves.sh ps -ef | grep -v grep | grep deploy.worker.Worker` + running=`${SPARK_HOME}/sbin/slaves.sh ps -ef | grep -v grep | grep deploy.worker.Worker` if [ -z "$running" ] then printf "\nAll workers successfully shut down.\n" diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh index 6e6056359510f..14e3af4be910a 100755 --- a/sbin/stop-history-server.sh +++ b/sbin/stop-history-server.sh @@ -19,7 +19,8 @@ # Stops the history server on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1 +"${SPARK_HOME}/sbin/spark-daemon.sh" stop org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index 729702d92191e..e57962bb354d9 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -19,13 +19,14 @@ # Stops the master on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 -if [ -e "$sbin"/../tachyon/bin/tachyon ]; then - "$sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master +if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then + "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master fi diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh index cb65d95b5e524..5c0b4e051db38 100755 --- a/sbin/stop-mesos-dispatcher.sh +++ b/sbin/stop-mesos-dispatcher.sh @@ -18,10 +18,11 @@ # # Stop the Mesos Cluster dispatcher on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 diff --git a/sbin/stop-mesos-shuffle-service.sh b/sbin/stop-mesos-shuffle-service.sh index 0e965d5ec5886..d23cad375e1bd 100755 --- a/sbin/stop-mesos-shuffle-service.sh +++ b/sbin/stop-mesos-shuffle-service.sh @@ -19,7 +19,8 @@ # Stops the Mesos external shuffle service on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh index 4cb6891ae27fa..50d69cf34e0a5 100755 --- a/sbin/stop-shuffle-service.sh +++ b/sbin/stop-shuffle-service.sh @@ -19,7 +19,8 @@ # Stops the external shuffle service on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh index 3d1da5b254f2a..685bcf59b33aa 100755 --- a/sbin/stop-slave.sh +++ b/sbin/stop-slave.sh @@ -21,23 +21,24 @@ # # Environment variables # -# SPARK_WORKER_INSTANCES The number of worker instances that should be +# SPARK_WORKER_INSTANCES The number of worker instances that should be # running on this slave. Default is 1. # Usage: stop-slave.sh # Stops all slaves on this worker machine -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 else for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + "${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) done fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 54c9bd46803a9..63956377629d6 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,16 +17,17 @@ # limitations under the License. # -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -. "$sbin/spark-config.sh" +. "${SPARK_HOME}/sbin/spark-config.sh" -. "$SPARK_PREFIX/bin/load-spark-env.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" # do before the below calls as they exec -if [ -e "$sbin"/../tachyon/bin/tachyon ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker +if [ -e "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon ]; then + "${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh +"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin"/stop-slave.sh diff --git a/sbin/stop-thriftserver.sh b/sbin/stop-thriftserver.sh index 4031a00d4a689..cf45058f882a0 100755 --- a/sbin/stop-thriftserver.sh +++ b/sbin/stop-thriftserver.sh @@ -19,7 +19,8 @@ # Stops the thrift server on the machine this script is executed on. -sbin="`dirname "$0"`" -sbin="`cd "$sbin"; pwd`" +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi -"$sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1 +"${SPARK_HOME}/sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1 From c09e5139874fb3626e005c8240cca5308b902ef3 Mon Sep 17 00:00:00 2001 From: tedyu Date: Wed, 4 Nov 2015 10:51:40 +0000 Subject: [PATCH 161/324] [SPARK-11442] Reduce numSlices for local metrics test of SparkListenerSuite In the thread, http://search-hadoop.com/m/q3RTtcQiFSlTxeP/test+failed+due+to+OOME&subj=test+failed+due+to+OOME, it was discussed that memory consumption for SparkListenerSuite should be brought down. This is an attempt in that direction by reducing numSlices for local metrics test. Author: tedyu Closes #9384 from tedyu/master. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index a9652d7e7d0b0..53102b9f1c936 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -212,14 +212,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match i } - val d = sc.parallelize(0 to 1e4.toInt, 64).map(w) + val numSlices = 16 + val d = sc.parallelize(0 to 1e3.toInt, numSlices).map(w) d.count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (1) val d2 = d.map { i => w(i) -> i * 2 }.setName("shuffle input 1") val d3 = d.map { i => w(i) -> (0 to (i % 5)) }.setName("shuffle input 2") - val d4 = d2.cogroup(d3, 64).map { case (k, (v1, v2)) => + val d4 = d2.cogroup(d3, numSlices).map { case (k, (v1, v2)) => w(k) -> (v1.size, v2.size) } d4.setName("A Cogroup") @@ -258,8 +259,8 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be (128) - sm.localBlocksFetched should be (128) + sm.totalBlocksFetched should be (2*numSlices) + sm.localBlocksFetched should be (2*numSlices) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0L) } From e328b69c31821e4b27673d7ef6182ab3b7a05ca8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 4 Nov 2015 08:28:33 -0800 Subject: [PATCH 162/324] [SPARK-9492][ML][R] LogisticRegression in R should provide model statistics Like ml ```LinearRegression```, ```LogisticRegression``` should provide a training summary including feature names and their coefficients. Author: Yanbo Liang Closes #9303 from yanboliang/spark-9492. --- R/pkg/inst/tests/test_mllib.R | 17 +++++++++++++++++ .../ml/classification/LogisticRegression.scala | 17 +++++++++++++---- .../org/apache/spark/ml/r/SparkRWrappers.scala | 7 ++++--- project/MimaExcludes.scala | 4 +++- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3331ce738358c..032cfef061fd3 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -67,3 +67,20 @@ test_that("summary coefficients match with native glm", { as.character(stats$features) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) + +test_that("summary coefficients match with native glm of family 'binomial'", { + df <- createDataFrame(sqlContext, iris) + training <- filter(df, df$Species != "setosa") + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = "binomial")) + coefs <- as.vector(stats$coefficients) + + rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] + rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit")))) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a1335e7a1bde8..f5fca686df144 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -378,6 +378,7 @@ class LogisticRegression(override val uid: String) model.transform(dataset), $(probabilityCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(logRegSummary) } @@ -452,7 +453,8 @@ class LogisticRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + new BinaryLogisticRegressionSummary( + this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) } /** @@ -614,9 +616,12 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance. */ def labelCol: String + /** Field in "predictions" which gives the features of each instance as a vector. */ + def featuresCol: String + } /** @@ -626,6 +631,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -633,8 +639,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( predictions: DataFrame, probabilityCol: String, labelCol: String, + featuresCol: String, val objectiveHistory: Array[Double]) - extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { } @@ -646,12 +653,14 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @transient override val predictions: DataFrame, override val probabilityCol: String, - override val labelCol: String) extends LogisticRegressionSummary { + override val labelCol: String, + override val featuresCol: String) extends LogisticRegressionSummary { private val sqlContext = predictions.sqlContext import sqlContext.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 24f76de806d8f..5be2f86936211 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -66,9 +66,10 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - case _: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No features names available for LogisticRegressionModel") // SPARK-9492 + case m: LogisticRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ec0e44b7f2d66..eeef96c378bdb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -59,7 +59,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticAggregator.count") + "org.apache.spark.ml.classification.LogisticAggregator.count"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") ) ++ Seq( // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. // This class is marked as `private` but MiMa still seems to be confused by the change. From 820064e613609bbf7edd726d982da1de60bf417a Mon Sep 17 00:00:00 2001 From: Pravin Gadakh Date: Wed, 4 Nov 2015 08:32:08 -0800 Subject: [PATCH 163/324] [SPARK-11380][DOCS] Replace example code in mllib-frequent-pattern-mining.md using include_example Author: Pravin Gadakh Author: Pravin Gadakh Closes #9340 from pravingadakh/SPARK-11380. --- docs/mllib-frequent-pattern-mining.md | 168 +----------------- .../mllib/JavaAssociationRulesExample.java | 56 ++++++ .../examples/mllib/JavaPrefixSpanExample.java | 55 ++++++ .../examples/mllib/JavaSimpleFPGrowth.java | 71 ++++++++ .../src/main/python/mllib/fpgrowth_example.py | 33 ++++ .../mllib/AssociationRulesExample.scala | 54 ++++++ .../examples/mllib/PrefixSpanExample.scala | 52 ++++++ .../spark/examples/mllib/SimpleFPGrowth.scala | 59 ++++++ 8 files changed, 387 insertions(+), 161 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java create mode 100644 examples/src/main/python/mllib/fpgrowth_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index f749eb4f2ff4f..fe42896a05d8e 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -52,31 +52,7 @@ details) from `transactions`. Refer to the [`FPGrowth` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) for details on the API. -{% highlight scala %} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.fpm.FPGrowth - -val data = sc.textFile("data/mllib/sample_fpgrowth.txt") - -val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) - -val fpg = new FPGrowth() - .setMinSupport(0.2) - .setNumPartitions(10) -val model = fpg.run(transactions) - -model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) -} - -val minConfidence = 0.8 -model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala %} @@ -95,46 +71,7 @@ details) from `transactions`. Refer to the [`FPGrowth` Java docs](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.fpm.AssociationRules; -import org.apache.spark.mllib.fpm.FPGrowth; -import org.apache.spark.mllib.fpm.FPGrowthModel; - -SparkConf conf = new SparkConf().setAppName("FP-growth Example"); -JavaSparkContext sc = new JavaSparkContext(conf); - -JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); - -JavaRDD> transactions = data.map( - new Function>() { - public List call(String line) { - String[] parts = line.split(" "); - return Arrays.asList(parts); - } - } -); - -FPGrowth fpg = new FPGrowth() - .setMinSupport(0.2) - .setNumPartitions(10); -FPGrowthModel model = fpg.run(transactions); - -for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); -} - -double minConfidence = 0.8; -for (AssociationRules.Rule rule - : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { - System.out.println( - rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java %} @@ -149,19 +86,7 @@ that stores the frequent itemsets with their frequencies. Refer to the [`FPGrowth` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.fpm.FPGrowth) for more details on the API. -{% highlight python %} -from pyspark.mllib.fpm import FPGrowth - -data = sc.textFile("data/mllib/sample_fpgrowth.txt") - -transactions = data.map(lambda line: line.strip().split(' ')) - -model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) - -result = model.freqItemsets().collect() -for fi in result: - print(fi) -{% endhighlight %} +{% include_example python/mllib/fpgrowth_example.py %} @@ -177,27 +102,7 @@ that have a single item as the consequent. Refer to the [`AssociationRules` Scala docs](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) for details on the API. -{% highlight scala %} -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.fpm.AssociationRules -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset - -val freqItemsets = sc.parallelize(Seq( - new FreqItemset(Array("a"), 15L), - new FreqItemset(Array("b"), 35L), - new FreqItemset(Array("a", "b"), 12L) -)); - -val ar = new AssociationRules() - .setMinConfidence(0.8) -val results = ar.run(freqItemsets) - -results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala %} @@ -208,29 +113,7 @@ that have a single item as the consequent. Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/AssociationRules.html) for details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.fpm.AssociationRules; -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; - -JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( - new FreqItemset(new String[] {"a"}, 15L), - new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 12L) -)); - -AssociationRules arules = new AssociationRules() - .setMinConfidence(0.8); -JavaRDD> results = arules.run(freqItemsets); - -for (AssociationRules.Rule rule: results.collect()) { - System.out.println( - rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java %} @@ -278,24 +161,7 @@ that stores the frequent sequences with their frequencies. Refer to the [`PrefixSpan` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpan) and [`PrefixSpanModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.fpm.PrefixSpanModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.fpm.PrefixSpan - -val sequences = sc.parallelize(Seq( - Array(Array(1, 2), Array(3)), - Array(Array(1), Array(3, 2), Array(1, 2)), - Array(Array(1, 2), Array(5)), - Array(Array(6)) - ), 2).cache() -val prefixSpan = new PrefixSpan() - .setMinSupport(0.5) - .setMaxPatternLength(5) -val model = prefixSpan.run(sequences) -model.freqSequences.collect().foreach { freqSequence => -println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + ", " + freqSequence.freq) -} -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala %} @@ -309,27 +175,7 @@ that stores the frequent sequences with their frequencies. Refer to the [`PrefixSpan` Java docs](api/java/org/apache/spark/mllib/fpm/PrefixSpan.html) and [`PrefixSpanModel` Java docs](api/java/org/apache/spark/mllib/fpm/PrefixSpanModel.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.mllib.fpm.PrefixSpan; -import org.apache.spark.mllib.fpm.PrefixSpanModel; - -JavaRDD>> sequences = sc.parallelize(Arrays.asList( - Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), - Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), - Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), - Arrays.asList(Arrays.asList(6)) -), 2); -PrefixSpan prefixSpan = new PrefixSpan() - .setMinSupport(0.5) - .setMaxPatternLength(5); -PrefixSpanModel model = prefixSpan.run(sequences); -for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { - System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java %} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java new file mode 100644 index 0000000000000..4d0f989819ace --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaAssociationRulesExample.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; +// $example off$ + +import org.apache.spark.SparkConf; + +public class JavaAssociationRulesExample { + + public static void main(String[] args) { + + SparkConf sparkConf = new SparkConf().setAppName("JavaAssociationRulesExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( + new FreqItemset(new String[] {"a"}, 15L), + new FreqItemset(new String[] {"b"}, 35L), + new FreqItemset(new String[] {"a", "b"}, 12L) + )); + + AssociationRules arules = new AssociationRules() + .setMinConfidence(0.8); + JavaRDD> results = arules.run(freqItemsets); + + for (AssociationRules.Rule rule : results.collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java new file mode 100644 index 0000000000000..68ec7c1e6ebe0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPrefixSpanExample.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; +// $example off$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.mllib.fpm.PrefixSpan; +import org.apache.spark.mllib.fpm.PrefixSpanModel; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaPrefixSpanExample { + + public static void main(String[] args) { + + SparkConf sparkConf = new SparkConf().setAppName("JavaPrefixSpanExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + for (PrefixSpan.FreqSequence freqSeq: model.freqSequences().toJavaRDD().collect()) { + System.out.println(freqSeq.javaSequence() + ", " + freqSeq.freq()); + } + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java new file mode 100644 index 0000000000000..72edaca5e95b1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSimpleFPGrowth.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +// $example off$ +import org.apache.spark.api.java.function.Function; +// $example on$ +import org.apache.spark.mllib.fpm.AssociationRules; +import org.apache.spark.mllib.fpm.FPGrowth; +import org.apache.spark.mllib.fpm.FPGrowthModel; +// $example off$ + +import org.apache.spark.SparkConf; + +public class JavaSimpleFPGrowth { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("FP-growth Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // $example on$ + JavaRDD data = sc.textFile("data/mllib/sample_fpgrowth.txt"); + + JavaRDD> transactions = data.map( + new Function>() { + public List call(String line) { + String[] parts = line.split(" "); + return Arrays.asList(parts); + } + } + ); + + FPGrowth fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10); + FPGrowthModel model = fpg.run(transactions); + + for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + itemset.javaItems() + "], " + itemset.freq()); + } + + double minConfidence = 0.8; + for (AssociationRules.Rule rule + : model.generateAssociationRules(minConfidence).toJavaRDD().collect()) { + System.out.println( + rule.javaAntecedent() + " => " + rule.javaConsequent() + ", " + rule.confidence()); + } + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/fpgrowth_example.py b/examples/src/main/python/mllib/fpgrowth_example.py new file mode 100644 index 0000000000000..715f5268206cb --- /dev/null +++ b/examples/src/main/python/mllib/fpgrowth_example.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.fpm import FPGrowth +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="FPGrowth") + + # $example on$ + data = sc.textFile("data/mllib/sample_fpgrowth.txt") + transactions = data.map(lambda line: line.strip().split(' ')) + model = FPGrowth.train(transactions, minSupport=0.2, numPartitions=10) + result = model.freqItemsets().collect() + for fi in result: + print(fi) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala new file mode 100644 index 0000000000000..ca22ddafc3c48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.AssociationRules +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object AssociationRulesExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("AssociationRulesExample") + val sc = new SparkContext(conf) + + // $example on$ + val freqItemsets = sc.parallelize(Seq( + new FreqItemset(Array("a"), 15L), + new FreqItemset(Array("b"), 35L), + new FreqItemset(Array("a", "b"), 12L) + )) + + val ar = new AssociationRules() + .setMinConfidence(0.8) + val results = ar.run(freqItemsets) + + results.collect().foreach { rule => + println("[" + rule.antecedent.mkString(",") + + "=>" + + rule.consequent.mkString(",") + "]," + rule.confidence) + } + // $example off$ + } + +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala new file mode 100644 index 0000000000000..d237232c430ca --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.PrefixSpan +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object PrefixSpanExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("PrefixSpanExample") + val sc = new SparkContext(conf) + + // $example on$ + val sequences = sc.parallelize(Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6)) + ), 2).cache() + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(sequences) + model.freqSequences.collect().foreach { freqSequence => + println( + freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + + ", " + freqSequence.freq) + } + // $example off$ + } +} +// scalastyle:off println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala new file mode 100644 index 0000000000000..b4e06afa7410f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.fpm.FPGrowth +import org.apache.spark.rdd.RDD +// $example off$ + +import org.apache.spark.{SparkContext, SparkConf} + +object SimpleFPGrowth { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName("SimpleFPGrowth") + val sc = new SparkContext(conf) + + // $example on$ + val data = sc.textFile("data/mllib/sample_fpgrowth.txt") + + val transactions: RDD[Array[String]] = data.map(s => s.trim.split(' ')) + + val fpg = new FPGrowth() + .setMinSupport(0.2) + .setNumPartitions(10) + val model = fpg.run(transactions) + + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + } + + val minConfidence = 0.8 + model.generateAssociationRules(minConfidence).collect().foreach { rule => + println( + rule.antecedent.mkString("[", ",", "]") + + " => " + rule.consequent .mkString("[", ",", "]") + + ", " + rule.confidence) + } + // $example off$ + } +} +// scalastyle:on println From 9b214cea896056e7d0a69ae9d3c282e1f027d5b9 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 4 Nov 2015 08:36:55 -0800 Subject: [PATCH 164/324] [SPARK-11443] Reserve space lines The trim_codeblock(lines) function in include_example.rb removes some blank lines in the code. Author: Xusen Yin Closes #9400 from yinxusen/SPARK-11443. --- docs/_plugins/include_example.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 0f4184c7462be..6ee63a5ac69df 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -50,7 +50,7 @@ def trim_codeblock(lines) .map { |l| l[/\A */].size } .min - lines.map { |l| l[min_start_spaces .. -1] } + lines.map { |l| l.strip.size == 0 ? l : l[min_start_spaces .. -1] } end # Select lines according to labels in code. Currently we use "$example on$" and "$example off$" From 8790ee6d69e50ca84eb849742be48f2476743b5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 4 Nov 2015 09:07:22 -0800 Subject: [PATCH 165/324] [SPARK-10622][CORE][YARN] Differentiate dead from "mostly dead" executors. In YARN mode, when preemption is enabled, we may leave executors in a zombie state while we wait to retrieve the reason for which the executor exited. This is so that we don't account for failed tasks that were running on a preempted executor. The issue is that while we wait for this information, the scheduler might decide to schedule tasks on the executor, which will never be able to run them. Other side effects include the block manager still considering the executor available to cache blocks, for example. So, when we know that an executor went down but we don't know why, stop everything related to the executor, except its running tasks. Only when we know the reason for the exit (or give up waiting for it) we do update the running tasks. This is achieved by a new `disableExecutor()` method in the `Schedulable` interface. For managers that do not behave like this (i.e. every one but YARN), the existing `executorLost()` method will behave the same way it did before. On top of that change, a few minor changes that made debugging easier, and fixed some other minor issues: - The cluster-mode AM was printing a misleading log message every time an executor disconnected from the driver (because the akka actor system was shared between driver and AM). - Avoid sending unnecessary requests for an executor's exit reason when we already know it was explicitly disabled / killed. This avoids both multiple requests, and unnecessary requests that would just cause warning messages on the AM (in the explicit kill case). - Tone down a log message about the executor being lost when it exited normally (e.g. preemption) - Wake up the AM monitor thread when requests for executor loss reasons arrive too, so that we can more quickly remove executors from this zombie state. Author: Marcelo Vanzin Closes #8887 from vanzin/SPARK-10622. --- .../spark/scheduler/ExecutorLossReason.scala | 9 +++ .../spark/scheduler/TaskSchedulerImpl.scala | 32 ++++++-- .../CoarseGrainedSchedulerBackend.scala | 37 ++++++++- .../cluster/YarnSchedulerBackend.scala | 9 +-- .../scheduler/TaskSchedulerImplSuite.scala | 36 +++++++++ .../spark/deploy/yarn/ApplicationMaster.scala | 79 +++++++++++-------- .../spark/deploy/yarn/YarnAllocator.scala | 5 ++ 7 files changed, 157 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 33edf25043850..47a5cbff4930b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,15 @@ private[spark] object ExecutorExited { } } +/** + * A loss reason that means we don't yet know why the executor exited. + * + * This is used by the task scheduler to remove state associated with the executor, but + * not yet fail any tasks that were running in the executor before the real loss reason + * is known. + */ +private [spark] object LossReasonPending extends ExecutorLossReason("Pending loss reason.") + private[spark] case class SlaveLost(_message: String = "Slave lost") extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1c7bfe89c02ac..43d7d80b7aae1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -468,11 +468,20 @@ private[spark] class TaskSchedulerImpl( removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) + executorIdToHost.get(executorId) match { + case Some(_) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -482,7 +491,11 @@ private[spark] class TaskSchedulerImpl( } } - /** Remove an executor from all our data structures and mark it as lost */ + /** + * Remove an executor from all our data structures and mark it as lost. If the executor's loss + * reason is not yet known, do not yet remove its association with its host nor update the status + * of any running tasks, since the loss reason defines whether we'll fail those tasks. + */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { activeExecutorIds -= executorId val host = executorIdToHost(executorId) @@ -497,8 +510,11 @@ private[spark] class TaskSchedulerImpl( } } } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host, reason) + + if (reason != LossReasonPending) { + executorIdToHost -= executorId + rootPool.executorLost(executorId, host, reason) + } } def executorAdded(execId: String, host: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index ebce5021b19dc..f71d98feac050 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -73,6 +73,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // The number of pending tasks which is locality required protected var localityAwareTasks = 0 + // Executors that have been lost, but for which we don't yet know the real exit reason. + protected val executorsPendingLossReason = new HashSet[String] + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -184,7 +187,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { // Filter out executors under killing - val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq @@ -202,7 +205,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { // Filter out executors under killing - if (!executorsPendingToRemove.contains(executorId)) { + if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = Seq( new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) @@ -210,6 +213,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + private def executorIsAlive(executorId: String): Boolean = synchronized { + !executorsPendingToRemove.contains(executorId) && + !executorsPendingLossReason.contains(executorId) + } + // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { @@ -246,6 +254,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingToRemove -= executorId + executorsPendingLossReason -= executorId } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) @@ -256,6 +265,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Stop making resource offers for the given executor. The executor is marked as lost with + * the loss reason still pending. + * + * @return Whether executor was alive. + */ + protected def disableExecutor(executorId: String): Boolean = { + val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { + if (executorIsAlive(executorId)) { + executorsPendingLossReason += executorId + true + } else { + false + } + } + + if (shouldDisable) { + logInfo(s"Disabling executor $executorId.") + scheduler.executorLost(executorId, LossReasonPending) + } + + shouldDisable + } + override def onStop() { reviveThread.shutdownNow() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index d75d6f673e84e..80da37b09b590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -115,15 +115,12 @@ private[spark] abstract class YarnSchedulerBackend( * (e.g., preemption), according to the application master, then we pass that information down * to the TaskSetManager to inform the TaskSetManager that tasks on that lost executor should * not count towards a job failure. - * - * TODO there's a race condition where while we are querying the ApplicationMaster for - * the executor loss reason, there is the potential that tasks will be scheduled on - * the executor that failed. We should fix this by having this onDisconnected event - * also "blacklist" executors so that tasks are not assigned to them. */ override def onDisconnected(rpcAddress: RpcAddress): Unit = { addressToExecutorId.get(rpcAddress).foreach { executorId => - yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + if (disableExecutor(executorId)) { + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index c2edd4c317d6e..2afb595e6f10d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -237,4 +237,40 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L } } + test("tasks are not re-scheduled while executor loss reason is pending") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = Seq(new WorkerOffer("executor0", "host0", 1)) + val e1Offers = Seq(new WorkerOffer("executor1", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // mark executor0 as dead but pending fail reason + taskScheduler.executorLost("executor0", LossReasonPending) + + // offer some more resources on a different executor, nothing should change + val taskDescriptions2 = taskScheduler.resourceOffers(e1Offers).flatten + assert(0 === taskDescriptions2.length) + + // provide the actual loss reason for executor0 + taskScheduler.executorLost("executor0", SlaveLost("oops")) + + // executor0's tasks should have failed now that the loss reason is known, so offering more + // resources should make them be scheduled on the new executor. + val taskDescriptions3 = taskScheduler.resourceOffers(e1Offers).flatten + assert(1 === taskDescriptions3.length) + assert("executor1" === taskDescriptions3(0).executorId) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 12ae350e4cef6..50ae7ffeec4c5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -87,8 +87,27 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + + // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() + // Steady state heartbeat interval. We want to be reasonably responsive without causing too many + // requests to RM. + private val heartbeatInterval = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) + } + + // Initial wait interval before allocator poll, to allow for quicker ramp up when executors are + // being requested. + private val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + + // Next wait interval before allocator poll. + private var nextAllocationInterval = initialAllocationInterval + // Fields used in client mode. private var rpcEnv: RpcEnv = null private var amEndpoint: RpcEndpointRef = _ @@ -332,19 +351,6 @@ private[spark] class ApplicationMaster( } private def launchReporterThread(): Thread = { - // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. - val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - - // we want to be reasonably responsive without causing too many requests to RM. - val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) - - // we want to check more frequently for pending containers - val initialAllocationInterval = math.min(heartbeatInterval, - sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) - - var nextAllocationInterval = initialAllocationInterval - // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -377,19 +383,19 @@ private[spark] class ApplicationMaster( } try { val numPendingAllocate = allocator.getPendingAllocate.size - val sleepInterval = - if (numPendingAllocate > 0) { - val currentAllocationInterval = - math.min(heartbeatInterval, nextAllocationInterval) - nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow - currentAllocationInterval - } else { - nextAllocationInterval = initialAllocationInterval - heartbeatInterval - } - logDebug(s"Number of pending allocations is $numPendingAllocate. " + - s"Sleeping for $sleepInterval.") allocatorLock.synchronized { + val sleepInterval = + if (numPendingAllocate > 0 || allocator.getNumPendingLossReasonRequests > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") allocatorLock.wait(sleepInterval) } } catch { @@ -560,6 +566,11 @@ private[spark] class ApplicationMaster( userThread } + private def resetAllocatorInterval(): Unit = allocatorLock.synchronized { + nextAllocationInterval = initialAllocationInterval + allocatorLock.notifyAll() + } + /** * An [[RpcEndpoint]] that communicates with the driver's scheduler backend. */ @@ -581,11 +592,9 @@ private[spark] class ApplicationMaster( case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { case Some(a) => - allocatorLock.synchronized { - if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, - localityAwareTasks, hostToLocalTaskCount)) { - allocatorLock.notifyAll() - } + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { + resetAllocatorInterval() } case None => @@ -603,17 +612,19 @@ private[spark] class ApplicationMaster( case GetExecutorLossReason(eid) => Option(allocator) match { - case Some(a) => a.enqueueGetLossReasonRequest(eid, context) - case None => logWarning(s"Container allocator is not ready to find" + - s" executor loss reasons yet.") + case Some(a) => + a.enqueueGetLossReasonRequest(eid, context) + resetAllocatorInterval() + case None => + logWarning("Container allocator is not ready to find executor loss reasons yet.") } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") // In cluster mode, do not rely on the disassociated event to exit // This avoids potentially reporting incorrect exit codes if the driver fails if (!isClusterMode) { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a0cf1b4aa469b..4d9e777cb4134 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -550,6 +550,10 @@ private[yarn] class YarnAllocator( private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + private[yarn] def getNumPendingLossReasonRequests: Int = synchronized { + pendingLossReasonRequests.size + } + /** * Split the pending container requests into 3 groups based on current localities of pending * tasks. @@ -582,6 +586,7 @@ private[yarn] class YarnAllocator( (localityMatched.toSeq, localityUnMatched.toSeq, localityFree.toSeq) } + } private object YarnAllocator { From 27feafccbd6945b000ca51b14c57912acbad9031 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 4 Nov 2015 09:11:54 -0800 Subject: [PATCH 166/324] [SPARK-11235][NETWORK] Add ability to stream data using network lib. The current interface used to fetch shuffle data is not very efficient for large buffers; it requires the receiver to buffer the entirety of the contents being downloaded in memory before processing the data. To use the network library to transfer large files (such as those that can be added using SparkContext addJar / addFile), this change adds a more efficient way of downloding data, by streaming the data and feeding it to a callback as data arrives. This is achieved by a custom frame decoder that replaces the current netty one; this decoder allows entering a mode where framing is skipped and data is instead provided directly to a callback. The existing netty classes (ByteToMessageDecoder and LengthFieldBasedFrameDecoder) could not be reused since their semantics do not allow for the interception approach the new decoder uses. Author: Marcelo Vanzin Closes #9206 from vanzin/SPARK-11235. --- .../spark/network/TransportContext.java | 3 +- .../spark/network/client/StreamCallback.java | 40 +++ .../network/client/StreamInterceptor.java | 76 ++++ .../spark/network/client/TransportClient.java | 41 +++ .../client/TransportResponseHandler.java | 47 ++- .../network/protocol/ChunkFetchSuccess.java | 16 +- .../spark/network/protocol/Message.java | 6 +- .../network/protocol/MessageDecoder.java | 9 + .../network/protocol/MessageEncoder.java | 27 +- .../network/protocol/ResponseWithBody.java | 40 +++ .../spark/network/protocol/StreamFailure.java | 80 +++++ .../spark/network/protocol/StreamRequest.java | 78 +++++ .../network/protocol/StreamResponse.java | 91 +++++ .../spark/network/server/StreamManager.java | 13 + .../server/TransportRequestHandler.java | 20 ++ .../apache/spark/network/util/NettyUtils.java | 9 +- .../network/util/TransportFrameDecoder.java | 154 +++++++++ .../apache/spark/network/ProtocolSuite.java | 8 + .../org/apache/spark/network/StreamSuite.java | 325 ++++++++++++++++++ .../util/TransportFrameDecoderSuite.java | 142 ++++++++ 20 files changed, 1196 insertions(+), 29 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java create mode 100644 network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java create mode 100644 network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java create mode 100644 network/common/src/test/java/org/apache/spark/network/StreamSuite.java create mode 100644 network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index b8d073fa16b4b..43900e6f2c972 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -39,6 +39,7 @@ import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.TransportFrameDecoder; /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to @@ -119,7 +120,7 @@ public TransportChannelHandler initializePipeline( TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", encoder) - .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java new file mode 100644 index 0000000000000..093fada320cc3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link onComplete()} will be + * called. + *

    + * The network library guarantees that a single thread will call these methods at a time, but + * different call may be made by different threads. + */ +public interface StreamCallback { + /** Called upon receipt of stream data. */ + void onData(String streamId, ByteBuffer buf) throws IOException; + + /** Called when all data from the stream has been received. */ + void onComplete(String streamId) throws IOException; + + /** Called if there's an error reading data from the stream. */ + void onFailure(String streamId, Throwable cause) throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java new file mode 100644 index 0000000000000..02230a00e69fc --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.util.TransportFrameDecoder; + +/** + * An interceptor that is registered with the frame decoder to feed stream data to a + * callback. + */ +class StreamInterceptor implements TransportFrameDecoder.Interceptor { + + private final String streamId; + private final long byteCount; + private final StreamCallback callback; + + private volatile long bytesRead; + + StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + this.streamId = streamId; + this.byteCount = byteCount; + this.callback = callback; + this.bytesRead = 0; + } + + @Override + public void exceptionCaught(Throwable cause) throws Exception { + callback.onFailure(streamId, cause); + } + + @Override + public void channelInactive() throws Exception { + callback.onFailure(streamId, new ClosedChannelException()); + } + + @Override + public boolean handle(ByteBuf buf) throws Exception { + int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); + ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer(); + + int available = nioBuffer.remaining(); + callback.onData(streamId, nioBuffer); + bytesRead += available; + if (bytesRead > byteCount) { + RuntimeException re = new IllegalStateException(String.format( + "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); + callback.onFailure(streamId, re); + throw re; + } else if (bytesRead == byteCount) { + callback.onComplete(streamId); + } + + return bytesRead != byteCount; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index fbb8bb6b2f6c3..a0ba223e340a2 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -38,6 +38,7 @@ import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamRequest; import org.apache.spark.network.util.NettyUtils; /** @@ -159,6 +160,46 @@ public void operationComplete(ChannelFuture future) throws Exception { }); } + /** + * Request to stream the data with the given stream ID from the remote end. + * + * @param streamId The stream to fetch. + * @param callback Object to call with the stream data. + */ + public void stream(final String streamId, final StreamCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.debug("Sending stream request for {} to {}", streamId, serverAddr); + + // Need to synchronize here so that the callback is added to the queue and the RPC is + // written to the socket atomically, so that callbacks are called in the right order + // when responses arrive. + synchronized (this) { + handler.addStreamCallback(callback); + channel.writeAndFlush(new StreamRequest(streamId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, + timeTaken); + } else { + String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + callback.onFailure(streamId, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + }); + } + } + /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 94fc21af5e606..ed3f36af58048 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -19,7 +19,9 @@ import java.io.IOException; import java.util.Map; +import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; import io.netty.channel.Channel; @@ -32,8 +34,11 @@ import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportFrameDecoder; /** * Handler that processes server responses, in response to requests issued from a @@ -50,6 +55,8 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + private final Queue streamCallbacks; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -57,6 +64,7 @@ public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.streamCallbacks = new ConcurrentLinkedQueue(); this.timeOfLastRequestNs = new AtomicLong(0); } @@ -78,6 +86,10 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } + public void addStreamCallback(StreamCallback callback) { + streamCallbacks.offer(callback); + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -124,11 +136,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.buffer.release(); + resp.body.release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); - resp.buffer.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); + resp.body.release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -161,6 +173,34 @@ public void handle(ResponseMessage message) { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof StreamResponse) { + StreamResponse resp = (StreamResponse) message; + StreamCallback callback = streamCallbacks.poll(); + if (callback != null) { + StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + } + } else { + logger.error("Could not find callback for StreamResponse."); + } + } else if (message instanceof StreamFailure) { + StreamFailure resp = (StreamFailure) message; + StreamCallback callback = streamCallbacks.poll(); + if (callback != null) { + try { + callback.onFailure(resp.streamId, new RuntimeException(resp.error)); + } catch (IOException ioe) { + logger.warn("Error in stream failure handler.", ioe); + } + } else { + logger.warn("Stream failure with unknown callback: {}", resp.error); + } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } @@ -175,4 +215,5 @@ public int numOutstandingRequests() { public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index c962fb7ecf76d..e6a7e9a8b4145 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,13 +30,12 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess implements ResponseMessage { +public final class ChunkFetchSuccess extends ResponseWithBody { public final StreamChunkId streamChunkId; - public final ManagedBuffer buffer; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + super(buffer, true); this.streamChunkId = streamChunkId; - this.buffer = buffer; } @Override @@ -53,6 +52,11 @@ public void encode(ByteBuf buf) { streamChunkId.encode(buf); } + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ public static ChunkFetchSuccess decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); @@ -63,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, buffer); + return Objects.hashCode(streamChunkId, body); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); } return false; } @@ -79,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", buffer) + .add("buffer", body) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d568370125fd4..d01598c20f16f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -27,7 +27,8 @@ public interface Message extends Encodable { /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), - RpcRequest(3), RpcResponse(4), RpcFailure(5); + RpcRequest(3), RpcResponse(4), RpcFailure(5), + StreamRequest(6), StreamResponse(7), StreamFailure(8); private final byte id; @@ -51,6 +52,9 @@ public static Type decode(ByteBuf buf) { case 3: return RpcRequest; case 4: return RpcResponse; case 5: return RpcFailure; + case 6: return StreamRequest; + case 7: return StreamResponse; + case 8: return StreamFailure; default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 81f8d7f96350f..3c04048f3821a 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,15 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case StreamRequest: + return StreamRequest.decode(in); + + case StreamResponse: + return StreamResponse.decode(in); + + case StreamFailure: + return StreamFailure.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 0f999f5dfe8d8..6cce97c807dc0 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -45,27 +45,32 @@ public final class MessageEncoder extends MessageToMessageEncoder { public void encode(ChannelHandlerContext ctx, Message in, List out) { Object body = null; long bodyLength = 0; + boolean isBodyInFrame = false; - // Only ChunkFetchSuccesses have data besides the header. + // Detect ResponseWithBody messages and get the data buffer out of them. // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ChunkFetchSuccess) { - ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + if (in instanceof ResponseWithBody) { + ResponseWithBody resp = (ResponseWithBody) in; try { - bodyLength = resp.buffer.size(); - body = resp.buffer.convertToNetty(); + bodyLength = resp.body.size(); + body = resp.body.convertToNetty(); + isBodyInFrame = resp.isBodyInFrame; } catch (Exception e) { - // Re-encode this message as BlockFetchFailure. - logger.error(String.format("Error opening block %s for client %s", - resp.streamChunkId, ctx.channel().remoteAddress()), e); - encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + resp, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); return; } } Message.Type msgType = in.type(); - // All messages have the frame length, message type, and message itself. + // All messages have the frame length, message type, and message itself. The frame length + // may optionally include the length of the body data, depending on what message is being + // sent. int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); - long frameLength = headerLength + bodyLength; + long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0); ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeLong(frameLength); msgType.encode(header); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java new file mode 100644 index 0000000000000..67be77e39f711 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Abstract class for response messages that contain a large data portion kept in a separate + * buffer. These messages are treated especially by MessageEncoder. + */ +public abstract class ResponseWithBody implements ResponseMessage { + public final ManagedBuffer body; + public final boolean isBodyInFrame; + + protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + public abstract ResponseMessage createFailureResponse(String error); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java new file mode 100644 index 0000000000000..e3dade2ebf905 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Message indicating an error when transferring a stream. + */ +public final class StreamFailure implements ResponseMessage { + public final String streamId; + public final String error; + + public StreamFailure(String streamId, String error) { + this.streamId = streamId; + this.error = error; + } + + @Override + public Type type() { return Type.StreamFailure; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(streamId) + Encoders.Strings.encodedLength(error); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + Encoders.Strings.encode(buf, error); + } + + public static StreamFailure decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + String error = Encoders.Strings.decode(buf); + return new StreamFailure(streamId, error); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, error); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamFailure) { + StreamFailure o = (StreamFailure) other; + return streamId.equals(o.streamId) && error.equals(o.error); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("error", error) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java new file mode 100644 index 0000000000000..821e8f53884d7 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Request to stream data from the remote end. + *

    + * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before + * the data can be streamed. + */ +public final class StreamRequest implements RequestMessage { + public final String streamId; + + public StreamRequest(String streamId) { + this.streamId = streamId; + } + + @Override + public Type type() { return Type.StreamRequest; } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(streamId); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + } + + public static StreamRequest decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + return new StreamRequest(streamId); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamRequest) { + StreamRequest o = (StreamRequest) other; + return streamId.equals(o.streamId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java new file mode 100644 index 0000000000000..ac5ab9a323a11 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link StreamRequest} when the stream has been successfully opened. + *

    + * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class StreamResponse extends ResponseWithBody { + public final String streamId; + public final long byteCount; + + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } + + @Override + public Type type() { return Type.StreamResponse; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(streamId); + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + buf.writeLong(byteCount); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new StreamFailure(streamId, error); + } + + public static StreamResponse decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + long byteCount = buf.readLong(); + return new StreamResponse(streamId, byteCount, null); + } + + @Override + public int hashCode() { + return Objects.hashCode(byteCount, streamId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamResponse) { + StreamResponse o = (StreamResponse) other; + return byteCount == o.byteCount && streamId.equals(o.streamId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("byteCount", byteCount) + .toString(); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index aaa677c965640..3f0155957a140 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -46,6 +46,19 @@ public abstract class StreamManager { */ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + /** + * Called in response to a stream() request. The returned data is streamed to the client + * through a single TCP connection. + * + * Note the streamId argument is not related to the similarly named argument in the + * {@link #getChunk(long, int)} method. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + */ + public ManagedBuffer openStream(String streamId) { + throw new UnsupportedOperationException(); + } + /** * Associates a stream with a single client connection, which is guaranteed to be the only reader * of the stream. The getChunk() method will be called serially on this connection and once the diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 9b8b047b49a86..4f67bd573be21 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -35,6 +35,9 @@ import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.util.NettyUtils; /** @@ -92,6 +95,8 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof StreamRequest) { + processStreamRequest((StreamRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -117,6 +122,21 @@ private void processFetchRequest(final ChunkFetchRequest req) { respond(new ChunkFetchSuccess(req.streamChunkId, buf)); } + private void processStreamRequest(final StreamRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); + ManagedBuffer buf; + try { + buf = streamManager.openStream(req.streamId); + } catch (Exception e) { + logger.error(String.format( + "Error opening stream %s for request from %s", req.streamId, client), e); + respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(new StreamResponse(req.streamId, buf.size(), buf)); + } + private void processRpcRequest(final RpcRequest req) { try { rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 26c6399ce7dbc..caa7260bc8281 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -89,13 +89,8 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static ByteToMessageDecoder createFrameDecoder() { - // maxFrameLength = 2G - // lengthFieldOffset = 0 - // lengthFieldLength = 8 - // lengthAdjustment = -8, i.e. exclude the 8 byte length itself - // initialBytesToStrip = 8, i.e. strip out the length field itself - return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + public static TransportFrameDecoder createFrameDecoder() { + return new TransportFrameDecoder(); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java new file mode 100644 index 0000000000000..272ea84e6180d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +/** + * A customized frame decoder that allows intercepting raw data. + *

    + * This behaves like Netty's frame decoder (with harcoded parameters that match this library's + * needs), except it allows an interceptor to be installed to read data directly before it's + * framed. + *

    + * Unlike Netty's frame decoder, each frame is dispatched to child handlers as soon as it's + * decoded, instead of building as many frames as the current buffer allows and dispatching + * all of them. This allows a child handler to install an interceptor if needed. + *

    + * If an interceptor is installed, framing stops, and data is instead fed directly to the + * interceptor. When the interceptor indicates that it doesn't need to read any more data, + * framing resumes. Interceptors should not hold references to the data buffers provided + * to their handle() method. + */ +public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { + + public static final String HANDLER_NAME = "frameDecoder"; + private static final int LENGTH_SIZE = 8; + private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + + private CompositeByteBuf buffer; + private volatile Interceptor interceptor; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf in = (ByteBuf) data; + + if (buffer == null) { + buffer = in.alloc().compositeBuffer(); + } + + buffer.writeBytes(in); + + while (buffer.isReadable()) { + feedInterceptor(); + if (interceptor != null) { + continue; + } + + ByteBuf frame = decodeNext(); + if (frame != null) { + ctx.fireChannelRead(frame); + } else { + break; + } + } + + // We can't discard read sub-buffers if there are other references to the buffer (e.g. + // through slices used for framing). This assumes that code that retains references + // will call retain() from the thread that called "fireChannelRead()" above, otherwise + // ref counting will go awry. + if (buffer != null && buffer.refCnt() == 1) { + buffer.discardReadComponents(); + } + } + + protected ByteBuf decodeNext() throws Exception { + if (buffer.readableBytes() < LENGTH_SIZE) { + return null; + } + + int frameLen = (int) buffer.readLong() - LENGTH_SIZE; + if (buffer.readableBytes() < frameLen) { + buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); + return null; + } + + Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); + Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); + + ByteBuf frame = buffer.readSlice(frameLen); + frame.retain(); + return frame; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (buffer != null) { + if (buffer.isReadable()) { + feedInterceptor(); + } + buffer.release(); + } + if (interceptor != null) { + interceptor.channelInactive(); + } + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (interceptor != null) { + interceptor.exceptionCaught(cause); + } + super.exceptionCaught(ctx, cause); + } + + public void setInterceptor(Interceptor interceptor) { + Preconditions.checkState(this.interceptor == null, "Already have an interceptor."); + this.interceptor = interceptor; + } + + private void feedInterceptor() throws Exception { + if (interceptor != null && !interceptor.handle(buffer)) { + interceptor = null; + } + } + + public static interface Interceptor { + + /** + * Handles data received from the remote end. + * + * @param data Buffer containing data. + * @return "true" if the interceptor expects more data, "false" to uninstall the interceptor. + */ + boolean handle(ByteBuf data) throws Exception; + + /** Called if an exception is thrown in the channel pipeline. */ + void exceptionCaught(Throwable cause) throws Exception; + + /** Called if the channel is closed and the interceptor is still installed. */ + void channelInactive() throws Exception; + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index d500bc3c98a78..22b451fc0e60e 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -39,6 +39,9 @@ import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; @@ -80,6 +83,7 @@ public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); testClientToServer(new RpcRequest(12345, new byte[0])); testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new StreamRequest("abcde")); } @Test @@ -92,6 +96,10 @@ public void responses() { testServerToClient(new RpcResponse(12345, new byte[1000])); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); + // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the + // channel and cannot be tested like this. + testServerToClient(new StreamResponse("anId", 12345L, new TestManagedBuffer(0))); + testServerToClient(new StreamFailure("anId", "this is an error")); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java new file mode 100644 index 0000000000000..6dcec831dec71 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import com.google.common.io.Files; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class StreamSuite { + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + + private static TransportServer server; + private static TransportClientFactory clientFactory; + private static File testFile; + private static File tempDir; + + private static ByteBuffer smallBuffer; + private static ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + @BeforeClass + public static void setUp() throws Exception { + tempDir = Files.createTempDir(); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final StreamManager streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public ManagedBuffer openStream(String streamId) { + switch (streamId) { + case "largeBuffer": + return new NioManagedBuffer(largeBuffer); + case "smallBuffer": + return new NioManagedBuffer(smallBuffer); + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + throw new IllegalArgumentException("Invalid stream: " + streamId); + } + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + if (tempDir != null) { + for (File f : tempDir.listFiles()) { + f.delete(); + } + tempDir.delete(); + } + } + + @Test + public void testSingleStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "largeBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + + @Test + public void testMultipleStreams() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + for (int i = 0; i < 20; i++) { + StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], + TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } + } finally { + client.close(); + } + } + + @Test + public void testConcurrentStreams() throws Throwable { + ExecutorService executor = Executors.newFixedThreadPool(20); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + try { + List tasks = new ArrayList<>(); + for (int i = 0; i < 20; i++) { + StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length], + TimeUnit.SECONDS.toMillis(20)); + tasks.add(task); + executor.submit(task); + } + + executor.shutdown(); + assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, TimeUnit.SECONDS)); + for (StreamTask task : tasks) { + task.check(); + } + } finally { + executor.shutdownNow(); + client.close(); + } + } + + private static class StreamTask implements Runnable { + + private final TransportClient client; + private final String streamId; + private final long timeoutMs; + private Throwable error; + + StreamTask(TransportClient client, String streamId, long timeoutMs) { + this.client = client; + this.streamId = streamId; + this.timeoutMs = timeoutMs; + } + + @Override + public void run() { + ByteBuffer srcBuffer = null; + OutputStream out = null; + File outFile = null; + try { + ByteArrayOutputStream baos = null; + + switch (streamId) { + case "largeBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = largeBuffer; + break; + case "smallBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = smallBuffer; + break; + case "file": + outFile = File.createTempFile("data", ".tmp", tempDir); + out = new FileOutputStream(outFile); + break; + default: + throw new IllegalArgumentException(streamId); + } + + TestCallback callback = new TestCallback(out); + client.stream(streamId, callback); + waitForCompletion(callback); + + if (srcBuffer == null) { + assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + } else { + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] result = baos.toByteArray(); + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } catch (Throwable t) { + error = t; + } finally { + if (out != null) { + try { + out.close(); + } catch (Exception e) { + // ignore. + } + } + if (outFile != null) { + outFile.delete(); + } + } + } + + public void check() throws Throwable { + if (error != null) { + throw error; + } + } + + private void waitForCompletion(TestCallback callback) throws Exception { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (callback) { + while (!callback.completed && now < deadline) { + callback.wait(deadline - now); + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", callback.completed); + assertNull(callback.error); + } + + } + + private static class TestCallback implements StreamCallback { + + private final OutputStream out; + public volatile boolean completed; + public volatile Throwable error; + + TestCallback(OutputStream out) { + this.out = out; + this.completed = false; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + byte[] tmp = new byte[buf.remaining()]; + buf.get(tmp); + out.write(tmp); + } + + @Override + public void onComplete(String streamId) throws IOException { + out.close(); + synchronized (this) { + completed = true; + notifyAll(); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) { + error = cause; + synchronized (this) { + completed = true; + notifyAll(); + } + } + + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java new file mode 100644 index 0000000000000..ca74f0a00cf9d --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.nio.ByteBuffer; +import java.util.Random; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class TransportFrameDecoderSuite { + + @Test + public void testFrameDecoding() throws Exception { + Random rnd = new Random(); + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + + final int frameCount = 100; + ByteBuf data = Unpooled.buffer(); + try { + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + while (data.isReadable()) { + int size = rnd.nextInt(16 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } finally { + data.release(); + } + } + + @Test + public void testInterception() throws Exception { + final int interceptedReads = 3; + TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + + byte[] data = new byte[8]; + ByteBuf len = Unpooled.copyLong(8 + data.length); + ByteBuf dataBuf = Unpooled.wrappedBuffer(data); + + try { + decoder.setInterceptor(interceptor); + for (int i = 0; i < interceptedReads; i++) { + decoder.channelRead(ctx, dataBuf); + dataBuf.release(); + dataBuf = Unpooled.wrappedBuffer(data); + } + decoder.channelRead(ctx, len); + decoder.channelRead(ctx, dataBuf); + verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ByteBuffer.class)); + } finally { + len.release(); + dataBuf.release(); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testNegativeFrameSize() throws Exception { + testInvalidFrame(-1); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyFrame() throws Exception { + // 8 because frame size includes the frame length. + testInvalidFrame(8); + } + + @Test(expected = IllegalArgumentException.class) + public void testLargeFrame() throws Exception { + // Frame length includes the frame size field, so need to add a few more bytes. + testInvalidFrame(Integer.MAX_VALUE + 9); + } + + private void testInvalidFrame(long size) throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ByteBuf frame = Unpooled.copyLong(size); + try { + decoder.channelRead(ctx, frame); + } finally { + frame.release(); + } + } + + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { + + private int remainingReads; + + MockInterceptor(int readCount) { + this.remainingReads = readCount; + } + + @Override + public boolean handle(ByteBuf data) throws Exception { + data.readerIndex(data.readerIndex() + data.readableBytes()); + assertFalse(data.isReadable()); + remainingReads -= 1; + return remainingReads != 0; + } + + @Override + public void exceptionCaught(Throwable cause) throws Exception { + + } + + @Override + public void channelInactive() throws Exception { + + } + + } + +} From cd1df662386c599a9d0968b9fc14f27b0883d285 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 09:32:30 -0800 Subject: [PATCH 167/324] [SPARK-11485][SQL] Make DataFrameHolder and DatasetHolder public. These two classes should be public, since they are used in public code. Author: Reynold Xin Closes #9445 from rxin/SPARK-11485. --- project/MimaExcludes.scala | 3 +++ .../scala/org/apache/spark/sql/DataFrameHolder.scala | 7 ++++++- .../scala/org/apache/spark/sql/DatasetHolder.scala | 11 ++++++++--- .../scala/org/apache/spark/sql/SQLImplicits.scala | 4 ++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eeef96c378bdb..90dc947d4e588 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -161,6 +161,9 @@ object MimaExcludes { "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") + ) ++ Seq( + // SPARK-11485 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index 2f19ec0403017..3b30337f1f877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql /** * A container for a [[DataFrame]], used for implicit conversions. * + * To use this, import implicit conversions in SQL: + * {{{ + * import sqlContext.implicits._ + * }}} + * * @since 1.3.0 */ -private[sql] case class DataFrameHolder(df: DataFrame) { +case class DataFrameHolder private[sql](private val df: DataFrame) { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 17817cbcc5e05..45f0098b92887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -18,11 +18,16 @@ package org.apache.spark.sql /** - * A container for a [[DataFrame]], used for implicit conversions. + * A container for a [[Dataset]], used for implicit conversions. * - * @since 1.3.0 + * To use this, import implicit conversions in SQL: + * {{{ + * import sqlContext.implicits._ + * }}} + * + * @since 1.6.0 */ -private[sql] case class DatasetHolder[T](df: Dataset[T]) { +case class DatasetHolder[T] private[sql](private val df: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index f2904e270811e..6da46a5f7ef9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -52,6 +52,10 @@ abstract class SQLImplicits { DatasetHolder(_sqlContext.createDataset(rdd)) } + /** + * Creates a [[Dataset]] from a local Seq. + * @since 1.6.0 + */ implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) } From e0fc9c7e59848cb78f8d598898bfca004a3710d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Nov 2015 09:33:30 -0800 Subject: [PATCH 168/324] [SPARK-11197][SQL] add doc for run SQL on files directly Author: Wenchen Fan Closes #9467 from cloud-fan/doc. --- docs/sql-programming-guide.md | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 510b3599721a3..2fe5c36338899 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -882,6 +882,44 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") +### Run SQL on files directly + +Instead of using read API to load a file into DataFrame and query it, you can also query that +file directly with SQL. + +

    +
    + +{% highlight scala %} +val df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +DataFrame df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); +{% endhighlight %} +
    + +
    + +{% highlight python %} +df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    + +
    + +{% highlight r %} +df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +{% endhighlight %} + +
    +
    + ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if From 3bd6f5d2ae503468de0e218d51c331e249a862bb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 09:34:52 -0800 Subject: [PATCH 169/324] [SPARK-11490][SQL] variance should alias var_samp instead of var_pop. stddev is an alias for stddev_samp. variance should be consistent with stddev. Also took the chance to remove internal Stddev and Variance, and only kept StddevSamp/StddevPop and VarianceSamp/VariancePop. Author: Reynold Xin Closes #9449 from rxin/SPARK-11490. --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 - .../spark/sql/catalyst/dsl/package.scala | 8 ---- .../expressions/aggregate/functions.scala | 29 ------------ .../expressions/aggregate/utils.scala | 12 ----- .../sql/catalyst/expressions/aggregates.scala | 45 +++++-------------- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 4 +- .../org/apache/spark/sql/functions.scala | 9 ++-- .../spark/sql/DataFrameAggregateSuite.scala | 17 +++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 14 +++--- 11 files changed, 32 insertions(+), 114 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 24c1a7b7ac5af..d4334d16289a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -187,11 +187,11 @@ object FunctionRegistry { expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), - expression[Stddev]("stddev"), + expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), - expression[Variance]("variance"), + expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), expression[Skewness]("skewness"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3c675672dab85..84e2b1366f626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,10 +297,8 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case Variance(e @ StringType()) => Variance(Cast(e, DoubleType)) case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 787f67a297e33..d8df66430a695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,14 +159,6 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) - def stddev(e: Expression): Expression = Stddev(e) - def stddev_pop(e: Expression): Expression = StddevPop(e) - def stddev_samp(e: Expression): Expression = StddevSamp(e) - def variance(e: Expression): Expression = Variance(e) - def var_pop(e: Expression): Expression = VariancePop(e) - def var_samp(e: Expression): Expression = VarianceSamp(e) - def skewness(e: Expression): Expression = Skewness(e) - def kurtosis(e: Expression): Expression = Kurtosis(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index f2c3eca095115..10dc5e64b7ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -328,13 +328,6 @@ case class Min(child: Expression) extends DeclarativeAggregate { override val evaluateExpression = min } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev" -} - // Compute the population standard deviation of a column case class StddevPop(child: Expression) extends StddevAgg(child) { @@ -1274,28 +1267,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w } } -case class Variance(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "variance" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - - if (n == 0.0) Double.NaN else moments(2) / n - } -} - case class VarianceSamp(child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 564174f9b64e4..644c6211d5f31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -97,12 +97,6 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - case expressions.Stddev(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Stddev(child), - mode = aggregate.Complete, - isDistinct = false) - case expressions.StddevPop(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.StddevPop(child), @@ -139,12 +133,6 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - case expressions.Variance(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Variance(child), - mode = aggregate.Complete, - isDistinct = false) - case expressions.VariancePop(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.VariancePop(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index bf59660c385ed..89d63abd9f272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -785,13 +785,6 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia } -// Compute the sample standard deviation of a column -case class Stddev(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"STDDEV($child)" - override def isSample: Boolean = true -} - // Compute the population standard deviation of a column case class StddevPop(child: Expression) extends StddevAgg1(child) { @@ -807,20 +800,21 @@ case class StddevSamp(child: Expression) extends StddevAgg1(child) { } case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) } case class ComputePartialStdFunction ( expr: Expression, base: AggregateExpression1 -) extends AggregateFunction1 { + ) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization private val computeType = DoubleType @@ -1048,25 +1042,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp override def toString: String = s"SKEWNESS($child)" } -// placeholder -case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "variance" - - override def toString: String = s"VARIANCE($child)" -} - // placeholder case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fc0ab632f9930..5e9c7efbbf160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1383,7 +1383,7 @@ class DataFrame private[sql]( val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> Stddev, + "stddev" -> StddevSamp, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c2b2a4013d510..7cf66b65c8722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -96,10 +96,10 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min - case "stddev" | "std" => Stddev + case "stddev" | "std" => StddevSamp case "stddev_pop" => StddevPop case "stddev_samp" => StddevSamp - case "variance" => Variance + case "variance" => VarianceSamp case "var_pop" => VariancePop case "var_samp" => VarianceSamp case "sum" => Sum diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c8c52831668cd..c70c965a9b04c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -329,13 +329,12 @@ object functions { def skewness(e: Column): Column = Skewness(e.expr) /** - * Aggregate function: returns the unbiased sample standard deviation of - * the expression in a group. + * Aggregate function: alias for [[stddev_samp]]. * * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = Stddev(e.expr) + def stddev(e: Column): Column = StddevSamp(e.expr) /** * Aggregate function: returns the unbiased sample standard deviation of @@ -388,12 +387,12 @@ object functions { def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName)) /** - * Aggregate function: returns the population variance of the values in a group. + * Aggregate function: alias for [[var_samp]]. * * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = Variance(e.expr) + def variance(e: Column): Column = VarianceSamp(e.expr) /** * Aggregate function: returns the unbiased variance of the values in a group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9b23977c765dc..b0e2ffaa60687 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -226,23 +226,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) - val expectedVariance = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkVariance, expectedVariance, absTol) + checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) val sparkVariancePop = testData2.agg(var_pop('a)) - checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol) + checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) val sparkVarianceSamp = testData2.agg(var_samp('a)) - val expectedVarianceSamp = Row(4.0 / 5.0) - checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol) + checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol) val sparkSkewness = testData2.agg(skewness('a)) - val expectedSkewness = Row(0.0) - checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol) + checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol) val sparkKurtosis = testData2.agg(kurtosis('a)) - val expectedKurtosis = Row(-1.5) - checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol) - + checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) } test("zero moments") { @@ -251,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a)), - Row(0.0)) + Row(Double.NaN)) checkAnswer( emptyTableData.agg(var_samp('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6388a8b9c3720..5731a356243e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -536,7 +536,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3) + Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) ) } @@ -757,7 +757,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("variance") { val absTol = 1e-8 val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(4.0 / 6.0) + val expectedAnswer = Row(0.8) checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) } @@ -784,16 +784,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("stddev agg") { checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) } test("variance agg") { val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" + - "FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0)) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) + checkAggregatesWithTol( + sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)), + absTol) } test("skewness and kurtosis agg") { From 987df4bfcafeca3633453c2d2f8e14d221fcef33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 10:04:51 -0800 Subject: [PATCH 170/324] Closes #9464 From de289bf279e14e47859b5fbcd70e97b9d0759f14 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Nov 2015 10:56:32 -0800 Subject: [PATCH 171/324] [SPARK-10304][SQL] Following up checking valid dir structure for partition discovery This patch follows up #8840. Author: Liang-Chi Hsieh Closes #9459 from viirya/detect_invalid_part_dir_following. --- .../datasources/PartitioningUtils.scala | 14 +++++++++++++- .../parquet/ParquetPartitionDiscoverySuite.scala | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 16dc23661c070..86bc3a1b6dab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -81,6 +81,8 @@ private[sql] object PartitioningUtils { parsePartition(path, defaultPartitionName, typeInference) }.unzip + // We create pairs of (path -> path's partition value) here + // If the corresponding partition value is None, the pair will be skiped val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _)) if (pathsWithPartitionValues.isEmpty) { @@ -89,11 +91,21 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. + + // Check if there is conflicting directory structure. + // For the paths such as: + // var paths = Seq( + // "hdfs://host:9000/invalidPath", + // "hdfs://host:9000/path/a=10/b=20", + // "hdfs://host:9000/path/a=10.5/b=hello") + // It will be recognised as conflicting directory structure: + // "hdfs://host:9000/invalidPath" + // "hdfs://host:9000/path" val basePaths = optBasePaths.flatMap(x => x) assert( basePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + - basePaths.mkString("\n\t", "\n\t", "\n\n")) + basePaths.distinct.mkString("\n\t", "\n\t", "\n\n")) val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 67b6a37fa502e..61cc0da50865c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -88,6 +88,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) } assert(exception.getMessage().contains("Conflicting directory structures detected")) + + // Invalid + // Conflicting directory structure: + // "hdfs://host:9000/tmp/tables/partitionedTable" + // "hdfs://host:9000/tmp/tables/nonPartitionedTable1" + // "hdfs://host:9000/tmp/tables/nonPartitionedTable2" + paths = Seq( + "hdfs://host:9000/tmp/tables/partitionedTable", + "hdfs://host:9000/tmp/tables/partitionedTable/p=1/", + "hdfs://host:9000/tmp/tables/nonPartitionedTable1", + "hdfs://host:9000/tmp/tables/nonPartitionedTable2") + + exception = intercept[AssertionError] { + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + } + assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { From abf5e4285d97b148a32cf22f5287511198175cb6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 12:33:47 -0800 Subject: [PATCH 172/324] [SPARK-11504][SQL] API audit for distributeBy and localSort 1. Renamed localSort -> sortWithinPartitions to avoid ambiguity in "local" 2. distributeBy -> repartition to match the existing repartition. Author: Reynold Xin Closes #9470 from rxin/SPARK-11504. --- .../org/apache/spark/sql/DataFrame.scala | 132 ++++++++++-------- .../apache/spark/sql/CachedTableSuite.scala | 20 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 44 +++--- 3 files changed, 113 insertions(+), 83 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5e9c7efbbf160..d3a2249d7006c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -241,18 +241,6 @@ class DataFrame private[sql]( sb.toString() } - private[sql] def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - Sort(sortOrder, global = global, logicalPlan) - } - override def toString: String = { try { schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") @@ -619,6 +607,32 @@ class DataFrame private[sql]( plan.copy(condition = cond) } + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { + sortWithinPartitions(sortCol, sortCols : _*) + } + + /** + * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * + * This is the same operation as "SORT BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def sortWithinPartitions(sortExprs: Column*): DataFrame = { + sortInternal(global = false, sortExprs) + } + /** * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. * {{{ @@ -645,7 +659,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortExprs: Column*): DataFrame = { - sortInternal(true, sortExprs) + sortInternal(global = true, sortExprs) } /** @@ -666,44 +680,6 @@ class DataFrame private[sql]( @scala.annotation.varargs def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) - /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting DataFrame is hash partitioned. - * @group dfops - * @since 1.6.0 - */ - def distributeBy(partitionExprs: Seq[Column], numPartitions: Int): DataFrame = { - RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, Some(numPartitions)) - } - - /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving - * the existing number of partitions. The resulting DataFrame is hash partitioned. - * @group dfops - * @since 1.6.0 - */ - def distributeBy(partitionExprs: Seq[Column]): DataFrame = { - RepartitionByExpression(partitionExprs.map { _.expr }, logicalPlan, None) - } - - /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def localSort(sortCol: String, sortCols: String*): DataFrame = localSort(sortCol, sortCols : _*) - - /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. - * @group dfops - * @since 1.6.0 - */ - @scala.annotation.varargs - def localSort(sortExprs: Column*): DataFrame = { - sortInternal(false, sortExprs) - } - /** * Selects column based on the column name and return it as a [[Column]]. * Note that the column name can also reference to a nested column like `a.b`. @@ -798,7 +774,9 @@ class DataFrame private[sql]( * SQL expressions. * * {{{ + * // The following are equivalent: * df.selectExpr("colA", "colB as newName", "abs(colC)") + * df.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) * }}} * @group dfops * @since 1.3.0 @@ -1524,13 +1502,41 @@ class DataFrame private[sql]( /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. - * @group rdd + * @group dfops * @since 1.3.0 */ def repartition(numPartitions: Int): DataFrame = { Repartition(numPartitions, shuffle = true, logicalPlan) } + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) + } + + /** + * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting DataFrame is hash partitioned. + * + * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). + * + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def repartition(partitionExprs: Column*): DataFrame = { + RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) + } + /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. @@ -2016,6 +2022,12 @@ class DataFrame private[sql]( write.mode(SaveMode.Append).insertInto(tableName) } + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // End of deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. @@ -2045,10 +2057,16 @@ class DataFrame private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = global, logicalPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 605954b105d1e..dbcb011f603f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -379,8 +379,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1") - testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2") + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") @@ -401,8 +401,20 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } // Distribute the tables into non-matching number of partitions. Need to shuffle. - testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1") - testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2") + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + sqlContext.dropTempTable("t1") + sqlContext.dropTempTable("t2") + + // One side of join is not partitioned in the desired way. Need to shuffle. + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a9e6413423118..84a616d0b9081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,79 +1044,79 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("distributeBy and localSort") { val original = testData.repartition(1) assert(original.rdd.partitions.length == 1) - val df = original.distributeBy(Column("key") :: Nil, 5) - assert(df.rdd.partitions.length == 5) + val df = original.repartition(5, $"key") + assert(df.rdd.partitions.length == 5) checkAnswer(original.select(), df.select()) - val df2 = original.distributeBy(Column("key") :: Nil, 10) - assert(df2.rdd.partitions.length == 10) + val df2 = original.repartition(10, $"key") + assert(df2.rdd.partitions.length == 10) checkAnswer(original.select(), df2.select()) // Group by the column we are distributed by. This should generate a plan with no exchange // between the aggregates - val df3 = testData.distributeBy(Column("key") :: Nil).groupBy("key").count() + val df3 = testData.repartition($"key").groupBy("key").count() verifyNonExchangingAgg(df3) - verifyNonExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + verifyNonExchangingAgg(testData.repartition($"key", $"value") .groupBy("key", "value").count()) // Grouping by just the first distributeBy expr, need to exchange. - verifyExchangingAgg(testData.distributeBy(Column("key") :: Column("value") :: Nil) + verifyExchangingAgg(testData.repartition($"key", $"value") .groupBy("key").count()) val data = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData2(i % 10, i))).toDF() // Distribute and order by. - val df4 = data.distributeBy(Column("a") :: Nil).localSort($"b".desc) + val df4 = data.repartition($"a").sortWithinPartitions($"b".desc) // Walk each partition and verify that it is sorted descending and does not contain all // the values. - df4.rdd.foreachPartition(p => { + df4.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true - p.foreach(r => { + p.foreach { r => val v: Int = r.getInt(1) if (previousValue != -1) { if (previousValue < v) throw new SparkException("Partition is not ordered.") if (v + 1 != previousValue) allSequential = false } previousValue = v - }) + } if (allSequential) throw new SparkException("Partition should not be globally ordered") - }) + } // Distribute and order by with multiple order bys - val df5 = data.distributeBy(Column("a") :: Nil, 2).localSort($"b".asc, $"a".asc) + val df5 = data.repartition(2, $"a").sortWithinPartitions($"b".asc, $"a".asc) // Walk each partition and verify that it is sorted ascending - df5.rdd.foreachPartition(p => { + df5.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true - p.foreach(r => { + p.foreach { r => val v: Int = r.getInt(1) if (previousValue != -1) { if (previousValue > v) throw new SparkException("Partition is not ordered.") if (v - 1 != previousValue) allSequential = false } previousValue = v - }) + } if (allSequential) throw new SparkException("Partition should not be all sequential") - }) + } // Distribute into one partition and order by. This partition should contain all the values. - val df6 = data.distributeBy(Column("a") :: Nil, 1).localSort($"b".asc) + val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc) // Walk each partition and verify that it is sorted descending and not globally sorted. - df6.rdd.foreachPartition(p => { + df6.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true - p.foreach(r => { + p.foreach { r => val v: Int = r.getInt(1) if (previousValue != -1) { if (previousValue > v) throw new SparkException("Partition is not ordered.") if (v - 1 != previousValue) allSequential = false } previousValue = v - }) + } if (!allSequential) throw new SparkException("Partition should contain all sequential values") - }) + } } test("fix case sensitivity of partition by") { From d19f4fda63d0800a85b3e1c8379160bbbf17b6a3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 13:44:07 -0800 Subject: [PATCH 173/324] [SPARK-11505][SQL] Break aggregate functions into multiple files functions.scala was getting pretty long. I broke it into multiple files. I also added explicit data types for some public vals, and renamed aggregate function pretty names to lower case, which is more consistent with rest of the functions. Author: Reynold Xin Closes #9471 from rxin/SPARK-11505. --- .../unsafe/sort/UnsafeExternalSorter.java | 5 +- .../expressions/aggregate/Average.scala | 86 ++ .../aggregate/CentralMomentAgg.scala | 230 +++++ .../catalyst/expressions/aggregate/Corr.scala | 179 ++++ .../expressions/aggregate/Count.scala | 52 + .../expressions/aggregate/First.scala | 92 ++ ...ctions.scala => HyperLogLogPlusPlus.scala} | 933 ------------------ .../expressions/aggregate/Kurtosis.scala | 49 + .../catalyst/expressions/aggregate/Last.scala | 89 ++ .../catalyst/expressions/aggregate/Max.scala | 55 ++ .../catalyst/expressions/aggregate/Min.scala | 56 ++ .../expressions/aggregate/Skewness.scala | 48 + .../expressions/aggregate/Stddev.scala | 134 +++ .../catalyst/expressions/aggregate/Sum.scala | 75 ++ .../aggregate/{utils.scala => Utils.scala} | 0 .../expressions/aggregate/Variance.scala | 66 ++ .../sql/catalyst/expressions/aggregates.scala | 24 +- 17 files changed, 1223 insertions(+), 950 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/{functions.scala => HyperLogLogPlusPlus.scala} (72%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/{utils.scala => Utils.scala} (100%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 49a5a4b13b70d..509fb0a044c0c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -157,11 +157,14 @@ public void closeCurrentPage() { */ @Override public long spill(long size, MemoryConsumer trigger) throws IOException { + assert(inMemSorter != null); if (trigger != this) { if (readingIterator != null) { return readingIterator.spill(); + } else { + } - return 0L; + return 0L; // this should throw exception } if (inMemSorter == null || inMemSorter.numRecords() <= 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala new file mode 100644 index 0000000000000..c8c20ada5fbc7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends DeclarativeAggregate { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) + case _ => DoubleType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case _ => DoubleType + } + + private val sum = AttributeReference("sum", sumDataType)() + private val count = AttributeReference("count", LongType)() + + override val aggBufferAttributes = sum :: count :: Nil + + override val initialValues = Seq( + /* sum = */ Cast(Literal(0), sumDataType), + /* count = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override val mergeExpressions = Seq( + /* sum = */ sum.left + sum.right, + /* count = */ count.left + count.right + ) + + // If all input are nulls, count will be 0 and we will get null after the division. + override val evaluateExpression = child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Cast(sum, dt) / Cast(count, dt), resultType) + case _ => + Cast(sum, resultType) / Cast(count, resultType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala new file mode 100644 index 0000000000000..ef08b025ff556 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A central moment is the expected value of a specified power of the deviation of a random + * variable from the mean. Central moments are often used to characterize the properties of about + * the shape of a distribution. + * + * This class implements online, one-pass algorithms for computing the central moments of a set of + * points. + * + * Behavior: + * - null values are ignored + * - returns `Double.NaN` when the column contains `Double.NaN` values + * + * References: + * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." + * 2015. http://arxiv.org/abs/1510.04923 + * + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + * Algorithms for calculating variance (Wikipedia)]] + * + * @param child to compute central moments of. + */ +abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { + + /** + * The central moment order to be computed. + */ + protected def momentOrder: Int + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * Size of aggregation buffer. + */ + private[this] val bufferSize = 5 + + override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => + AttributeReference(s"M$i", DoubleType)() + } + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // buffer offsets + private[this] val nOffset = mutableAggBufferOffset + private[this] val meanOffset = mutableAggBufferOffset + 1 + private[this] val secondMomentOffset = mutableAggBufferOffset + 2 + private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 + private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 + + // frequently used values for online updates + private[this] var delta = 0.0 + private[this] var deltaN = 0.0 + private[this] var delta2 = 0.0 + private[this] var deltaN2 = 0.0 + private[this] var n = 0.0 + private[this] var mean = 0.0 + private[this] var m2 = 0.0 + private[this] var m3 = 0.0 + private[this] var m4 = 0.0 + + /** + * Initialize all moments to zero. + */ + override def initialize(buffer: MutableRow): Unit = { + for (aggIndex <- 0 until bufferSize) { + buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) + } + } + + /** + * Update the central moments buffer. + */ + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val v = Cast(child, DoubleType).eval(input) + if (v != null) { + val updateValue = v match { + case d: Double => d + } + + n = buffer.getDouble(nOffset) + mean = buffer.getDouble(meanOffset) + + n += 1.0 + buffer.setDouble(nOffset, n) + delta = updateValue - mean + deltaN = delta / n + mean += deltaN + buffer.setDouble(meanOffset, mean) + + if (momentOrder >= 2) { + m2 = buffer.getDouble(secondMomentOffset) + m2 += delta * (delta - deltaN) + buffer.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + delta2 = delta * delta + deltaN2 = deltaN * deltaN + m3 = buffer.getDouble(thirdMomentOffset) + m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) + buffer.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + m4 = buffer.getDouble(fourthMomentOffset) + m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + + delta * (delta * delta2 - deltaN * deltaN2) + buffer.setDouble(fourthMomentOffset, m4) + } + } + } + + /** + * Merge two central moment buffers. + */ + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val n1 = buffer1.getDouble(nOffset) + val n2 = buffer2.getDouble(inputAggBufferOffset) + val mean1 = buffer1.getDouble(meanOffset) + val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) + + var secondMoment1 = 0.0 + var secondMoment2 = 0.0 + + var thirdMoment1 = 0.0 + var thirdMoment2 = 0.0 + + var fourthMoment1 = 0.0 + var fourthMoment2 = 0.0 + + n = n1 + n2 + buffer1.setDouble(nOffset, n) + delta = mean2 - mean1 + deltaN = if (n == 0.0) 0.0 else delta / n + mean = mean1 + deltaN * n2 + buffer1.setDouble(mutableAggBufferOffset + 1, mean) + + // higher order moments computed according to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics + if (momentOrder >= 2) { + secondMoment1 = buffer1.getDouble(secondMomentOffset) + secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) + m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 + buffer1.setDouble(secondMomentOffset, m2) + } + + if (momentOrder >= 3) { + thirdMoment1 = buffer1.getDouble(thirdMomentOffset) + thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) + m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * + (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) + buffer1.setDouble(thirdMomentOffset, m3) + } + + if (momentOrder >= 4) { + fourthMoment1 = buffer1.getDouble(fourthMomentOffset) + fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) + m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * + n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * + (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + + 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) + buffer1.setDouble(fourthMomentOffset, m4) + } + } + + /** + * Compute aggregate statistic from sufficient moments. + * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) + * needed to compute the aggregate stat. + */ + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + + override final def eval(buffer: InternalRow): Any = { + val n = buffer.getDouble(nOffset) + val mean = buffer.getDouble(meanOffset) + val moments = Array.ofDim[Double](momentOrder + 1) + moments(0) = 1.0 + moments(1) = 0.0 + if (momentOrder >= 2) { + moments(2) = buffer.getDouble(secondMomentOffset) + } + if (momentOrder >= 3) { + moments(3) = buffer.getDouble(thirdMomentOffset) + } + if (momentOrder >= 4) { + moments(4) = buffer.getDouble(fourthMomentOffset) + } + + getStatistic(n, mean, moments) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala new file mode 100644 index 0000000000000..832338378fb38 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Compute Pearson correlation between two expressions. + * When applied on empty data (i.e., count is zero), it returns NULL. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + */ +case class Corr( + left: Expression, + right: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate { + + override def children: Seq[Expression] = Seq(left, right) + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = { + aggBufferAttributes.map(_.newInstance()) + } + + override val aggBufferAttributes: Seq[AttributeReference] = Seq( + AttributeReference("xAvg", DoubleType)(), + AttributeReference("yAvg", DoubleType)(), + AttributeReference("Ck", DoubleType)(), + AttributeReference("MkX", DoubleType)(), + AttributeReference("MkY", DoubleType)(), + AttributeReference("count", LongType)()) + + // Local cache of mutableAggBufferOffset(s) that will be used in update and merge + private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 + private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 + private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 + private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 + private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 + + // Local cache of inputAggBufferOffset(s) that will be used in update and merge + private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 + private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 + private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 + private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 + private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def initialize(buffer: MutableRow): Unit = { + buffer.setDouble(mutableAggBufferOffset, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) + buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) + buffer.setLong(mutableAggBufferOffsetPlus5, 0L) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val leftEval = left.eval(input) + val rightEval = right.eval(input) + + if (leftEval != null && rightEval != null) { + val x = leftEval.asInstanceOf[Double] + val y = rightEval.asInstanceOf[Double] + + var xAvg = buffer.getDouble(mutableAggBufferOffset) + var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer.getLong(mutableAggBufferOffsetPlus5) + + val deltaX = x - xAvg + val deltaY = y - yAvg + count += 1 + xAvg += deltaX / count + yAvg += deltaY / count + Ck += deltaX * (y - yAvg) + MkX += deltaX * (x - xAvg) + MkY += deltaY * (y - yAvg) + + buffer.setDouble(mutableAggBufferOffset, xAvg) + buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + // Merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) + + // We only go to merge two buffers if there is at least one record aggregated in buffer2. + // We don't need to check count in buffer1 because if count2 is more than zero, totalCount + // is more than zero too, then we won't get a divide by zero exception. + if (count2 > 0) { + var xAvg = buffer1.getDouble(mutableAggBufferOffset) + var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) + var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) + var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) + var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) + var count = buffer1.getLong(mutableAggBufferOffsetPlus5) + + val xAvg2 = buffer2.getDouble(inputAggBufferOffset) + val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) + val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) + val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) + val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) + + val totalCount = count + count2 + val deltaX = xAvg - xAvg2 + val deltaY = yAvg - yAvg2 + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 + xAvg = (xAvg * count + xAvg2 * count2) / totalCount + yAvg = (yAvg * count + yAvg2 * count2) / totalCount + MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 + MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 + count = totalCount + + buffer1.setDouble(mutableAggBufferOffset, xAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) + buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) + buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) + buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) + buffer1.setLong(mutableAggBufferOffsetPlus5, count) + } + } + + override def eval(buffer: InternalRow): Any = { + val count = buffer.getLong(mutableAggBufferOffsetPlus5) + if (count > 0) { + val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) + val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) + val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) + val corr = Ck / math.sqrt(MkX * MkY) + if (corr.isNaN) { + null + } else { + corr + } + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala new file mode 100644 index 0000000000000..54df96cd2446a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Count(child: Expression) extends DeclarativeAggregate { + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val count = AttributeReference("count", LongType)() + + override val aggBufferAttributes = count :: Nil + + override val initialValues = Seq( + /* count = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override val evaluateExpression = Cast(count, LongType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala new file mode 100644 index 0000000000000..9028143015853 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Returns the first value of `child` for a group of rows. If the first value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + */ +case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val first = AttributeReference("first", child.dataType)() + + private val valueSet = AttributeReference("valueSet", BooleanType)() + + override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + + override val initialValues: Seq[Literal] = Seq( + /* first = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) + ) + + override val updateExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* first = */ If(Or(valueSet, IsNull(child)), first, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) + ) + } else { + Seq( + /* first = */ If(valueSet, first, child), + /* valueSet = */ Literal.create(true, BooleanType) + ) + } + } + + override val mergeExpressions: Seq[Expression] = { + // For first, we can just check if valueSet.left is set to true. If it is set + // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // false, we are safe to do so because first.right will be null in this case). + Seq( + /* first = */ If(valueSet.left, first.left, first.right), + /* valueSet = */ Or(valueSet.left, valueSet.right) + ) + } + + override val evaluateExpression: AttributeReference = first + + override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala similarity index 72% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 10dc5e64b7ec9..8d341ee630bdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -22,636 +22,10 @@ import java.util import com.clearspring.analytics.hash.MurmurHash -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Average(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = child.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 4, s + 4) - case _ => DoubleType - } - - private val sumDataType = child.dataType match { - case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) - case _ => DoubleType - } - - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() - - override val aggBufferAttributes = sum :: count :: Nil - - override val initialValues = Seq( - /* sum = */ Cast(Literal(0), sumDataType), - /* count = */ Literal(0L) - ) - - override val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - - override val mergeExpressions = Seq( - /* sum = */ sum.left + sum.right, - /* count = */ count.left + count.right - ) - - // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) - case _ => - Cast(sum, resultType) / Cast(count, resultType) - } -} - -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = false - - // Return data type. - override def dataType: DataType = LongType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val count = AttributeReference("count", LongType)() - - override val aggBufferAttributes = count :: Nil - - override val initialValues = Seq( - /* count = */ Literal(0L) - ) - - override val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) - ) - - override val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override val evaluateExpression = Cast(count, LongType) -} - -/** - * Returns the first value of `child` for a group of rows. If the first value of `child` - * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already - * sorted column, if we do partial aggregation and final aggregation (when mergeExpression - * is used) its result will not be deterministic (unless the input table is sorted and has - * a single partition, and we use a single reducer to do the aggregation.). - * @param child - */ -case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // First is not a deterministic function. - override def deterministic: Boolean = false - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val first = AttributeReference("first", child.dataType)() - - private val valueSet = AttributeReference("valueSet", BooleanType)() - - override val aggBufferAttributes = first :: valueSet :: Nil - - override val initialValues = Seq( - /* first = */ Literal.create(null, child.dataType), - /* valueSet = */ Literal.create(false, BooleanType) - ) - - override val updateExpressions = { - if (ignoreNulls) { - Seq( - /* first = */ If(Or(valueSet, IsNull(child)), first, child), - /* valueSet = */ Or(valueSet, IsNotNull(child)) - ) - } else { - Seq( - /* first = */ If(valueSet, first, child), - /* valueSet = */ Literal.create(true, BooleanType) - ) - } - } - - override val mergeExpressions = { - // For first, we can just check if valueSet.left is set to true. If it is set - // to true, we use first.right. If not, we use first.right (even if valueSet.right is - // false, we are safe to do so because first.right will be null in this case). - Seq( - /* first = */ If(valueSet.left, first.left, first.right), - /* valueSet = */ Or(valueSet.left, valueSet.right) - ) - } - - override val evaluateExpression = first - - override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}" -} - -/** - * Returns the last value of `child` for a group of rows. If the last value of `child` - * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already - * sorted column, if we do partial aggregation and final aggregation (when mergeExpression - * is used) its result will not be deterministic (unless the input table is sorted and has - * a single partition, and we use a single reducer to do the aggregation.). - * @param child - */ -case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Last is not a deterministic function. - override def deterministic: Boolean = false - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val last = AttributeReference("last", child.dataType)() - - override val aggBufferAttributes = last :: Nil - - override val initialValues = Seq( - /* last = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(child), last, child) - ) - } else { - Seq( - /* last = */ child - ) - } - } - - override val mergeExpressions = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - } else { - Seq( - /* last = */ last.right - ) - } - } - - override val evaluateExpression = last - - override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" -} - -case class Max(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val max = AttributeReference("max", child.dataType)() - - override val aggBufferAttributes = max :: Nil - - override val initialValues = Seq( - /* max = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) - ) - - override val mergeExpressions = { - val greatest = Greatest(Seq(max.left, max.right)) - Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) - ) - } - - override val evaluateExpression = max -} - -case class Min(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = child.dataType - - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - - private val min = AttributeReference("min", child.dataType)() - - override val aggBufferAttributes = min :: Nil - - override val initialValues = Seq( - /* min = */ Literal.create(null, child.dataType) - ) - - override val updateExpressions = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) - ) - - override val mergeExpressions = { - val least = Least(Seq(min.left, min.right)) - Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) - ) - } - - override val evaluateExpression = min -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - def isSample: Boolean - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - private val resultType = DoubleType - - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() - - override val aggBufferAttributes = count :: avg :: mk :: Nil - - override val initialValues = Seq( - /* count = */ Cast(Literal(0), resultType), - /* avg = */ Cast(Literal(0), resultType), - /* mk = */ Cast(Literal(0), resultType) - ) - - override val updateExpressions = { - val value = Cast(child, resultType) - val newCount = count + Cast(Literal(1), resultType) - - // update average - // avg = avg + (value - avg)/count - val newAvg = avg + (value - avg) / newCount - - // update sum of square of difference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - val newMk = mk + (value - avg) * (value - newAvg) - - Seq( - /* count = */ If(IsNull(child), count, newCount), - /* avg = */ If(IsNull(child), avg, newAvg), - /* mk = */ If(IsNull(child), mk, newMk) - ) - } - - override val mergeExpressions = { - - // count merge - val newCount = count.left + count.right - - // average merge - val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount - - // update sum of square differences - val newMk = { - val avgDelta = avg.right - avg.left - val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount - mk.left + mk.right + mkDelta - } - - Seq( - /* count = */ If(IsNull(count.left), count.right, - If(IsNull(count.right), count.left, newCount)), - /* avg = */ If(IsNull(avg.left), avg.right, - If(IsNull(avg.right), avg.left, newAvg)), - /* mk = */ If(IsNull(mk.left), mk.right, - If(IsNull(mk.right), mk.left, newMk)) - ) - } - - override val evaluateExpression = { - // when count == 0, return null - // when count == 1, return 0 - // when count >1 - // stddev_samp = sqrt (mk/(count -1)) - // stddev_pop = sqrt (mk/count) - val varCol = - if (isSample) { - mk / Cast((count - Cast(Literal(1), resultType)), resultType) - } else { - mk / count - } - - If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) - } -} - -case class Sum(child: Expression) extends DeclarativeAggregate { - - override def children: Seq[Expression] = child :: Nil - - override def nullable: Boolean = true - - // Return data type. - override def dataType: DataType = resultType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) - - private val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType - case _ => child.dataType - } - - private val sumDataType = resultType - - private val sum = AttributeReference("sum", sumDataType)() - - private val zero = Cast(Literal(0), sumDataType) - - override val aggBufferAttributes = sum :: Nil - - override val initialValues = Seq( - /* sum = */ Literal.create(null, sumDataType) - ) - - override val updateExpressions = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) - - override val mergeExpressions = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) - Seq( - /* sum = */ - Coalesce(Seq(add, sum.left)) - ) - } - - override val evaluateExpression = Cast(sum, resultType) -} - -/** - * Compute Pearson correlation between two expressions. - * When applied on empty data (i.e., count is zero), it returns NULL. - * - * Definition of Pearson correlation can be found at - * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient - * - * @param left one of the expressions to compute correlation with. - * @param right another expression to compute correlation with. - */ -case class Corr( - left: Expression, - right: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends ImperativeAggregate { - - def children: Seq[Expression] = Seq(left, right) - - def nullable: Boolean = false - - def dataType: DataType = DoubleType - - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - - def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - def inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - - val aggBufferAttributes: Seq[AttributeReference] = Seq( - AttributeReference("xAvg", DoubleType)(), - AttributeReference("yAvg", DoubleType)(), - AttributeReference("Ck", DoubleType)(), - AttributeReference("MkX", DoubleType)(), - AttributeReference("MkY", DoubleType)(), - AttributeReference("count", LongType)()) - - // Local cache of mutableAggBufferOffset(s) that will be used in update and merge - private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1 - private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2 - private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3 - private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4 - private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5 - - // Local cache of inputAggBufferOffset(s) that will be used in update and merge - private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1 - private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2 - private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3 - private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4 - private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5 - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def initialize(buffer: MutableRow): Unit = { - buffer.setDouble(mutableAggBufferOffset, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0) - buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0) - buffer.setLong(mutableAggBufferOffsetPlus5, 0L) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val leftEval = left.eval(input) - val rightEval = right.eval(input) - - if (leftEval != null && rightEval != null) { - val x = leftEval.asInstanceOf[Double] - val y = rightEval.asInstanceOf[Double] - - var xAvg = buffer.getDouble(mutableAggBufferOffset) - var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer.getLong(mutableAggBufferOffsetPlus5) - - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - - buffer.setDouble(mutableAggBufferOffset, xAvg) - buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - // Merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val count2 = buffer2.getLong(inputAggBufferOffsetPlus5) - - // We only go to merge two buffers if there is at least one record aggregated in buffer2. - // We don't need to check count in buffer1 because if count2 is more than zero, totalCount - // is more than zero too, then we won't get a divide by zero exception. - if (count2 > 0) { - var xAvg = buffer1.getDouble(mutableAggBufferOffset) - var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1) - var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2) - var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3) - var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4) - var count = buffer1.getLong(mutableAggBufferOffsetPlus5) - - val xAvg2 = buffer2.getDouble(inputAggBufferOffset) - val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1) - val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2) - val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3) - val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4) - - val totalCount = count + count2 - val deltaX = xAvg - xAvg2 - val deltaY = yAvg - yAvg2 - Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 - xAvg = (xAvg * count + xAvg2 * count2) / totalCount - yAvg = (yAvg * count + yAvg2 * count2) / totalCount - MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 - MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 - count = totalCount - - buffer1.setDouble(mutableAggBufferOffset, xAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg) - buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck) - buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX) - buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY) - buffer1.setLong(mutableAggBufferOffsetPlus5, count) - } - } - - override def eval(buffer: InternalRow): Any = { - val count = buffer.getLong(mutableAggBufferOffsetPlus5) - if (count > 0) { - val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2) - val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3) - val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4) - val corr = Ck / math.sqrt(MkX * MkY) - if (corr.isNaN) { - null - } else { - corr - } - } else { - null - } - } -} - // scalastyle:off /** * HyperLogLog++ (HLL++) is a state of the art cardinality estimation algorithm. This class @@ -1058,310 +432,3 @@ object HyperLogLogPlusPlus { ) // scalastyle:on } - -/** - * A central moment is the expected value of a specified power of the deviation of a random - * variable from the mean. Central moments are often used to characterize the properties of about - * the shape of a distribution. - * - * This class implements online, one-pass algorithms for computing the central moments of a set of - * points. - * - * Behavior: - * - null values are ignored - * - returns `Double.NaN` when the column contains `Double.NaN` values - * - * References: - * - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." - * 2015. http://arxiv.org/abs/1510.04923 - * - * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - * Algorithms for calculating variance (Wikipedia)]] - * - * @param child to compute central moments of. - */ -abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { - - /** - * The central moment order to be computed. - */ - protected def momentOrder: Int - - override def children: Seq[Expression] = Seq(child) - - override def nullable: Boolean = false - - override def dataType: DataType = DoubleType - - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - /** - * Size of aggregation buffer. - */ - private[this] val bufferSize = 5 - - override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => - AttributeReference(s"M$i", DoubleType)() - } - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - // buffer offsets - private[this] val nOffset = mutableAggBufferOffset - private[this] val meanOffset = mutableAggBufferOffset + 1 - private[this] val secondMomentOffset = mutableAggBufferOffset + 2 - private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 - private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 - - // frequently used values for online updates - private[this] var delta = 0.0 - private[this] var deltaN = 0.0 - private[this] var delta2 = 0.0 - private[this] var deltaN2 = 0.0 - private[this] var n = 0.0 - private[this] var mean = 0.0 - private[this] var m2 = 0.0 - private[this] var m3 = 0.0 - private[this] var m4 = 0.0 - - /** - * Initialize all moments to zero. - */ - override def initialize(buffer: MutableRow): Unit = { - for (aggIndex <- 0 until bufferSize) { - buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) - } - } - - /** - * Update the central moments buffer. - */ - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) - if (v != null) { - val updateValue = v match { - case d: Double => d - } - - n = buffer.getDouble(nOffset) - mean = buffer.getDouble(meanOffset) - - n += 1.0 - buffer.setDouble(nOffset, n) - delta = updateValue - mean - deltaN = delta / n - mean += deltaN - buffer.setDouble(meanOffset, mean) - - if (momentOrder >= 2) { - m2 = buffer.getDouble(secondMomentOffset) - m2 += delta * (delta - deltaN) - buffer.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - delta2 = delta * delta - deltaN2 = deltaN * deltaN - m3 = buffer.getDouble(thirdMomentOffset) - m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) - buffer.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - m4 = buffer.getDouble(fourthMomentOffset) - m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + - delta * (delta * delta2 - deltaN * deltaN2) - buffer.setDouble(fourthMomentOffset, m4) - } - } - } - - /** - * Merge two central moment buffers. - */ - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val n1 = buffer1.getDouble(nOffset) - val n2 = buffer2.getDouble(inputAggBufferOffset) - val mean1 = buffer1.getDouble(meanOffset) - val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) - - var secondMoment1 = 0.0 - var secondMoment2 = 0.0 - - var thirdMoment1 = 0.0 - var thirdMoment2 = 0.0 - - var fourthMoment1 = 0.0 - var fourthMoment2 = 0.0 - - n = n1 + n2 - buffer1.setDouble(nOffset, n) - delta = mean2 - mean1 - deltaN = if (n == 0.0) 0.0 else delta / n - mean = mean1 + deltaN * n2 - buffer1.setDouble(mutableAggBufferOffset + 1, mean) - - // higher order moments computed according to: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics - if (momentOrder >= 2) { - secondMoment1 = buffer1.getDouble(secondMomentOffset) - secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) - m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 - buffer1.setDouble(secondMomentOffset, m2) - } - - if (momentOrder >= 3) { - thirdMoment1 = buffer1.getDouble(thirdMomentOffset) - thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) - m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * - (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) - buffer1.setDouble(thirdMomentOffset, m3) - } - - if (momentOrder >= 4) { - fourthMoment1 = buffer1.getDouble(fourthMomentOffset) - fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) - m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * - n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * - (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + - 4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) - buffer1.setDouble(fourthMomentOffset, m4) - } - } - - /** - * Compute aggregate statistic from sufficient moments. - * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) - * needed to compute the aggregate stat. - */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double - - override final def eval(buffer: InternalRow): Any = { - val n = buffer.getDouble(nOffset) - val mean = buffer.getDouble(meanOffset) - val moments = Array.ofDim[Double](momentOrder + 1) - moments(0) = 1.0 - moments(1) = 0.0 - if (momentOrder >= 2) { - moments(2) = buffer.getDouble(secondMomentOffset) - } - if (momentOrder >= 3) { - moments(3) = buffer.getDouble(thirdMomentOffset) - } - if (momentOrder >= 4) { - moments(4) = buffer.getDouble(fourthMomentOffset) - } - - getStatistic(n, mean, moments) - } -} - -case class VarianceSamp(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "variance_samp" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) - } -} - -case class VariancePop(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "variance_pop" - - override protected val momentOrder = 2 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - - if (n == 0.0) Double.NaN else moments(2) / n - } -} - -case class Skewness(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "skewness" - - override protected val momentOrder = 3 - - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { - Double.NaN - } else { - math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) - } - } -} - -case class Kurtosis(child: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def prettyName: String = "kurtosis" - - override protected val momentOrder = 4 - - // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { - require(moments.length == momentOrder + 1, - s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") - val m2 = moments(2) - val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { - Double.NaN - } else { - n * m4 / (m2 * m2) - 3.0 - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala new file mode 100644 index 0000000000000..6da39e7143447 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class Kurtosis(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "kurtosis" + + override protected val momentOrder = 4 + + // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m4 = moments(4) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + n * m4 / (m2 * m2) - 3.0 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala new file mode 100644 index 0000000000000..8636bfe8d07aa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * Returns the last value of `child` for a group of rows. If the last value of `child` + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * sorted column, if we do partial aggregation and final aggregation (when mergeExpression + * is used) its result will not be deterministic (unless the input table is sorted and has + * a single partition, and we use a single reducer to do the aggregation.). + */ +case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { + + def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + + private val ignoreNulls: Boolean = ignoreNullsExpr match { + case Literal(b: Boolean, BooleanType) => b + case _ => + throw new AnalysisException("The second argument of First should be a boolean literal.") + } + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val last = AttributeReference("last", child.dataType)() + + override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + + override val initialValues: Seq[Literal] = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(child), last, child) + ) + } else { + Seq( + /* last = */ child + ) + } + } + + override val mergeExpressions: Seq[Expression] = { + if (ignoreNulls) { + Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + } else { + Seq( + /* last = */ last.right + ) + } + } + + override val evaluateExpression: AttributeReference = last + + override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala new file mode 100644 index 0000000000000..b9d75ad452838 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Max(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val max = AttributeReference("max", child.dataType)() + + override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override val initialValues: Seq[Literal] = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions: Seq[Expression] = Seq( + /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + ) + + override val mergeExpressions: Seq[Expression] = { + val greatest = Greatest(Seq(max.left, max.right)) + Seq( + /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + ) + } + + override val evaluateExpression: AttributeReference = max +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala new file mode 100644 index 0000000000000..5ed9cd348daba --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +case class Min(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val min = AttributeReference("min", child.dataType)() + + override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override val initialValues: Seq[Expression] = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions: Seq[Expression] = Seq( + /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + ) + + override val mergeExpressions: Seq[Expression] = { + val least = Least(Seq(min.left, min.right)) + Seq( + /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + ) + } + + override val evaluateExpression: AttributeReference = min +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala new file mode 100644 index 0000000000000..0def7ddfd9d3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class Skewness(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "skewness" + + override protected val momentOrder = 3 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") + val m2 = moments(2) + val m3 = moments(3) + if (n == 0.0 || m2 == 0.0) { + Double.NaN + } else { + math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala new file mode 100644 index 0000000000000..3f47ffe13cbc8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { + + def isSample: Boolean + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val count = AttributeReference("count", resultType)() + private val avg = AttributeReference("avg", resultType)() + private val mk = AttributeReference("mk", resultType)() + + override val aggBufferAttributes = count :: avg :: mk :: Nil + + override val initialValues: Seq[Expression] = Seq( + /* count = */ Cast(Literal(0), resultType), + /* avg = */ Cast(Literal(0), resultType), + /* mk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions: Seq[Expression] = { + val value = Cast(child, resultType) + val newCount = count + Cast(Literal(1), resultType) + + // update average + // avg = avg + (value - avg)/count + val newAvg = avg + (value - avg) / newCount + + // update sum ofference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + val newMk = mk + (value - avg) * (value - newAvg) + + Seq( + /* count = */ If(IsNull(child), count, newCount), + /* avg = */ If(IsNull(child), avg, newAvg), + /* mk = */ If(IsNull(child), mk, newMk) + ) + } + + override val mergeExpressions: Seq[Expression] = { + + // count merge + val newCount = count.left + count.right + + // average merge + val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount + + // update sum of square differences + val newMk = { + val avgDelta = avg.right - avg.left + val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount + mk.left + mk.right + mkDelta + } + + Seq( + /* count = */ If(IsNull(count.left), count.right, + If(IsNull(count.right), count.left, newCount)), + /* avg = */ If(IsNull(avg.left), avg.right, + If(IsNull(avg.right), avg.left, newAvg)), + /* mk = */ If(IsNull(mk.left), mk.right, + If(IsNull(mk.right), mk.left, newMk)) + ) + } + + override val evaluateExpression: Expression = { + // when count == 0, return null + // when count == 1, return 0 + // when count >1 + // stddev_samp = sqrt (mk/(count -1)) + // stddev_pop = sqrt (mk/count) + val varCol = + if (isSample) { + mk / Cast(count - Cast(Literal(1), resultType), resultType) + } else { + mk / count + } + + If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala new file mode 100644 index 0000000000000..7f8adbc56ad1d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Sum(child: Expression) extends DeclarativeAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select sum(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType + case _ => child.dataType + } + + private val sumDataType = resultType + + private val sum = AttributeReference("sum", sumDataType)() + + private val zero = Cast(Literal(0), sumDataType) + + override val aggBufferAttributes = sum :: Nil + + override val initialValues: Seq[Expression] = Seq( + /* sum = */ Literal.create(null, sumDataType) + ) + + override val updateExpressions: Seq[Expression] = Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + + override val mergeExpressions: Seq[Expression] = { + val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) + Seq( + /* sum = */ + Coalesce(Seq(add, sum.left)) + ) + } + + override val evaluateExpression: Expression = Cast(sum, resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala new file mode 100644 index 0000000000000..ec63534e5290a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.expressions._ + +case class VarianceSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "var_samp" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + } +} + +case class VariancePop(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "var_pop" + + override protected val momentOrder = 2 + + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") + + if (n == 0.0) Double.NaN else moments(2) / n + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 89d63abd9f272..3dcf7915d77b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -549,7 +549,7 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg case _ => child.dataType } - override def toString: String = s"SUM(DISTINCT $child)" + override def toString: String = s"sum(distinct $child)" override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) override def asPartial: SplitEvaluation = { @@ -646,7 +646,7 @@ case class First( override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})" + override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() @@ -707,7 +707,7 @@ case class Last( override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}" + override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() @@ -756,7 +756,7 @@ case class Corr(left: Expression, right: Expression) extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { override def nullable: Boolean = false override def dataType: DoubleType.type = DoubleType - override def toString: String = s"CORRELATION($left, $right)" + override def toString: String = s"corr($left, $right)" override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) override def newInstance(): AggregateFunction1 = { throw new UnsupportedOperationException( @@ -788,14 +788,14 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression with Partia // Compute the population standard deviation of a column case class StddevPop(child: Expression) extends StddevAgg1(child) { - override def toString: String = s"STDDEV_POP($child)" + override def toString: String = s"stddev_pop($child)" override def isSample: Boolean = false } // Compute the sample standard deviation of a column case class StddevSamp(child: Expression) extends StddevAgg1(child) { - override def toString: String = s"STDDEV_SAMP($child)" + override def toString: String = s"stddev_samp($child)" override def isSample: Boolean = true } @@ -1019,8 +1019,6 @@ case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExp override def foldable: Boolean = false override def prettyName: String = "kurtosis" - - override def toString: String = s"KURTOSIS($child)" } // placeholder @@ -1038,8 +1036,6 @@ case class Skewness(child: Expression) extends UnaryExpression with AggregateExp override def foldable: Boolean = false override def prettyName: String = "skewness" - - override def toString: String = s"SKEWNESS($child)" } // placeholder @@ -1056,9 +1052,7 @@ case class VariancePop(child: Expression) extends UnaryExpression with Aggregate override def foldable: Boolean = false - override def prettyName: String = "variance_pop" - - override def toString: String = s"VAR_POP($child)" + override def prettyName: String = "var_pop" } // placeholder @@ -1075,7 +1069,5 @@ case class VarianceSamp(child: Expression) extends UnaryExpression with Aggregat override def foldable: Boolean = false - override def prettyName: String = "variance_samp" - - override def toString: String = s"VAR_SAMP($child)" + override def prettyName: String = "var_samp" } From 701fb5052080fa8c0a79ad7c1e65693ccf444787 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 4 Nov 2015 14:03:31 -0800 Subject: [PATCH 174/324] [SPARK-10949] Update Snappy version to 1.1.2 This is an updated version of #8995 by a-roberts. Original description follows: Snappy now supports concatenation of serialized streams, this patch contains a version number change and the "does not support" test is now a "supports" test. Snappy 1.1.2 changelog mentions: > snappy-java-1.1.2 (22 September 2015) > This is a backward compatible release for 1.1.x. > Add AIX (32-bit) support. > There is no upgrade for the native libraries of the other platforms. > A major change since 1.1.1 is a support for reading concatenated results of SnappyOutputStream(s) > snappy-java-1.1.2-RC2 (18 May 2015) > Fix #107: SnappyOutputStream.close() is not idempotent > snappy-java-1.1.2-RC1 (13 May 2015) > SnappyInputStream now supports reading concatenated compressed results of SnappyOutputStream > There has been no compressed format change since 1.0.5.x. So You can read the compressed results > interchangeablly between these versions. > Fixes a problem when java.io.tmpdir does not exist. Closes #8995. Author: Adam Roberts Author: Josh Rosen Closes #9439 from JoshRosen/update-snappy. --- .../org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 4 ++-- .../main/scala/org/apache/spark/io/CompressionCodec.scala | 5 +++++ .../scala/org/apache/spark/io/CompressionCodecSuite.scala | 6 ++---- pom.xml | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e19b37864293c..6a0a89e81c321 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -254,8 +254,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); - final boolean fastMergeIsSupported = - !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + final boolean fastMergeIsSupported = !compressionEnabled || + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 9dc36704a676d..ca74eedf89be5 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -47,6 +47,11 @@ trait CompressionCodec { private[spark] object CompressionCodec { private val configKey = "spark.io.compression.codec" + + private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { + codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + } + private val shortCompressionCodecNames = Map( "lz4" -> classOf[LZ4CompressionCodec].getName, "lzf" -> classOf[LZFCompressionCodec].getName, diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index cbdb33c89d0fb..1553ab60bddaa 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -100,12 +100,10 @@ class CompressionCodecSuite extends SparkFunSuite { testCodec(codec) } - test("snappy does not support concatenation of serialized streams") { + test("snappy supports concatenation of serialized streams") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) - intercept[Exception] { - testConcatenationOfSerializedStreams(codec) - } + testConcatenationOfSerializedStreams(codec) } test("bad compression codec") { diff --git a/pom.xml b/pom.xml index 762bfc7282335..f5a3e44fc0a34 100644 --- a/pom.xml +++ b/pom.xml @@ -165,7 +165,7 @@ org.scala-lang 1.9.13 2.4.4 - 1.1.1.7 + 1.1.2 1.1.2 1.2.0-incubating 1.10 From 1b6a5d4af9691c3f7f3ebee3146dc13d12a0e047 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Nov 2015 14:45:02 -0800 Subject: [PATCH 175/324] [SPARK-11493] remove bitset from BytesToBytesMap Since we have 4 bytes as number of records in the beginning of a page, the address can not be zero, so we do not need the bitset. For performance concerns, the bitset could help speed up false lookup if the slot is empty (because bitset is smaller than longArray, cache hit rate will be higher). In practice, the map is filled with 35% - 70% (use 50% as average), so only half of the false lookups can benefit of it, all others will pay the cost of load the bitset (still need to access the longArray anyway). For aggregation, we always need to access the longArray (insert a new key after false lookup), also confirmed by a benchmark. For broadcast hash join, there could be a regression, but a simple benchmark showed that it may not (most of lookup are false): ``` sqlContext.range(1<<20).write.parquet("small") df = sqlContext.read.parquet('small') for i in range(3): t = time.time() df2 = sqlContext.range(1<<26).selectExpr("id * 1111111111 % 987654321 as id2") df2.join(df, df.id == df2.id2).count() print time.time() -t ``` Having bitset (used time in seconds): ``` 17.5404241085 10.2758829594 10.5786800385 ``` After removing bitset (used time in seconds): ``` 21.8939979076 12.4132959843 9.97224712372 ``` cc rxin nongli Author: Davies Liu Closes #9452 from davies/remove_bitset. --- .../spark/unsafe/map/BytesToBytesMap.java | 58 +++------ .../apache/spark/unsafe/bitset/BitSet.java | 113 ------------------ .../spark/unsafe/bitset/BitSetSuite.java | 88 -------------- 3 files changed, 15 insertions(+), 244 deletions(-) delete mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java delete mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index e36709c6fc849..07241c827c2ae 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -35,7 +35,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -123,12 +122,6 @@ public final class BytesToBytesMap extends MemoryConsumer { */ private boolean canGrowArray = true; - /** - * A {@link BitSet} used to track location of the map where the key is set. - * Size of the bitset should be half of the size of the long array. - */ - @Nullable private BitSet bitset; - private final double loadFactor; /** @@ -427,7 +420,6 @@ public Location lookup(Object keyBase, long keyOffset, int keyLength) { * This is a thread-safe version of `lookup`, could be used by multiple threads. */ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { - assert(bitset != null); assert(longArray != null); if (enablePerfMetrics) { @@ -440,7 +432,7 @@ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location l if (enablePerfMetrics) { numProbes++; } - if (!bitset.isSet(pos)) { + if (longArray.get(pos * 2) == 0) { // This is a new key. loc.with(pos, hashcode, false); return; @@ -644,7 +636,6 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, assert (!isDefined) : "Can only set value once for a key"; assert (keyLength % 8 == 0); assert (valueLength % 8 == 0); - assert(bitset != null); assert(longArray != null); if (numElements == MAX_CAPACITY || !canGrowArray) { @@ -678,7 +669,6 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, Platform.putInt(base, offset, Platform.getInt(base, offset) + 1); pageCursor += recordLength; numElements++; - bitset.set(pos); final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); longArray.set(pos * 2, storedKeyAddress); @@ -734,7 +724,6 @@ private void allocate(int capacity) { assert (capacity <= MAX_CAPACITY); acquireMemory(capacity * 16); longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); - bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); this.mask = capacity - 1; @@ -749,7 +738,6 @@ public void freeArray() { long used = longArray.memoryBlock().size(); longArray = null; releaseMemory(used); - bitset = null; } } @@ -795,9 +783,7 @@ public long getTotalMemoryConsumption() { for (MemoryBlock dataPage : dataPages) { totalDataPagesSize += dataPage.size(); } - return totalDataPagesSize + - ((bitset != null) ? bitset.memoryBlock().size() : 0L) + - ((longArray != null) ? longArray.memoryBlock().size() : 0L); + return totalDataPagesSize + ((longArray != null) ? longArray.memoryBlock().size() : 0L); } private void updatePeakMemoryUsed() { @@ -852,7 +838,6 @@ public int getNumDataPages() { */ @VisibleForTesting void growAndRehash() { - assert(bitset != null); assert(longArray != null); long resizeStartTime = -1; @@ -861,39 +846,26 @@ void growAndRehash() { } // Store references to the old data structures to be used when we re-hash final LongArray oldLongArray = longArray; - final BitSet oldBitSet = bitset; - final int oldCapacity = (int) oldBitSet.capacity(); + final int oldCapacity = (int) oldLongArray.size() / 2; // Allocate the new data structures - try { - allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); - } catch (OutOfMemoryError oom) { - longArray = oldLongArray; - bitset = oldBitSet; - throw oom; - } + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) - for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { - final long keyPointer = oldLongArray.get(pos * 2); - final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + for (int i = 0; i < oldLongArray.size(); i += 2) { + final long keyPointer = oldLongArray.get(i); + if (keyPointer == 0) { + continue; + } + final int hashcode = (int) oldLongArray.get(i + 1); int newPos = hashcode & mask; int step = 1; - boolean keepGoing = true; - - // No need to check for equality here when we insert so this has one less if branch than - // the similar code path in addWithoutResize. - while (keepGoing) { - if (!bitset.isSet(newPos)) { - bitset.set(newPos); - longArray.set(newPos * 2, keyPointer); - longArray.set(newPos * 2 + 1, hashcode); - keepGoing = false; - } else { - newPos = (newPos + step) & mask; - step++; - } + while (longArray.get(newPos * 2) != 0) { + newPos = (newPos + step) & mask; + step++; } + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); } releaseMemory(oldLongArray.memoryBlock().size()); diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java deleted file mode 100644 index 7c124173b0bbb..0000000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.bitset; - -import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; - -/** - * A fixed size uncompressed bit set backed by a {@link LongArray}. - * - * Each bit occupies exactly one bit of storage. - */ -public final class BitSet { - - /** A long array for the bits. */ - private final LongArray words; - - /** Length of the long array. */ - private final int numWords; - - private final Object baseObject; - private final long baseOffset; - - /** - * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be - * multiple of 8 bytes (i.e. 64 bits). - */ - public BitSet(MemoryBlock memory) { - words = new LongArray(memory); - assert (words.size() <= Integer.MAX_VALUE); - numWords = (int) words.size(); - baseObject = words.memoryBlock().getBaseObject(); - baseOffset = words.memoryBlock().getBaseOffset(); - } - - public MemoryBlock memoryBlock() { - return words.memoryBlock(); - } - - /** - * Returns the number of bits in this {@code BitSet}. - */ - public long capacity() { - return numWords * 64; - } - - /** - * Sets the bit at the specified index to {@code true}. - */ - public void set(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - BitSetMethods.set(baseObject, baseOffset, index); - } - - /** - * Sets the bit at the specified index to {@code false}. - */ - public void unset(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - BitSetMethods.unset(baseObject, baseOffset, index); - } - - /** - * Returns {@code true} if the bit is set at the specified index. - */ - public boolean isSet(int index) { - assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; - return BitSetMethods.isSet(baseObject, baseOffset, index); - } - - /** - * Returns the index of the first bit that is set to true that occurs on or after the - * specified starting index. If no such bit exists then {@code -1} is returned. - *

    - * To iterate over the true bits in a BitSet, use the following loop: - *

    -   * 
    -   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
    -   *    // operate on index i here
    -   *  }
    -   * 
    -   * 
    - * - * @param fromIndex the index to start checking from (inclusive) - * @return the index of the next set bit, or -1 if there is no such bit - */ - public int nextSetBit(int fromIndex) { - return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); - } - - /** - * Returns {@code true} if any bit is set. - */ - public boolean anySet() { - return BitSetMethods.anySet(baseObject, baseOffset, numWords); - } - -} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java deleted file mode 100644 index 14e38683df4ab..0000000000000 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.bitset; - -import org.junit.Assert; -import org.junit.Test; - -import org.apache.spark.unsafe.memory.MemoryBlock; - -public class BitSetSuite { - - private static BitSet createBitSet(int capacity) { - Assert.assertEquals(0, capacity % 64); - return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); - } - - @Test - public void basicOps() { - BitSet bs = createBitSet(64); - Assert.assertEquals(64, bs.capacity()); - - // Make sure the bit set starts empty. - for (int i = 0; i < bs.capacity(); i++) { - Assert.assertFalse(bs.isSet(i)); - } - // another form of asserting that the bit set is empty - Assert.assertFalse(bs.anySet()); - - // Set every bit and check it. - for (int i = 0; i < bs.capacity(); i++) { - bs.set(i); - Assert.assertTrue(bs.isSet(i)); - } - - // Unset every bit and check it. - for (int i = 0; i < bs.capacity(); i++) { - Assert.assertTrue(bs.isSet(i)); - bs.unset(i); - Assert.assertFalse(bs.isSet(i)); - } - - // Make sure anySet() can detect any set bit - bs = createBitSet(256); - bs.set(64); - Assert.assertTrue(bs.anySet()); - } - - @Test - public void traversal() { - BitSet bs = createBitSet(256); - - Assert.assertEquals(-1, bs.nextSetBit(0)); - Assert.assertEquals(-1, bs.nextSetBit(10)); - Assert.assertEquals(-1, bs.nextSetBit(64)); - - bs.set(10); - Assert.assertEquals(10, bs.nextSetBit(0)); - Assert.assertEquals(10, bs.nextSetBit(1)); - Assert.assertEquals(10, bs.nextSetBit(10)); - Assert.assertEquals(-1, bs.nextSetBit(11)); - - bs.set(11); - Assert.assertEquals(10, bs.nextSetBit(10)); - Assert.assertEquals(11, bs.nextSetBit(11)); - - // Skip a whole word and find it - bs.set(190); - Assert.assertEquals(190, bs.nextSetBit(12)); - - Assert.assertEquals(-1, bs.nextSetBit(191)); - Assert.assertEquals(-1, bs.nextSetBit(256)); - } -} From 411ff6afb485c9d8cfc667c9346f836f2529ea9f Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 4 Nov 2015 15:28:19 -0800 Subject: [PATCH 176/324] [SPARK-10028][MLLIB][PYTHON] Add Python API for PrefixSpan Author: Yu ISHIKAWA Closes #9469 from yu-iskw/SPARK-10028. --- .../api/python/PrefixSpanModelWrapper.scala | 32 +++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 23 ++++++- python/pyspark/mllib/fpm.py | 69 ++++++++++++++++++- 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala new file mode 100644 index 0000000000000..0027602a04f81 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PrefixSpanModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.mllib.fpm.PrefixSpanModel +import org.apache.spark.rdd.RDD + +/** + * A Wrapper of PrefixSpanModel to provide helper method for Python + */ +private[python] class PrefixSpanModelWrapper(model: PrefixSpanModel[Any]) + extends PrefixSpanModel(model.freqSequences) { + + def getFreqSequences: RDD[Array[Any]] = { + SerDe.fromTuple2RDD(model.freqSequences.map(x => (x.javaSequence, x.freq))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 21e55938fa7aa..40c41806cdfea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -35,7 +35,7 @@ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.evaluation.RankingMetrics import org.apache.spark.mllib.feature._ -import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel, PrefixSpan} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed._ import org.apache.spark.mllib.optimization._ @@ -557,6 +557,27 @@ private[python] class PythonMLLibAPI extends Serializable { new FPGrowthModelWrapper(model) } + /** + * Java stub for Python mllib PrefixSpan.train(). This stub returns a handle + * to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see + * the Py4J documentation. + */ + def trainPrefixSpanModel( + data: JavaRDD[java.util.ArrayList[java.util.ArrayList[Any]]], + minSupport: Double, + maxPatternLength: Int, + localProjDBSize: Int ): PrefixSpanModelWrapper = { + val prefixSpan = new PrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(localProjDBSize) + + val trainData = data.rdd.map(_.asScala.toArray.map(_.asScala.toArray)) + val model = prefixSpan.run(trainData) + new PrefixSpanModelWrapper(model) + } + /** * Java stub for Normalizer.transform() */ diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdabba9602a8c..2039decc0cb3c 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -23,7 +23,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc -__all__ = ['FPGrowth', 'FPGrowthModel'] +__all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @@ -85,6 +85,73 @@ class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ +@inherit_doc +@ignore_unicode_prefix +class PrefixSpanModel(JavaModelWrapper): + """ + .. note:: Experimental + + Model fitted by PrefixSpan + + >>> data = [ + ... [["a", "b"], ["c"]], + ... [["a"], ["c", "b"], ["a", "b"]], + ... [["a", "b"], ["e"]], + ... [["f"]]] + >>> rdd = sc.parallelize(data, 2) + >>> model = PrefixSpan.train(rdd) + >>> sorted(model.freqSequences().collect()) + [FreqSequence(sequence=[[u'a']], freq=3), FreqSequence(sequence=[[u'a'], [u'a']], freq=1), ... + + .. versionadded:: 1.6.0 + """ + + @since("1.6.0") + def freqSequences(self): + """Gets frequence sequences""" + return self.call("getFreqSequences").map(lambda x: PrefixSpan.FreqSequence(x[0], x[1])) + + +class PrefixSpan(object): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: + Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth + ([[http://doi.org/10.1109/ICDE.2001.914830]]). + + .. versionadded:: 1.6.0 + """ + + @classmethod + @since("1.6.0") + def train(cls, data, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000): + """ + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param data: The input data set, each element contains a sequnce of itemsets. + :param minSupport: the minimal support level of the sequential pattern, any pattern appears + more than (minSupport * size-of-the-dataset) times will be output (default: `0.1`) + :param maxPatternLength: the maximal length of the sequential pattern, any pattern appears + less than maxPatternLength will be output. (default: `10`) + :param maxLocalProjDBSize: The maximum number of items (including delimiters used in + the internal storage format) allowed in a projected database before local + processing. If a projected database exceeds this size, another + iteration of distributed prefix growth is run. (default: `32000000`) + """ + model = callMLlibFunc("trainPrefixSpanModel", + data, minSupport, maxPatternLength, maxLocalProjDBSize) + return PrefixSpanModel(model) + + class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])): + """ + Represents a (sequence, freq) tuple. + + .. versionadded:: 1.6.0 + """ + + def _test(): import doctest import pyspark.mllib.fpm From b6e0a5ae6f243139f11c9cbbf18cddd3f25db208 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 4 Nov 2015 16:49:25 -0800 Subject: [PATCH 177/324] [SPARK-11510][SQL] Remove SQL aggregation tests for higher order statistics We have some aggregate function tests in both DataFrameAggregateSuite and SQLQuerySuite. The two have almost the same coverage and we should just remove the SQL one. Author: Reynold Xin Closes #9475 from rxin/SPARK-11510. --- .../spark/sql/DataFrameAggregateSuite.scala | 97 ++++++------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 77 --------------- .../spark/sql/StringFunctionsSuite.scala | 1 - 3 files changed, 28 insertions(+), 147 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b0e2ffaa60687..2e679e7bc4e0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -83,13 +83,8 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("average") { checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) - - // Also check mean - checkAnswer( - testData2.agg(mean('a)), - Row(2.0)) + testData2.agg(avg('a), mean('a)), + Row(2.0, 2.0)) checkAnswer( testData2.agg(avg('a), sumDistinct('a)), // non-partial @@ -98,6 +93,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) + checkAnswer( decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) @@ -168,44 +164,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - checkAnswer( emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("stddev") { - val testData2ADev = math.sqrt(4/5.0) - + val testData2ADev = math.sqrt(4 / 5.0) checkAnswer( - testData2.agg(stddev('a)), - Row(testData2ADev)) - - checkAnswer( - testData2.agg(stddev_pop('a)), - Row(math.sqrt(4/6.0))) - - checkAnswer( - testData2.agg(stddev_samp('a)), - Row(testData2ADev)) + testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() == 0) - - checkAnswer( - emptyTableData.agg(stddev('a)), - Row(null)) - checkAnswer( - emptyTableData.agg(stddev_pop('a)), - Row(null)) - - checkAnswer( - emptyTableData.agg(stddev_samp('a)), - Row(null)) + emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), + Row(null, null, null)) } test("zero sum") { @@ -227,6 +202,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) + val sparkVariancePop = testData2.agg(var_pop('a)) checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol) @@ -241,52 +217,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("zero moments") { - val emptyTableData = Seq((1, 2)).toDF("a", "b") - assert(emptyTableData.count() === 1) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - + val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(0.0)) + input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + input.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) } test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - assert(emptyTableData.count() === 0) - - checkAnswer( - emptyTableData.agg(variance('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_samp('a)), - Row(Double.NaN)) - - checkAnswer( - emptyTableData.agg(var_pop('a)), - Row(Double.NaN)) checkAnswer( - emptyTableData.agg(skewness('a)), - Row(Double.NaN)) + emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) checkAnswer( - emptyTableData.agg(kurtosis('a)), - Row(Double.NaN)) + emptyTableData.agg( + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5731a356243e5..3de277a79a52c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -726,83 +726,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("stddev") { - checkAnswer( - sql("SELECT STDDEV(a) FROM testData2"), - Row(math.sqrt(4.0 / 5.0)) - ) - } - - test("stddev_pop") { - checkAnswer( - sql("SELECT STDDEV_POP(a) FROM testData2"), - Row(math.sqrt(4.0 / 6.0)) - ) - } - - test("stddev_samp") { - checkAnswer( - sql("SELECT STDDEV_SAMP(a) FROM testData2"), - Row(math.sqrt(4/5.0)) - ) - } - - test("var_samp") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_SAMP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 5.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("variance") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2") - val expectedAnswer = Row(0.8) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("var_pop") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT VAR_POP(a) FROM testData2") - val expectedAnswer = Row(4.0 / 6.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("skewness") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT skewness(a) FROM testData2") - val expectedAnswer = Row(0.0) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("kurtosis") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT kurtosis(a) FROM testData2") - val expectedAnswer = Row(-1.5) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - - test("stddev agg") { - checkAnswer( - sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0)))) - } - - test("variance agg") { - val absTol = 1e-8 - checkAggregatesWithTol( - sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"), - (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)), - absTol) - } - - test("skewness and kurtosis agg") { - val absTol = 1e-8 - val sparkAnswer = sql("SELECT a, skewness(b), kurtosis(b) FROM testData2 GROUP BY a") - val expectedAnswer = (1 to 3).map(i => Row(i, 0.0, -2.0)) - checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol) - } - test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e12e6bea30260..e2090b0a83ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.Decimal class StringFunctionsSuite extends QueryTest with SharedSQLContext { From ce5e6a2849ae860689fa3e7d5aaa12216945ea99 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 4 Nov 2015 16:58:38 -0800 Subject: [PATCH 178/324] [SPARK-11491] Update build to use Scala 2.10.5 Spark should build against Scala 2.10.5, since that includes a fix for Scaladoc that will fix doc snapshot publishing: https://issues.scala-lang.org/browse/SI-8479 Author: Josh Rosen Closes #9450 from JoshRosen/upgrade-to-scala-2.10.5. --- LICENSE | 10 +++++----- dev/audit-release/README.md | 2 +- dev/audit-release/audit_release.py | 2 +- docker/spark-test/base/Dockerfile | 2 +- docs/_config.yml | 2 +- pom.xml | 4 ++-- project/SparkBuild.scala | 2 +- 7 files changed, 12 insertions(+), 12 deletions(-) diff --git a/LICENSE b/LICENSE index 790476ece15bd..0db2d14465bd3 100644 --- a/LICENSE +++ b/LICENSE @@ -250,11 +250,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.4 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.10.4 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.10.5 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.10.5 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.10:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.10:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.10:0.7.1 - http://spire-math.org) diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 38becda0eae92..f72f8c653a265 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.4 \ +$ SCALA_VERSION=2.10.5 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 0b7069f6e116a..27d1dd784ce2e 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" RELEASE_VERSION = "1.1.1" -SCALA_VERSION = "2.10.4" +SCALA_VERSION = "2.10.5" SCALA_BINARY_VERSION = "2.10" # Do not set these diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index 5dbdb8b22a44f..7ba0de603dc7d 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* -ENV SCALA_VERSION 2.10.4 +ENV SCALA_VERSION 2.10.5 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index c59cc465ef89d..2c70b76be8b7a 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -17,7 +17,7 @@ include: SPARK_VERSION: 1.6.0-SNAPSHOT SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.4" +SCALA_VERSION: "2.10.5" MESOS_VERSION: 0.21.0 SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/pom.xml b/pom.xml index f5a3e44fc0a34..4ed1c0c82dee6 100644 --- a/pom.xml +++ b/pom.xml @@ -159,7 +159,7 @@ 3.1 3.4.1 - 2.10.4 + 2.10.5 2.10 ${scala.version} org.scala-lang @@ -2422,7 +2422,7 @@ !scala-2.11 - 2.10.4 + 2.10.5 2.10 ${scala.version} org.scala-lang diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 766edd9500c30..75c36930decef 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -316,7 +316,7 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", - scalaVersion := "2.10.4", + scalaVersion := "2.10.5", retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", From a752ddad7fe1d0f01b51f7551ec017ff87e1eea5 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Wed, 4 Nov 2015 17:16:00 -0800 Subject: [PATCH 179/324] [SPARK-11398] [SQL] unnecessary def dialectClassName in HiveContext, and misleading dialect conf at the start of spark-sql MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. def dialectClassName in HiveContext is unnecessary. In HiveContext, if conf.dialect == "hiveql", getSQLDialect() will return new HiveQLDialect(this); else it will use super.getSQLDialect(). Then in super.getSQLDialect(), it calls dialectClassName, which is overriden in HiveContext and still return super.dialectClassName. So we'll never reach the code "classOf[HiveQLDialect].getCanonicalName" of def dialectClassName in HiveContext. 2. When we start bin/spark-sql, the default context is HiveContext, and the corresponding dialect is hiveql. However, if we type "set spark.sql.dialect;", the result is "sql", which is inconsistent with the actual dialect and is misleading. For example, we can use sql like "create table" which is only allowed in hiveql, but this dialect conf shows it's "sql". Although this problem will not cause any execution error, it's misleading to spark sql users. Therefore I think we should fix it. In this pr, while procesing “set spark.sql.dialect” in SetCommand, I use "conf.dialect" instead of "getConf()" for the case of key == SQLConf.DIALECT.key, so that it will return the right dialect conf. Author: Zhenhua Wang Closes #9349 from wzhfy/dialect. --- .../scala/org/apache/spark/sql/execution/commands.scala | 6 +++++- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 ------ .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 7 +++++++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 856607615ae87..e5f60b15e7359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -156,7 +156,11 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm val runFunc = (sqlContext: SQLContext) => { val value = try { - sqlContext.getConf(key) + if (key == SQLConf.DIALECT.key) { + sqlContext.conf.dialect + } else { + sqlContext.getConf(key) + } } catch { case _: NoSuchElementException => "" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 83a81cf5d1fcf..1f5135320326c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -555,12 +555,6 @@ class HiveContext private[hive]( override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - protected[sql] override def dialectClassName = if (conf.dialect == "hiveql") { - classOf[HiveQLDialect].getCanonicalName - } else { - super.dialectClassName - } - protected[sql] override def getSQLDialect(): ParserDialect = { if (conf.dialect == "hiveql") { new HiveQLDialect(this) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index fd380641dcc71..af48d478953b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -335,6 +335,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SQL dialect at the start of HiveContext") { + val hiveContext = new HiveContext(sqlContext.sparkContext) + val dialectConf = "spark.sql.dialect" + checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) + assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect]) + } + test("SQL Dialect Switching") { assert(getSQLDialect().getClass === classOf[HiveQLDialect]) setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) From d0b56339625727744e2c30fc2167bc6a457d37f7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 4 Nov 2015 17:19:52 -0800 Subject: [PATCH 180/324] [SPARK-11307] Reduce memory consumption of OutputCommitCoordinator OutputCommitCoordinator uses a map in a place where an array would suffice, increasing its memory consumption for result stages with millions of tasks. This patch replaces that map with an array. The only tricky part of this is reasoning about the range of possible array indexes in order to make sure that we never index out of bounds. Author: Josh Rosen Closes #9274 from JoshRosen/SPARK-11307. --- .../apache/spark/scheduler/DAGScheduler.scala | 8 +++- .../scheduler/OutputCommitCoordinator.scala | 40 ++++++++++++------- .../OutputCommitCoordinatorSuite.scala | 2 +- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5673fbf2c8fea..a1f0fd05f661a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -947,7 +947,13 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - outputCommitCoordinator.stageStart(stage.id) + stage match { + case s: ShuffleMapStage => + outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + case s: ResultStage => + outputCommitCoordinator.stageStart( + stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) + } val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { stage match { case s: ShuffleMapStage => diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index add0dedc03f44..4d146678174f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -47,6 +47,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private type PartitionId = Int private type TaskAttemptNumber = Int + private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing * output for that partition. @@ -56,9 +58,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = - mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] + private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -95,9 +95,21 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } } - // Called by DAGScheduler - private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() + /** + * Called by the DAGScheduler when a stage starts. + * + * @param stage the stage id. + * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. + * the maximum possible value of `context.partitionId`). + */ + private[scheduler] def stageStart( + stage: StageId, + maxPartitionId: Int): Unit = { + val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) + java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) + synchronized { + authorizedCommittersByStage(stage) = arr + } } // Called by DAGScheduler @@ -122,10 +134,10 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + if (authorizedCommitters(partition) == attemptNumber) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - authorizedCommitters.remove(partition) + authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -145,16 +157,16 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => - authorizedCommitters.get(partition) match { - case Some(existingCommitter) => - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - case None => + authorizedCommitters(partition) match { + case NO_AUTHORIZED_COMMITTER => logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + s"partition=$partition") authorizedCommitters(partition) = attemptNumber true + case existingCommitter => + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") + false } case None => logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 48456a9cd6e7b..7345508bfe995 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -171,7 +171,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 - outputCommitCoordinator.stageStart(stage) + outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) From 81498dd5c86ca51d2fb351c8ef52cbb28e6844f4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Nov 2015 21:30:21 -0800 Subject: [PATCH 181/324] [SPARK-11425] [SPARK-11486] Improve hybrid aggregation After aggregation, the dataset could be smaller than inputs, so it's better to do hash based aggregation for all inputs, then using sort based aggregation to merge them. Author: Davies Liu Closes #9383 from davies/fix_switch. --- .../spark/unsafe/map/BytesToBytesMap.java | 46 +++-- .../unsafe/sort/UnsafeExternalSorter.java | 39 ++-- .../unsafe/sort/UnsafeInMemorySorter.java | 15 +- .../UnsafeFixedWidthAggregationMap.java | 9 +- .../sql/execution/UnsafeKVExternalSorter.java | 23 ++- .../TungstenAggregationIterator.scala | 171 +++++------------- .../UnsafeFixedWidthAggregationMapSuite.scala | 64 +++---- .../execution/AggregationQuerySuite.scala | 12 +- 8 files changed, 165 insertions(+), 214 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 07241c827c2ae..6656fd1d0bc59 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,6 +20,7 @@ import javax.annotation.Nullable; import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -638,7 +639,11 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, assert (valueLength % 8 == 0); assert(longArray != null); - if (numElements == MAX_CAPACITY || !canGrowArray) { + + if (numElements == MAX_CAPACITY + // The map could be reused from last spill (because of no enough memory to grow), + // then we don't try to grow again if hit the `growthThreshold`. + || !canGrowArray && numElements > growthThreshold) { return false; } @@ -730,25 +735,18 @@ private void allocate(int capacity) { } /** - * Free the memory used by longArray. + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent and can be called multiple times. */ - public void freeArray() { + public void free() { updatePeakMemoryUsed(); if (longArray != null) { long used = longArray.memoryBlock().size(); longArray = null; releaseMemory(used); } - } - - /** - * Free all allocated memory associated with this map, including the storage for keys and values - * as well as the hash map array itself. - * - * This method is idempotent and can be called multiple times. - */ - public void free() { - freeArray(); Iterator dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { MemoryBlock dataPage = dataPagesIterator.next(); @@ -833,6 +831,28 @@ public int getNumDataPages() { return dataPages.size(); } + /** + * Returns the underline long[] of longArray. + */ + public long[] getArray() { + assert(longArray != null); + return (long[]) longArray.memoryBlock().getBaseObject(); + } + + /** + * Reset this map to initialized state. + */ + public void reset() { + numElements = 0; + Arrays.fill(getArray(), 0); + while (dataPages.size() > 0) { + MemoryBlock dataPage = dataPages.removeLast(); + freePage(dataPage); + } + currentPage = null; + pageCursor = 0; + } + /** * Grows the size of the hash table and re-hash everything. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 509fb0a044c0c..cba043bc48cc8 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -79,9 +79,13 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - UnsafeInMemorySorter inMemorySorter) { - return new UnsafeExternalSorter(taskMemoryManager, blockManager, + UnsafeInMemorySorter inMemorySorter) throws IOException { + UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + sorter.spill(Long.MAX_VALUE, sorter); + // The external sorter will be used to insert records, in-memory sorter is not needed. + sorter.inMemSorter = null; + return sorter; } public static UnsafeExternalSorter create( @@ -124,7 +128,6 @@ private UnsafeExternalSorter( acquireMemory(inMemSorter.getMemoryUsage()); } else { this.inMemSorter = existingInMemorySorter; - // will acquire after free the map } this.peakMemoryUsedBytes = getMemoryUsage(); @@ -157,12 +160,9 @@ public void closeCurrentPage() { */ @Override public long spill(long size, MemoryConsumer trigger) throws IOException { - assert(inMemSorter != null); if (trigger != this) { if (readingIterator != null) { return readingIterator.spill(); - } else { - } return 0L; // this should throw exception } @@ -388,25 +388,38 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, inMemSorter.insertRecord(recordAddress, prefix); } + /** + * Merges another UnsafeExternalSorters into this one, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeExternalSorter other) throws IOException { + other.spill(); + spillWriters.addAll(other.spillWriters); + // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`. + other.spillWriters.clear(); + other.cleanupResources(); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(inMemSorter != null); - readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); - int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } - spillWriters.clear(); - spillMerger.addSpillIfNotEmpty(readingIterator); - + if (inMemSorter != null) { + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + spillMerger.addSpillIfNotEmpty(readingIterator); + } return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1480f0681ed9c..d57213b9b8bfc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,9 +19,9 @@ import java.util.Comparator; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; -import org.apache.spark.memory.TaskMemoryManager; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -77,13 +77,20 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { */ private int pos = 0; + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + } + public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { - assert (initialSize > 0); - this.array = new long[initialSize * 2]; + long[] array) { + this.array = array; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index d4b6d75b4d981..a2f99d566d471 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -236,16 +236,13 @@ public void printPerfMetrics() { /** * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] - * that can be used to insert more records to do external sorting. * - * The only memory that is allocated is the address/prefix array, 16 bytes per record. - * - * Note that this destroys the map, and as a result, the map cannot be used anymore after this. + * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used + * to insert records. */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { - UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + return new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); - return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 845f2ae6859b7..e2898ef2e2158 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -83,11 +83,10 @@ public UnsafeKVExternalSorter( /* initialSize */ 4096, pageSizeBytes); } else { - // The memory needed for UnsafeInMemorySorter should be less than longArray in map. - map.freeArray(); - // The memory used by UnsafeInMemorySorter will be counted later (end of this block) + // During spilling, the array in map will not be used, so we can borrow that and use it + // as the underline array for in-memory sorter (it's always large enough). final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); + taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. @@ -123,10 +122,9 @@ public UnsafeKVExternalSorter( pageSizeBytes, inMemSorter); - sorter.spill(); - map.free(); - // counting the memory used UnsafeInMemorySorter - taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter); + // reset the map, so we can re-use it to insert new records. the inMemSorter will not used + // anymore, so the underline array could be used by map again. + map.reset(); } } @@ -142,6 +140,15 @@ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } + /** + * Merges another UnsafeKVExternalSorter into `this`, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeKVExternalSorter other) throws IOException { + sorter.merge(other.sorter); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 713a4db0cd59b..ce8d592c368ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -34,14 +34,18 @@ import org.apache.spark.sql.types.StructType * * This iterator first uses hash-based aggregation to process input rows. It uses * a hash map to store groups and their corresponding aggregation buffers. If we - * this map cannot allocate memory from memory manager, - * it switches to sort-based aggregation. The process of the switch has the following step: + * this map cannot allocate memory from memory manager, it spill the map into disk + * and create a new one. After processed all the input, then merge all the spills + * together using external sorter, and do sort-based aggregation. + * + * The process has the following step: + * - Step 0: Do hash-based aggregation. * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. - * - Step 2: Create a external sorter based on the spilled sorted map entries. - * - Step 3: Redirect all input rows to the external sorter. - * - Step 4: Get a sorted [[KVIterator]] from the external sorter. - * - Step 5: Initialize sort-based aggregation. + * - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map. + * - Step 3: Get a sorted [[KVIterator]] from the external sorter. + * - Step 4: Repeat step 0 until no more input. + * - Step 5: Initialize sort-based aggregation on the sorted iterator. * Then, this iterator works in the way of sort-based aggregation. * * The code of this class is organized as follows: @@ -488,9 +492,10 @@ class TungstenAggregationIterator( // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in - // hashMap. If we could not allocate more memory for the map, we switch to - // sort-based aggregation (by calling switchToSortBasedAggregation). - private def processInputs(): Unit = { + // hashMap. If there is not enough memory, it will multiple hash-maps, spilling + // after each becomes full then using sort to merge these spills, finally do sort + // based aggregation. + private def processInputs(fallbackStartsAt: Int): Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -502,44 +507,40 @@ class TungstenAggregationIterator( processRow(buffer, newInput) } } else { - while (!sortBased && inputIter.hasNext) { + var i = 0 + while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + var buffer: UnsafeRow = null + if (i < fallbackStartsAt) { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + val sorter = hashMap.destructAndCreateExternalSorter() + if (externalSorter == null) { + externalSorter = sorter + } else { + externalSorter.merge(sorter) + } + i = 0 + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation") + } } + processRow(buffer, newInput) + i += 1 } - } - } - // This function is only used for testing. It basically the same as processInputs except - // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have - // been processed. - private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { - var i = 0 - while (!sortBased && inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - } else { - null - } - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + if (externalSorter != null) { + val sorter = hashMap.destructAndCreateExternalSorter() + externalSorter.merge(sorter) + hashMap.free() + + switchToSortBasedAggregation() } - i += 1 } } @@ -561,88 +562,8 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Step 1: Get the ExternalSorter containing sorted entries of the map. - externalSorter = hashMap.destructAndCreateExternalSorter() - - // Step 2: If we have aggregate function with mode Partial or Complete, - // we need to process input rows to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from an input row. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - // Note: Since we spill the sorter's contents immediately after creating it, we must insert - // something into the sorter here to ensure that we acquire at least a page of memory. - // This is done through `externalSorter.insertKV`, which will trigger the page allocation. - // Otherwise, children operators may steal the window of opportunity and starve our sorter. - - if (needsProcess) { - // First, we create a buffer. - val buffer = createNewAggregationBuffer() - - // Process firstKey and firstInput. - // Initialize buffer. - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Process the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } else { - // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. - // We need to project the aggregation buffer part from an input row. - val buffer = createNewAggregationBuffer() - // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to - // extract the aggregation buffer. In practice, however, we extract it positionally by relying - // on it being present at the end of the row. The reason for this relates to how the different - // aggregates handle input binding. - // - // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers, - // so its correctness does not rely on attribute bindings. When we fall back to sort-based - // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset) - // need to be updated and any internal state in the aggregate functions themselves must be - // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset - // this state and update the offsets. - // - // The updated ImperativeAggregate will have different attribute ids for its - // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual - // ImperativeAggregate evaluation, but it means that - // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the - // attributes in `originalInputAttributes`, which is why we can't use those attributes here. - // - // For more details, see the discussion on PR #9038. - val bufferExtractor = newMutableProjection( - originalInputAttributes.drop(initialInputBufferOffset), - originalInputAttributes)() - bufferExtractor.target(buffer) - - // Insert firstKey and its buffer. - bufferExtractor(firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Insert the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - bufferExtractor(newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. val newAggregationMode = aggregationMode match { @@ -762,15 +683,7 @@ class TungstenAggregationIterator( /** * Start processing input rows. */ - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) // If we did not switch to sort-based aggregation in processInputs, // we pre-load the first key-value pair from the map (to make hasNext idempotent). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index a38623623a441..7ceaee38d131b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -170,9 +170,6 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("test external sorting") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -189,35 +186,33 @@ class UnsafeFixedWidthAggregationMapSuite buf.setInt(0, keyString.length) assert(buf != null) } - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() while (iter.next()) { - assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) - out += iter.getKey.getString(0) + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) } assert(out === (keys ++ additionalKeys).sorted) - map.free() } @@ -232,25 +227,21 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() @@ -262,16 +253,12 @@ class UnsafeFixedWidthAggregationMapSuite out += key.getString(0) } - assert(out === (additionalKeys).sorted) - + assert(out === additionalKeys.sorted) map.free() } testWithMemoryLeakDetection("test external sorting with empty records") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, StructType(Nil), @@ -281,7 +268,6 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - (1 to 10).foreach { i => val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) assert(buf != null) @@ -292,13 +278,15 @@ class UnsafeFixedWidthAggregationMapSuite // Add more keys to the sorter and make sure the results come out sorted. (1 to 4096).foreach { i => - sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + map.getAggregationBufferFromUnsafeRow(UnsafeRow.createFromByteArray(0, 0)) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) var count = 0 val iter = sorter.sortedIterator() @@ -309,9 +297,8 @@ class UnsafeFixedWidthAggregationMapSuite count += 1 } - // 1 record was from the map and 4096 records were explicitly inserted. - assert(count === 4097) - + // 1 record per map, spilled 42 times. + assert(count === 42) map.free() } @@ -345,6 +332,7 @@ class UnsafeFixedWidthAggregationMapSuite var sorter: UnsafeKVExternalSorter = null try { sorter = map.destructAndCreateExternalSorter() + map.free() } finally { if (sorter != null) { sorter.cleanupResources() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 74061db0f28af..ea80060e370e0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,13 +22,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { @@ -702,6 +701,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } } + test("no aggregation function (SPARK-11486)") { + val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + .groupBy("s").count() + .groupBy().count() + checkAnswer(df, Row(20) :: Nil) + } + test("udaf with all data types") { val struct = StructType( From 6f81eae24f83df51a99d4bb2629dd7daadc01519 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 5 Nov 2015 09:08:53 +0000 Subject: [PATCH 182/324] [SPARK-11440][CORE][STREAMING][BUILD] Declare rest of @Experimental items non-experimental if they've existed since 1.2.0 Remove `Experimental` annotations in core, streaming for items that existed in 1.2.0 or before. The changes are: * SparkContext * binary{Files,Records} : 1.2.0 * submitJob : 1.0.0 * JavaSparkContext * binary{Files,Records} : 1.2.0 * DoubleRDDFunctions, JavaDoubleRDD * {mean,sum}Approx : 1.0.0 * PairRDDFunctions, JavaPairRDD * sampleByKeyExact : 1.2.0 * countByKeyApprox : 1.0.0 * PairRDDFunctions * countApproxDistinctByKey : 1.1.0 * RDD * countApprox, countByValueApprox, countApproxDistinct : 1.0.0 * JavaRDDLike * countApprox : 1.0.0 * PythonHadoopUtil.Converter : 1.1.0 * PortableDataStream : 1.2.0 (related to binaryFiles) * BoundedDouble : 1.0.0 * PartialResult : 1.0.0 * StreamingContext, JavaStreamingContext * binaryRecordsStream : 1.2.0 * HiveContext * analyze : 1.2.0 Author: Sean Owen Closes #9396 from srowen/SPARK-11440. --- .../src/main/scala/org/apache/spark/SparkContext.scala | 10 +--------- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 7 ------- .../scala/org/apache/spark/api/java/JavaPairRDD.scala | 9 --------- .../scala/org/apache/spark/api/java/JavaRDDLike.scala | 5 ----- .../org/apache/spark/api/java/JavaSparkContext.scala | 7 ------- .../org/apache/spark/api/python/PythonHadoopUtil.scala | 3 --- .../org/apache/spark/input/PortableDataStream.scala | 2 -- .../scala/org/apache/spark/partial/BoundedDouble.scala | 4 ---- .../scala/org/apache/spark/partial/PartialResult.scala | 3 --- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 ---- .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 7 ------- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +------- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 2 -- .../org/apache/spark/streaming/StreamingContext.scala | 3 --- .../streaming/api/java/JavaStreamingContext.scala | 3 --- 15 files changed, 2 insertions(+), 75 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a6857b4c7d882..7421821e2601b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -45,7 +45,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} @@ -870,8 +870,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * :: Experimental :: - * * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file * (useful for binary data) * @@ -902,7 +900,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - @Experimental def binaryFiles( path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { @@ -922,8 +919,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * :: Experimental :: - * * Load data from a flat binary file, assuming the length of each record is constant. * * '''Note:''' We ensure that the byte array for each record in the resulting RDD @@ -936,7 +931,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @return An RDD of data with values, represented as byte arrays */ - @Experimental def binaryRecords( path: String, recordLength: Int, @@ -1963,10 +1957,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * :: Experimental :: * Submit a job for execution and return a FutureJob holding the result. */ - @Experimental def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index a650df605b92e..c32aefac465bc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -24,7 +24,6 @@ import scala.reflect.ClassTag import org.apache.spark.Partitioner import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -209,25 +208,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) srdd.meanApprox(timeout, confidence) /** - * :: Experimental :: * Approximate operation to return the mean within a timeout. */ - @Experimental def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] = srdd.sumApprox(timeout, confidence) /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 8344f6368ac48..0b0c6e5bb8cc1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} @@ -159,7 +158,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -169,14 +167,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need * two additional passes. */ - @Experimental def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -188,7 +184,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * Use Utils.random.nextLong as the default seed for the random number generator. */ - @Experimental def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) @@ -300,20 +295,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index fc817cdd6a3f8..871be0b1f39ea 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -28,7 +28,6 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap @@ -436,20 +435,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def count(): Long = rdd.count() /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = rdd.countApprox(timeout, confidence) /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox(timeout: Long): PartialResult[BoundedDouble] = rdd.countApprox(timeout) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 609496ccdfef1..4f54cd69e2175 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.AccumulatorParam._ -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} @@ -266,8 +265,6 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.binaryFiles(path, minPartitions)) /** - * :: Experimental :: - * * Read a directory of binary files from HDFS, a local file system (available on all nodes), * or any Hadoop-supported file system URI as a byte array. Each file is read as a single * record and returned in a key-value pair, where the key is the path of each file, @@ -294,19 +291,15 @@ class JavaSparkContext(val sc: SparkContext) * * @note Small files are preferred; very large files but may cause bad performance. */ - @Experimental def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) /** - * :: Experimental :: - * * Load data from a flat binary file, assuming the length of each record is constant. * * @param path Directory to the input data files * @return An RDD of data with values, represented as byte arrays */ - @Experimental def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { new JavaRDD(sc.binaryRecords(path, recordLength)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index a7dfa1d257cf2..d2beef2a0dd43 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -24,17 +24,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} /** - * :: Experimental :: * A trait for use with reading custom classes in PySpark. Implement this trait and add custom * transformation code by overriding the convert method. */ -@Experimental trait Converter[T, + U] extends Serializable { def convert(obj: T): U } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index e2ffc3b64e5db..33e4ee0215817 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,7 +27,6 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil /** @@ -129,7 +128,6 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat * @note TaskAttemptContext is not serializable resulting in the confBytes construct * @note CombineFileSplit is not serializable resulting in the splitBytes construct */ -@Experimental class PortableDataStream( isplit: CombineFileSplit, context: TaskAttemptContext, diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index aed0353344427..48b9434153172 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -17,13 +17,9 @@ package org.apache.spark.partial -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: * A Double value with error bars and associated confidence. */ -@Experimental class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { override def toString(): String = "[%.3f, %.3f]".format(low, high) } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index 53c4b32c95ab3..25cb7490aa9c9 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -17,9 +17,6 @@ package org.apache.spark.partial -import org.apache.spark.annotation.Experimental - -@Experimental class PartialResult[R](initialVal: R, isFinal: Boolean) { private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None private var failure: Option[Exception] = None diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 926bce6f15a2a..7fbaadcea3a3b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -74,10 +74,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } /** - * :: Experimental :: * Approximate operation to return the mean within a timeout. */ - @Experimental def meanApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { @@ -87,10 +85,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } /** - * :: Experimental :: * Approximate operation to return the sum within a timeout. */ - @Experimental def sumApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = self.withScope { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a981b63942e6d..c6181902ace6d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -274,7 +274,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * ::Experimental:: * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * @@ -289,7 +288,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * @param seed seed for the random number generator * @return RDD containing the sampled subset */ - @Experimental def sampleByKeyExact( withReplacement: Boolean, fractions: Map[K, Double], @@ -384,19 +382,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * :: Experimental :: * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ - @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[Map[K, BoundedDouble]] = self.withScope { self.map(_._1).countByValueApprox(timeout, confidence) } /** - * :: Experimental :: - * * Return approximate number of distinct values for each key in this RDD. * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: @@ -413,7 +407,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * If `sp` equals 0, the sparse representation is skipped. * @param partitioner Partitioner to use for the resulting RDD. */ - @Experimental def countApproxDistinctByKey( p: Int, sp: Int, diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a97bb174438a5..800ef53cbef07 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -1119,11 +1119,9 @@ abstract class RDD[T: ClassTag]( def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum /** - * :: Experimental :: * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ - @Experimental def countApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope { @@ -1152,10 +1150,8 @@ abstract class RDD[T: ClassTag]( } /** - * :: Experimental :: * Approximate version of countByValue(). */ - @Experimental def countByValueApprox(timeout: Long, confidence: Double = 0.95) (implicit ord: Ordering[T] = null) : PartialResult[Map[T, BoundedDouble]] = withScope { @@ -1174,7 +1170,6 @@ abstract class RDD[T: ClassTag]( } /** - * :: Experimental :: * Return approximate number of distinct elements in the RDD. * * The algorithm used is based on streamlib's implementation of "HyperLogLog in Practice: @@ -1190,7 +1185,6 @@ abstract class RDD[T: ClassTag]( * @param sp The precision value for the sparse set, between 0 and 32. * If `sp` equals 0, the sparse representation is skipped. */ - @Experimental def countApproxDistinct(p: Int, sp: Int): Long = withScope { require(p >= 4, s"p ($p) must be >= 4") require(sp <= 32, s"sp ($sp) must be <= 32") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1f5135320326c..670d6a78e36e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -36,7 +36,6 @@ import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ @@ -356,7 +355,6 @@ class HiveContext private[hive]( * * @since 1.2.0 */ - @Experimental def analyze(tableName: String) { val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 051f53de64cd5..97113835f3bd0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -445,8 +445,6 @@ class StreamingContext private[streaming] ( } /** - * :: Experimental :: - * * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as flat binary files, assuming a fixed length per record, * generating one byte array per record. Files must be written to the monitored directory @@ -459,7 +457,6 @@ class StreamingContext private[streaming] ( * @param directory HDFS directory to monitor for new file * @param recordLength length of each record in bytes */ - @Experimental def binaryRecordsStream( directory: String, recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 13f371f29603a..8f21c79a760c1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -222,8 +222,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { } /** - * :: Experimental :: - * * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as flat binary files with fixed record lengths, * yielding byte arrays @@ -234,7 +232,6 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * @param directory HDFS directory to monitor for new files * @param recordLength The length at which to split the records */ - @Experimental def binaryRecordsStream(directory: String, recordLength: Int): JavaDStream[Array[Byte]] = { ssc.binaryRecordsStream(directory, recordLength) } From 859dff56eb0f8c63c86e7e900a12340c199e6247 Mon Sep 17 00:00:00 2001 From: Nick Evans Date: Thu, 5 Nov 2015 09:18:20 +0000 Subject: [PATCH 183/324] [SPARK-11378][STREAMING] make StreamingContext.awaitTerminationOrTimeout return properly This adds a failing test checking that `awaitTerminationOrTimeout` returns the expected value, and then fixes that failing test with the addition of a `return`. tdas zsxwing Author: Nick Evans Closes #9336 from manygrams/fix_await_termination_or_timeout. --- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/tests.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 975c75473214a..8be56c9915265 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -218,7 +218,7 @@ def awaitTerminationOrTimeout(self, timeout): @param timeout: time to wait in seconds """ - self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) + return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) def stop(self, stopSparkContext=True, stopGraceFully=False): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index f7fa481d50235..179479625bca4 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -596,6 +596,13 @@ def setupFunc(): self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) self.assertTrue(self.setupCalled) + def test_await_termination_or_timeout(self): + self._add_input_stream() + self.ssc.start() + self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001)) + self.ssc.stop(False) + self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001)) + class CheckpointTests(unittest.TestCase): From 7bdc92197cce0edc0110dc9c2158e6e3f42c72ee Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Nov 2015 09:23:09 +0000 Subject: [PATCH 184/324] [SPARK-11449][CORE] PortableDataStream should be a factory ```PortableDataStream``` maintains some internal state. This makes it tricky to reuse a stream (one needs to call ```close``` on both the ```PortableDataStream``` and the ```InputStream``` it produces). This PR removes all state from ```PortableDataStream``` and effectively turns it into an ```InputStream```/```Array[Byte]``` factory. This makes the user responsible for managing the ```InputStream``` it returns. cc srowen Author: Herman van Hovell Closes #9417 from hvanhovell/SPARK-11449. --- .../spark/input/PortableDataStream.scala | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 33e4ee0215817..280e7a5fe893c 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.JavaConverters._ -import com.google.common.io.ByteStreams +import com.google.common.io.{Closeables, ByteStreams} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} @@ -82,7 +82,6 @@ private[spark] abstract class StreamBasedRecordReader[T]( if (!processed) { val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) - fileIn.close() // if it has not been open yet, close does nothing key = fileIn.getPath processed = true true @@ -134,12 +133,6 @@ class PortableDataStream( index: Integer) extends Serializable { - // transient forces file to be reopened after being serialization - // it is also used for non-serializable classes - - @transient private var fileIn: DataInputStream = null - @transient private var isOpen = false - private val confBytes = { val baos = new ByteArrayOutputStream() SparkHadoopUtil.get.getConfigurationFromJobContext(context). @@ -175,40 +168,34 @@ class PortableDataStream( } /** - * Create a new DataInputStream from the split and context + * Create a new DataInputStream from the split and context. The user of this method is responsible + * for closing the stream after usage. */ def open(): DataInputStream = { - if (!isOpen) { - val pathp = split.getPath(index) - val fs = pathp.getFileSystem(conf) - fileIn = fs.open(pathp) - isOpen = true - } - fileIn + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fs.open(pathp) } /** * Read the file as a byte array */ def toArray(): Array[Byte] = { - open() - val innerBuffer = ByteStreams.toByteArray(fileIn) - close() - innerBuffer + val stream = open() + try { + ByteStreams.toByteArray(stream) + } finally { + Closeables.close(stream, true) + } } /** - * Close the file (if it is currently open) + * Closing the PortableDataStream is not needed anymore. The user either can use the + * PortableDataStream to get a DataInputStream (which the user needs to close after usage), + * or a byte array. */ + @deprecated("Closing the PortableDataStream is not needed anymore.", "1.6.0") def close(): Unit = { - if (isOpen) { - try { - fileIn.close() - isOpen = false - } catch { - case ioe: java.io.IOException => // do nothing - } - } } def getPath(): String = path From a94671a027c29bacea37f56b95eccb115638abff Mon Sep 17 00:00:00 2001 From: a1singh Date: Thu, 5 Nov 2015 12:51:10 +0000 Subject: [PATCH 185/324] [SPARK-11506][MLLIB] Removed redundant operation in Online LDA implementation In file LDAOptimizer.scala: line 441: since "idx" was never used, replaced unrequired zipWithIndex.foreach with foreach. - nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + nonEmptyDocs.foreach { case (_, termCounts: Vector) => Author: a1singh Closes #9456 from a1singh/master. --- .../scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 38486e949bbcf..17c0609800e90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -438,7 +438,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val stat = BDM.zeros[Double](k, vocabSize) var gammaPart = List[BDV[Double]]() - nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + nonEmptyDocs.foreach { case (_, termCounts: Vector) => val ids: List[Int] = termCounts match { case v: DenseVector => (0 until v.size).toList case v: SparseVector => v.indices.toList From 77488fb8e586103ba4c0858b73e1715f1a66671f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 5 Nov 2015 23:49:44 +0800 Subject: [PATCH 186/324] [MINOR][SQL] A minor log line fix `jars` in the log line is an array, so `$jars` doesn't print its content. Author: Cheng Lian Closes #9494 from liancheng/minor.log-fix. --- .../src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 670d6a78e36e4..2d72b959af134 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -309,7 +309,8 @@ class HiveContext private[hive]( .map(_.toURI.toURL) logInfo( - s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using $jars") + s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion " + + s"using ${jars.mkString(":")}") new IsolatedClientLoader( version = metaVersion, execJars = jars.toSeq, From 72634f27e3110fd7f5bfca498752f69d0b1f873c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 5 Nov 2015 08:59:06 -0800 Subject: [PATCH 187/324] [MINOR][ML][DOC] Rename weights to coefficients in user guide We should use ```coefficients``` rather than ```weights``` in user guide that freshman can get the right conventional name at the outset. mengxr vectorijk Author: Yanbo Liang Closes #9493 from yanboliang/docs-coefficients. --- docs/ml-linear-methods.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 4e94e2f9c708d..16e2ee71293ae 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -71,8 +71,8 @@ val lr = new LogisticRegression() // Fit the model val lrModel = lr.fit(training) -// Print the weights and intercept for logistic regression -println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") +// Print the coefficients and intercept for logistic regression +println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") {% endhighlight %} @@ -105,8 +105,8 @@ public class LogisticRegressionWithElasticNetExample { // Fit the model LogisticRegressionModel lrModel = lr.fit(training); - // Print the weights and intercept for logistic regression - System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + // Print the coefficients and intercept for logistic regression + System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); } } {% endhighlight %} @@ -124,8 +124,8 @@ lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) # Fit the model lrModel = lr.fit(training) -# Print the weights and intercept for logistic regression -print("Weights: " + str(lrModel.weights)) +# Print the coefficients and intercept for logistic regression +print("Coefficients: " + str(lrModel.coefficients)) print("Intercept: " + str(lrModel.intercept)) {% endhighlight %} @@ -258,8 +258,8 @@ val lr = new LinearRegression() // Fit the model val lrModel = lr.fit(training) -// Print the weights and intercept for linear regression -println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") +// Print the coefficients and intercept for linear regression +println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") // Summarize the model over the training set and print out some metrics val trainingSummary = lrModel.summary @@ -302,8 +302,8 @@ public class LinearRegressionWithElasticNetExample { // Fit the model LinearRegressionModel lrModel = lr.fit(training); - // Print the weights and intercept for linear regression - System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + // Print the coefficients and intercept for linear regression + System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept()); // Summarize the model over the training set and print out some metrics LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); @@ -330,8 +330,8 @@ lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) # Fit the model lrModel = lr.fit(training) -# Print the weights and intercept for linear regression -print("Weights: " + str(lrModel.weights)) +# Print the coefficients and intercept for linear regression +print("Coefficients: " + str(lrModel.coefficients)) print("Intercept: " + str(lrModel.intercept)) # Linear regression model summary is not yet supported in Python. From 2e86cf1b01ae0ed69f72bf8054330440d432eeb7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 5 Nov 2015 09:00:03 -0800 Subject: [PATCH 188/324] [SPARK-11527][ML][PYSPARK] PySpark AFTSurvivalRegressionModel should expose coefficients/intercept/scale PySpark ```AFTSurvivalRegressionModel``` should expose coefficients/intercept/scale. mengxr vectorijk Author: Yanbo Liang Closes #9492 from yanboliang/spark-11527. --- python/pyspark/ml/regression.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ab26616f4a01d..d7b4fd92c3817 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -824,6 +824,30 @@ class AFTSurvivalRegressionModel(JavaModel): .. versionadded:: 1.6.0 """ + @property + @since("1.6.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("1.6.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + @property + @since("1.6.0") + def scale(self): + """ + Model scale paramter. + """ + return self._call_java("scale") + def predictQuantiles(self, features): """ Predicted Quantiles From a4b5cefcf1a196e6b257e6127d6b43a7e50200ac Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Thu, 5 Nov 2015 09:35:49 -0800 Subject: [PATCH 189/324] [SPARK-11501][CORE][YARN] Propagate spark.rpc config to executors spark.rpc is supposed to be configurable but is not currently (doesn't get propagated to executors because RpcEnv.create is done before driver properties are fetched). Author: Nishkam Ravi Closes #9460 from nishkamravi2/master_akka. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f023e4b21cb40..19633a3ce6a02 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -629,6 +629,7 @@ private[spark] object SparkConf extends Logging { name.startsWith("spark.akka") || (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || + name.startsWith("spark.rpc") || isSparkPortConf(name) } From b072ff4d1d05fc212cd7036d1897a032a395f0b3 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 5 Nov 2015 09:41:14 -0800 Subject: [PATCH 190/324] [SPARK-11474][SQL] change fetchSize to fetchsize In DefaultDataSource.scala, it has override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation The parameters is CaseInsensitiveMap. After this line parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) properties is set to all lower case key/value pairs and fetchSize becomes fetchsize. However, in compute method in JDBCRDD, it has val fetchSize = properties.getProperty("fetchSize", "0").toInt so fetchSize value is always 0 and never gets set correctly. Author: Huaxin Gao Closes #9473 from huaxingao/spark-11474. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 730d88b024cb1..018a009fbda6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -347,6 +347,7 @@ private[sql] class JDBCRDD( /** * Runs the SQL query against the JDBC driver. + * */ override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = new Iterator[InternalRow] { @@ -368,7 +369,7 @@ private[sql] class JDBCRDD( val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty("fetchSize", "0").toInt + val fetchSize = properties.getProperty("fetchsize", "0").toInt stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() From 9da7ceed81b0afce7deb8f39f3a6d565d401a391 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 5 Nov 2015 09:56:18 -0800 Subject: [PATCH 191/324] [SPARK-11473][ML] R-like summary statistics with intercept for OLS via normal equation solver Follow up [SPARK-9836](https://issues.apache.org/jira/browse/SPARK-9836), we should also support summary statistics for ```intercept```. Author: Yanbo Liang Closes #9485 from yanboliang/spark-11473. --- .../spark/ml/optim/WeightedLeastSquares.scala | 35 ++++++++++--------- .../ml/regression/LinearRegression.scala | 22 +++++++----- .../ml/regression/LinearRegressionSuite.scala | 16 ++++----- python/pyspark/ml/regression.py | 16 ++++----- 4 files changed, 48 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index e612a2122ed62..8617722ae542f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -75,7 +75,7 @@ private[ml] class WeightedLeastSquares( val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() logInfo(s"Number of instances: ${summary.count}.") - val k = summary.k + val k = if (fitIntercept) summary.k + 1 else summary.k val triK = summary.triK val wSum = summary.wSum val bBar = summary.bBar @@ -86,14 +86,6 @@ private[ml] class WeightedLeastSquares( val aaBar = summary.aaBar val aaValues = aaBar.values - if (fitIntercept) { - // shift centers - // A^T A - aBar aBar^T - BLAS.spr(-1.0, aBar, aaValues) - // A^T b - bBar aBar - BLAS.axpy(-bBar, aBar, abBar) - } - // add regularization to diagonals var i = 0 var j = 2 @@ -111,21 +103,32 @@ private[ml] class WeightedLeastSquares( j += 1 } - val x = new DenseVector(CholeskyDecomposition.solve(aaBar.values, abBar.values)) + val aa = if (fitIntercept) { + Array.concat(aaBar.values, aBar.values, Array(1.0)) + } else { + aaBar.values + } + val ab = if (fitIntercept) { + Array.concat(abBar.values, Array(bBar)) + } else { + abBar.values + } + + val x = CholeskyDecomposition.solve(aa, ab) + + val aaInv = CholeskyDecomposition.inverse(aa, k) - val aaInv = CholeskyDecomposition.inverse(aaBar.values, k) // aaInv is a packed upper triangular matrix, here we get all elements on diagonal val diagInvAtWA = new DenseVector((1 to k).map { i => aaInv(i + (i - 1) * i / 2 - 1) / wSum }.toArray) - // compute intercept - val intercept = if (fitIntercept) { - bBar - BLAS.dot(aBar, x) + val (coefficients, intercept) = if (fitIntercept) { + (new DenseVector(x.slice(0, x.length - 1)), x.last) } else { - 0.0 + (new DenseVector(x), 0.0) } - new WeightedLeastSquaresModel(x, intercept, diagInvAtWA) + new WeightedLeastSquaresModel(coefficients, intercept, diagInvAtWA) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c51e30483ab3d..6638313818703 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -511,8 +511,7 @@ class LinearRegressionSummary private[regression] ( } /** - * Standard error of estimated coefficients. - * Note that standard error of estimated intercept is not supported currently. + * Standard error of estimated coefficients and intercept. */ lazy val coefficientStandardErrors: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -532,21 +531,26 @@ class LinearRegressionSummary private[regression] ( } } - /** T-statistic of estimated coefficients. - * Note that t-statistic of estimated intercept is not supported currently. - */ + /** + * T-statistic of estimated coefficients and intercept. + */ lazy val tValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { throw new UnsupportedOperationException( "No t-statistic available for this LinearRegressionModel") } else { - model.coefficients.toArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + val estimate = if (model.getFitIntercept) { + Array.concat(model.coefficients.toArray, Array(model.intercept)) + } else { + model.coefficients.toArray + } + estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } } - /** Two-sided p-value of estimated coefficients. - * Note that p-value of estimated intercept is not supported currently. - */ + /** + * Two-sided p-value of estimated coefficients and intercept. + */ lazy val pValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { throw new UnsupportedOperationException( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index fbf83e8922861..a1d86fe8fedad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -621,13 +621,13 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) val devianceResidualsR = Array(-0.35566, 0.34504) - val seCoefR = Array(0.0011756, 0.0009032) - val tValsR = Array(3998, 7971) - val pValsR = Array(0, 0) + val seCoefR = Array(0.0011756, 0.0009032, 0.0018489) + val tValsR = Array(3998, 7971, 3407) + val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => - assert(x._1 ~== x._2 absTol 1E-3) } + assert(x._1 ~== x._2 absTol 1E-5) } model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => - assert(x._1 ~== x._2 absTol 1E-3) } + assert(x._1 ~== x._2 absTol 1E-5) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } @@ -789,9 +789,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val coefficientsR = Vectors.dense(Array(6.080, -0.600)) val interceptR = 18.080 val devianceResidualsR = Array(-1.358, 1.920) - val seCoefR = Array(5.556, 1.960) - val tValsR = Array(1.094, -0.306) - val pValsR = Array(0.471, 0.811) + val seCoefR = Array(5.556, 1.960, 9.608) + val tValsR = Array(1.094, -0.306, 1.882) + val pValsR = Array(0.471, 0.811, 0.311) assert(model.coefficients ~== coefficientsR absTol 1E-3) assert(model.intercept ~== interceptR absTol 1E-3) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d7b4fd92c3817..7648bf13266bf 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -55,15 +55,15 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) - >>> model.transform(test0).head().prediction - -1.0 - >>> model.weights - DenseVector([1.0]) - >>> model.intercept - 0.0 + >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 + True + >>> abs(model.coefficients[0] - 1.0) < 0.001 + True + >>> abs(model.intercept - 0.0) < 0.001 + True >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) - >>> model.transform(test1).head().prediction - 1.0 + >>> abs(model.transform(test1).head().prediction - 1.0) < 0.001 + True >>> lr.setParams("vector") Traceback (most recent call last): ... From c76865c6220e3e7b2a266bbc4935567ef55303d8 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 5 Nov 2015 11:30:44 -0800 Subject: [PATCH 192/324] [SPARK-11484][WEBUI] Using proxyBase set by spark AM Use the proxyBase set by the AM, if not found then use env. This is to fix the issue if somebody accidentally set APPLICATION_WEB_PROXY_BASE to wrong proxyBase Author: Srinivasa Reddy Vundela Closes #9448 from vundela/master. --- .../src/main/scala/org/apache/spark/ui/UIUtils.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 68a9f912a5d2c..25dcb604d9e5f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -143,14 +143,10 @@ private[spark] object UIUtils extends Logging { // Yarn has to go through a proxy so the base uri is provided and has to be on all links def uiRoot: String = { - if (System.getenv("APPLICATION_WEB_PROXY_BASE") != null) { - System.getenv("APPLICATION_WEB_PROXY_BASE") - } else if (System.getProperty("spark.ui.proxyBase") != null) { - System.getProperty("spark.ui.proxyBase") - } - else { - "" - } + // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. + sys.props.get("spark.ui.proxyBase") + .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .getOrElse("") } def prependBaseUri(basePath: String = "", resource: String = ""): String = { From 6b87acd6649a3390b5c2c4fcb61e58d125d0d87c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 11:58:13 -0800 Subject: [PATCH 193/324] [SPARK-11513][SQL] Remove implicit conversion from LogicalPlan to DataFrame This internal implicit conversion has been a source of confusion for a lot of new developers. Author: Reynold Xin Closes #9479 from rxin/SPARK-11513. --- .../org/apache/spark/sql/DataFrame.scala | 123 +++++++++++------- .../scala/org/apache/spark/sql/Dataset.scala | 5 +- 2 files changed, 78 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d3a2249d7006c..6336dee7be6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -147,14 +147,6 @@ class DataFrame private[sql]( queryExecution.analyzed } - /** - * An implicit conversion function internal to this class for us to avoid doing - * "new DataFrame(...)" everywhere. - */ - @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } - protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -235,7 +227,7 @@ class DataFrame private[sql]( // For Data that has more than "numRows" records if (hasMoreData) { val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows ${rowsString}\n") + sb.append(s"only showing top $numRows $rowsString\n") } sb.toString() @@ -332,7 +324,7 @@ class DataFrame private[sql]( */ def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) - explain.queryExecution.executedPlan.executeCollect().foreach { + withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println @@ -370,7 +362,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = show(numRows, true) + def show(numRows: Int): Unit = show(numRows, truncate = true) /** * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters @@ -445,7 +437,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def join(right: DataFrame): DataFrame = { + def join(right: DataFrame): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } @@ -520,21 +512,25 @@ class DataFrame private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] // Project only one of the join columns. - val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) val condition = usingColumns.map { col => - catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + catalyst.expressions.EqualTo( + withPlan(joined.left).resolve(col), + withPlan(joined.right).resolve(col)) }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => catalyst.expressions.And(cond, eqTo) } - Project( - joined.output.filterNot(joinedCols.contains(_)), - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) + withPlan { + Project( + joined.output.filterNot(joinedCols.contains(_)), + Join( + joined.left, + joined.right, + joinType = JoinType(joinType), + condition) + ) + } } /** @@ -581,19 +577,20 @@ class DataFrame private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. - val plan = Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + val plan = withPlan( + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return plan + return withPlan(plan) } // If left/right have no output set intersection, return the plan. - val lanalyzed = this.logicalPlan.queryExecution.analyzed - val ranalyzed = right.logicalPlan.queryExecution.analyzed + val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed + val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return plan + return withPlan(plan) } // Otherwise, find the trivially true predicates and automatically resolves them to both sides. @@ -602,9 +599,14 @@ class DataFrame private[sql]( val cond = plan.condition.map { _.transform { case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + catalyst.expressions.EqualTo( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} - plan.copy(condition = cond) + + withPlan { + plan.copy(condition = cond) + } } /** @@ -707,7 +709,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def as(alias: String): DataFrame = Subquery(alias, logicalPlan) + def as(alias: String): DataFrame = withPlan { + Subquery(alias, logicalPlan) + } /** * (Scala-specific) Returns a new [[DataFrame]] with an alias set. @@ -739,7 +743,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = { + def select(cols: Column*): DataFrame = withPlan { val namedExpressions = cols.map { // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to @@ -798,7 +802,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan) + def filter(condition: Column): DataFrame = withPlan { + Filter(condition.expr, logicalPlan) + } /** * Filters rows using the given SQL expression. @@ -1039,7 +1045,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + def limit(n: Int): DataFrame = withPlan { + Limit(Literal(n), logicalPlan) + } /** * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. @@ -1047,7 +1055,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + def unionAll(other: DataFrame): DataFrame = withPlan { + Union(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. @@ -1055,7 +1065,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + def intersect(other: DataFrame): DataFrame = withPlan { + Intersect(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. @@ -1063,7 +1075,9 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + def except(other: DataFrame): DataFrame = withPlan { + Except(logicalPlan, other.logicalPlan) + } /** * Returns a new [[DataFrame]] by sampling a fraction of rows. @@ -1074,7 +1088,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan) } @@ -1102,7 +1116,7 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) }.toArray } @@ -1162,8 +1176,10 @@ class DataFrame private[sql]( f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } /** @@ -1190,8 +1206,10 @@ class DataFrame private[sql]( } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + withPlan { + Generate(generator, join = true, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } ///////////////////////////////////////////////////////////////////////////// @@ -1309,7 +1327,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(colNames: Seq[String]): DataFrame = { + def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan { val groupCols = colNames.map(resolve) val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => @@ -1355,7 +1373,7 @@ class DataFrame private[sql]( * @since 1.3.1 */ @scala.annotation.varargs - def describe(cols: String*): DataFrame = { + def describe(cols: String*): DataFrame = withPlan { // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( @@ -1505,7 +1523,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def repartition(numPartitions: Int): DataFrame = { + def repartition(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = true, logicalPlan) } @@ -1519,7 +1537,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = { + def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) } @@ -1533,7 +1551,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): DataFrame = { + def repartition(partitionExprs: Column*): DataFrame = withPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) } @@ -1545,7 +1563,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - def coalesce(numPartitions: Int): DataFrame = { + def coalesce(numPartitions: Int): DataFrame = withPlan { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -2066,7 +2084,14 @@ class DataFrame private[sql]( SortOrder(expr, Ascending) } } - Sort(sortOrder, global = global, logicalPlan) + withPlan { + Sort(sortOrder, global = global, logicalPlan) + } + } + + /** A convenient function to wrap a logical plan and produce a DataFrame. */ + @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b75aeec4cf3a..500227e93a472 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -107,13 +107,16 @@ class Dataset[T] private( * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] * objects that allow fields to be accessed by ordinal or name. */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) - /** * Returns this Dataset. * @since 1.6.0 */ + // This is declared with parentheses to prevent the Scala compiler from treating + // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset. def toDS(): Dataset[T] = this /** From f80f7b69a3f81d0ea879a31c769d17ffbbac74aa Mon Sep 17 00:00:00 2001 From: "Ehsan M.Kermani" Date: Thu, 5 Nov 2015 12:11:57 -0800 Subject: [PATCH 194/324] [SPARK-10265][DOCUMENTATION, ML] Fixed @Since annotation to ml.regression Here is my first commit. Author: Ehsan M.Kermani Closes #8728 from ehsanmok/SinceAnn. --- .../ml/regression/DecisionTreeRegressor.scala | 20 +++++++++-- .../spark/ml/regression/GBTRegressor.scala | 33 ++++++++++++++++--- .../ml/regression/IsotonicRegression.scala | 26 +++++++++++++-- .../ml/regression/LinearRegression.scala | 28 ++++++++++++++-- .../ml/regression/RandomForestRegressor.scala | 30 ++++++++++++++--- 5 files changed, 119 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 88b79a4eb82be..04420fc6e8251 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} @@ -36,30 +36,39 @@ import org.apache.spark.sql.DataFrame * for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeRegressor(override val uid: String) +final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) // Override parameter setters from parent trait for Java API compatibility. - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { @@ -78,9 +87,11 @@ final class DecisionTreeRegressor(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ @@ -93,6 +104,7 @@ object DecisionTreeRegressor { * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ +@Since("1.4.0") @Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, @@ -115,10 +127,12 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 65b5b3e0727df..07144cc7cfbd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} @@ -42,54 +42,65 @@ import org.apache.spark.sql.types.DoubleType * learning algorithm for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class GBTRegressor(override val uid: String) +final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTRegressor.setImpurity should NOT be used") this } // Parameters from TreeEnsembleParams: - + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) } // Parameters from GBTParams: - + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTRegressor: @@ -100,6 +111,7 @@ final class GBTRegressor(override val uid: String) * (default = squared) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", @@ -108,9 +120,11 @@ final class GBTRegressor(override val uid: String) setDefault(lossType -> "squared") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -135,13 +149,16 @@ final class GBTRegressor(override val uid: String) GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) } @@ -154,6 +171,7 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.4.0") @Experimental final class GBTRegressionModel private[ml]( override val uid: String, @@ -172,11 +190,14 @@ final class GBTRegressionModel private[ml]( * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ + @Since("1.4.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -194,11 +215,13 @@ final class GBTRegressionModel private[ml]( blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"GBTRegressionModel (uid=$uid) with $numTrees trees" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index f4a17c8f9a582..a1fe01b047108 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} @@ -124,32 +124,42 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Since("1.5.0") @Experimental -class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel] - with IsotonicRegressionBase { +class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setIsotonic(value: Boolean): this.type = set(isotonic, value) /** @group setParam */ + @Since("1.5.0") def setWeightCol(value: String): this.type = set(weightCol, value) /** @group setParam */ + @Since("1.5.0") def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + @Since("1.5.0") override def fit(dataset: DataFrame): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. @@ -163,6 +173,7 @@ class IsotonicRegression(override val uid: String) extends Estimator[IsotonicReg copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true) } @@ -178,6 +189,7 @@ class IsotonicRegression(override val uid: String) extends Estimator[IsotonicReg * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Since("1.5.0") @Experimental class IsotonicRegressionModel private[ml] ( override val uid: String, @@ -185,27 +197,34 @@ class IsotonicRegressionModel private[ml] ( extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setFeatureIndex(value: Int): this.type = set(featureIndex, value) /** Boundaries in increasing order for which predictions are known. */ + @Since("1.5.0") def boundaries: Vector = Vectors.dense(oldModel.boundaries) /** * Predictions associated with the boundaries at the same index, monotone because of isotonic * regression. */ + @Since("1.5.0") def predictions: Vector = Vectors.dense(oldModel.predictions) + @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegressionModel = { copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => @@ -217,6 +236,7 @@ class IsotonicRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6638313818703..913140e581983 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -24,9 +24,9 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import breeze.stats.distributions.StudentsT import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim.WeightedLeastSquares +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -61,11 +61,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ +@Since("1.3.0") @Experimental -class LinearRegression(override val uid: String) +class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) /** @@ -73,6 +75,7 @@ class LinearRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.3.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -81,6 +84,7 @@ class LinearRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -93,6 +97,7 @@ class LinearRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) @@ -103,6 +108,7 @@ class LinearRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -111,6 +117,7 @@ class LinearRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.3.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -120,6 +127,7 @@ class LinearRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -129,6 +137,7 @@ class LinearRegression(override val uid: String) * Default is empty, so all instances have weight one. * @group setParam */ + @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") @@ -139,6 +148,7 @@ class LinearRegression(override val uid: String) * selected automatically. * @group setParam */ + @Since("1.6.0") def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") @@ -329,6 +339,7 @@ class LinearRegression(override val uid: String) model.setSummary(trainingSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } @@ -336,6 +347,7 @@ class LinearRegression(override val uid: String) * :: Experimental :: * Model produced by [[LinearRegression]]. */ +@Since("1.3.0") @Experimental class LinearRegressionModel private[ml] ( override val uid: String, @@ -355,6 +367,7 @@ class LinearRegressionModel private[ml] ( * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is * thrown if `trainingSummary == None`. */ + @Since("1.5.0") def summary: LinearRegressionTrainingSummary = trainingSummary match { case Some(summ) => summ case None => @@ -369,6 +382,7 @@ class LinearRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** @@ -402,6 +416,7 @@ class LinearRegressionModel private[ml] ( dot(features, coefficients) + intercept } + @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) @@ -416,6 +431,7 @@ class LinearRegressionModel private[ml] ( * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ +@Since("1.5.0") @Experimental class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, @@ -428,6 +444,7 @@ class LinearRegressionTrainingSummary private[regression] ( extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { /** Number of training iterations until termination */ + @Since("1.5.0") val totalIterations = objectiveHistory.length } @@ -437,6 +454,7 @@ class LinearRegressionTrainingSummary private[regression] ( * Linear regression results evaluated on a dataset. * @param predictions predictions outputted by the model's `transform` method. */ +@Since("1.5.0") @Experimental class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, @@ -455,33 +473,39 @@ class LinearRegressionSummary private[regression] ( * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] */ + @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. */ + @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. */ + @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. */ + @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ + @Since("1.5.0") val r2: Double = metrics.r2 /** Residuals (label - predicted value) */ + @Since("1.5.0") @transient lazy val residuals: DataFrame = { val t = udf { (pred: Double, label: Double) => label - pred } predictions.select(t(col(predictionCol), col(labelCol)).as("residuals")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 64fc17247cce6..71e40b513ee0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} @@ -37,44 +37,55 @@ import org.apache.spark.sql.functions._ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ +@Since("1.4.0") @Experimental -final class RandomForestRegressor(override val uid: String) +final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: - + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: - + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: - + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -91,15 +102,19 @@ final class RandomForestRegressor(override val uid: String) new RandomForestRegressionModel(trees, numFeatures) } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -111,6 +126,7 @@ object RandomForestRegressor { * @param _trees Decision trees in the ensemble. * @param numFeatures Number of features used by this model */ +@Since("1.4.0") @Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, @@ -128,11 +144,13 @@ final class RandomForestRegressionModel private[ml] ( private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -150,10 +168,12 @@ final class RandomForestRegressionModel private[ml] ( _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" } From 14ee0f5726f96e2c4c28ac328d43fd85a0630b48 Mon Sep 17 00:00:00 2001 From: Travis Hegner Date: Thu, 5 Nov 2015 12:35:23 -0800 Subject: [PATCH 195/324] [SPARK-10648] Oracle dialect to handle nonspecific numeric types This is the alternative/agreed upon solution to PR #8780. Creating an OracleDialect to handle the nonspecific numeric types that can be defined in oracle. Author: Travis Hegner Closes #9495 from travishegner/OracleDialect. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 88ae83957a708..f9a6a09b6270d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -139,6 +139,7 @@ object JdbcDialects { registerDialect(DB2Dialect) registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) + registerDialect(OracleDialect) /** @@ -315,3 +316,27 @@ case object DerbyDialect extends JdbcDialect { } +/** + * :: DeveloperApi :: + * Default Oracle dialect, mapping a nonspecific numeric type to a general decimal type. + */ +@DeveloperApi +case object OracleDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + if (sqlType == Types.NUMERIC && size == 0) { + // This is sub-optimal as we have to pick a precision/scale in advance whereas the data + // in Oracle is allowed to have different precision/scale for each value. + Some(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else { + None + } + } +} From 8a5314efd19fb8f8a194a373fd994b954cc1fd47 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 13:34:36 -0800 Subject: [PATCH 196/324] [SPARK-11532][SQL] Remove implicit conversion from Expression to Column Author: Reynold Xin Closes #9500 from rxin/SPARK-11532. --- .../scala/org/apache/spark/sql/Column.scala | 118 ++++++++++-------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c73f696962de5..c32c93897ce0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -68,7 +68,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { }) /** Creates a column based on the given expression. */ - implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: Expression): Column = new Column(newExpr) override def toString: String = expr.prettyString @@ -99,7 +99,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def apply(extraction: Any): Column = UnresolvedExtractValue(expr, lit(extraction).expr) + def apply(extraction: Any): Column = withExpr { + UnresolvedExtractValue(expr, lit(extraction).expr) + } /** * Unary minus, i.e. negate the expression. @@ -115,7 +117,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_- : Column = UnaryMinus(expr) + def unary_- : Column = withExpr { UnaryMinus(expr) } /** * Inversion of boolean expression, i.e. NOT. @@ -131,7 +133,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def unary_! : Column = Not(expr) + def unary_! : Column = withExpr { Not(expr) } /** * Equality test. @@ -147,7 +149,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def === (other: Any): Column = { + def === (other: Any): Column = withExpr { val right = lit(other).expr if (this.expr == right) { logWarning( @@ -188,7 +190,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + def !== (other: Any): Column = withExpr{ Not(EqualTo(expr, lit(other).expr)) } /** * Inequality test. @@ -205,7 +207,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group java_expr_ops * @since 1.3.0 */ - def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr)) + def notEqual(other: Any): Column = withExpr { Not(EqualTo(expr, lit(other).expr)) } /** * Greater than. @@ -221,7 +223,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def > (other: Any): Column = GreaterThan(expr, lit(other).expr) + def > (other: Any): Column = withExpr { GreaterThan(expr, lit(other).expr) } /** * Greater than. @@ -252,7 +254,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def < (other: Any): Column = LessThan(expr, lit(other).expr) + def < (other: Any): Column = withExpr { LessThan(expr, lit(other).expr) } /** * Less than. @@ -282,7 +284,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr) + def <= (other: Any): Column = withExpr { LessThanOrEqual(expr, lit(other).expr) } /** * Less than or equal to. @@ -312,7 +314,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr) + def >= (other: Any): Column = withExpr { GreaterThanOrEqual(expr, lit(other).expr) } /** * Greater than or equal to an expression. @@ -335,7 +337,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr) + def <=> (other: Any): Column = withExpr { EqualNullSafe(expr, lit(other).expr) } /** * Equality test that is safe for null values. @@ -368,7 +370,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => - CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) + withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) } case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -398,7 +400,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { - CaseWhen(branches :+ lit(value).expr) + withExpr { CaseWhen(branches :+ lit(value).expr) } } else { throw new IllegalArgumentException( "otherwise() can only be applied once on a Column previously generated by when()") @@ -424,7 +426,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.5.0 */ - def isNaN: Column = IsNaN(expr) + def isNaN: Column = withExpr { IsNaN(expr) } /** * True if the current expression is null. @@ -432,7 +434,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNull: Column = IsNull(expr) + def isNull: Column = withExpr { IsNull(expr) } /** * True if the current expression is NOT null. @@ -440,7 +442,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def isNotNull: Column = IsNotNull(expr) + def isNotNull: Column = withExpr { IsNotNull(expr) } /** * Boolean OR. @@ -455,7 +457,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def || (other: Any): Column = Or(expr, lit(other).expr) + def || (other: Any): Column = withExpr { Or(expr, lit(other).expr) } /** * Boolean OR. @@ -485,7 +487,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def && (other: Any): Column = And(expr, lit(other).expr) + def && (other: Any): Column = withExpr { And(expr, lit(other).expr) } /** * Boolean AND. @@ -515,7 +517,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def + (other: Any): Column = Add(expr, lit(other).expr) + def + (other: Any): Column = withExpr { Add(expr, lit(other).expr) } /** * Sum of this expression and another expression. @@ -545,7 +547,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def - (other: Any): Column = Subtract(expr, lit(other).expr) + def - (other: Any): Column = withExpr { Subtract(expr, lit(other).expr) } /** * Subtraction. Subtract the other expression from this expression. @@ -575,7 +577,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def * (other: Any): Column = Multiply(expr, lit(other).expr) + def * (other: Any): Column = withExpr { Multiply(expr, lit(other).expr) } /** * Multiplication of this expression and another expression. @@ -605,7 +607,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def / (other: Any): Column = Divide(expr, lit(other).expr) + def / (other: Any): Column = withExpr { Divide(expr, lit(other).expr) } /** * Division this expression by another expression. @@ -628,7 +630,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def % (other: Any): Column = Remainder(expr, lit(other).expr) + def % (other: Any): Column = withExpr { Remainder(expr, lit(other).expr) } /** * Modulo (a.k.a. remainder) expression. @@ -657,7 +659,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = In(expr, list.map(lit(_).expr)) + def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** * SQL like expression. @@ -665,7 +667,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def like(literal: String): Column = Like(expr, lit(literal).expr) + def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** * SQL RLIKE expression (LIKE with Regex). @@ -673,7 +675,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + def rlike(literal: String): Column = withExpr { RLike(expr, lit(literal).expr) } /** * An expression that gets an item at position `ordinal` out of an array, @@ -682,7 +684,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key)) + def getItem(key: Any): Column = withExpr { UnresolvedExtractValue(expr, Literal(key)) } /** * An expression that gets a field by name in a [[StructType]]. @@ -690,7 +692,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName)) + def getField(fieldName: String): Column = withExpr { + UnresolvedExtractValue(expr, Literal(fieldName)) + } /** * An expression that returns a substring. @@ -700,7 +704,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr) + def substr(startPos: Column, len: Column): Column = withExpr { + Substring(expr, startPos.expr, len.expr) + } /** * An expression that returns a substring. @@ -710,7 +716,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr) + def substr(startPos: Int, len: Int): Column = withExpr { + Substring(expr, lit(startPos).expr, lit(len).expr) + } /** * Contains the other element. @@ -718,7 +726,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def contains(other: Any): Column = Contains(expr, lit(other).expr) + def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** * String starts with. @@ -726,7 +734,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr) + def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** * String starts with another string literal. @@ -742,7 +750,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr) + def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** * String ends with another string literal. @@ -777,9 +785,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = expr match { - case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias)() + def as(alias: String): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } } /** @@ -792,7 +802,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Seq[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Assigns the given aliases to the results of a table generating function. @@ -804,7 +814,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + def as(aliases: Array[String]): Column = withExpr { MultiAlias(expr, aliases) } /** * Gives the column an alias. @@ -819,9 +829,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = expr match { - case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias.name)() + def as(alias: Symbol): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } } /** @@ -834,7 +846,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String, metadata: Metadata): Column = { + def as(alias: String, metadata: Metadata): Column = withExpr { Alias(expr, alias)(explicitMetadata = Some(metadata)) } @@ -852,10 +864,12 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = expr match { - // keeps the name of expression if possible when do cast. - case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) - case _ => Cast(expr, to) + def cast(to: DataType): Column = withExpr { + expr match { + // keeps the name of expression if possible when do cast. + case ne: NamedExpression => UnresolvedAlias(Cast(expr, to)) + case _ => Cast(expr, to) + } } /** @@ -885,7 +899,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def desc: Column = SortOrder(expr, Descending) + def desc: Column = withExpr { SortOrder(expr, Descending) } /** * Returns an ordering used in sorting. @@ -900,7 +914,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def asc: Column = SortOrder(expr, Ascending) + def asc: Column = withExpr { SortOrder(expr, Ascending) } /** * Prints the expression to the console for debugging purpose. @@ -927,7 +941,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseOR(other: Any): Column = BitwiseOr(expr, lit(other).expr) + def bitwiseOR(other: Any): Column = withExpr { BitwiseOr(expr, lit(other).expr) } /** * Compute bitwise AND of this expression with another expression. @@ -938,7 +952,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseAND(other: Any): Column = BitwiseAnd(expr, lit(other).expr) + def bitwiseAND(other: Any): Column = withExpr { BitwiseAnd(expr, lit(other).expr) } /** * Compute bitwise XOR of this expression with another expression. @@ -949,7 +963,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + def bitwiseXOR(other: Any): Column = withExpr { BitwiseXor(expr, lit(other).expr) } /** * Define a windowing column. From b9455d1f1810e1e3f472014f665ad3ad3122bcc0 Mon Sep 17 00:00:00 2001 From: adrian555 Date: Thu, 5 Nov 2015 14:47:38 -0800 Subject: [PATCH 197/324] [SPARK-11260][SPARKR] with() function support Author: adrian555 Author: Adrian Zhuang Closes #9443 from adrian555/with. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 30 ++++++++++++++++++++++++------ R/pkg/R/generics.R | 4 ++++ R/pkg/R/utils.R | 13 +++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 9 +++++++++ 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index cd9537a2655f0..56b8ed0bf271b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -83,6 +83,7 @@ exportMethods("arrange", "unique", "unpersist", "where", + "with", "withColumn", "withColumnRenamed", "write.df") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index df5bc8137187b..44ce9414da5cf 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2126,11 +2126,29 @@ setMethod("as.data.frame", setMethod("attach", signature(what = "DataFrame"), function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { - cols <- columns(what) - stopifnot(length(cols) > 0) - newEnv <- new.env() - for (i in 1:length(cols)) { - assign(x = cols[i], value = what[, cols[i]], envir = newEnv) - } + newEnv <- assignNewEnv(what) attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) }) + +#' Evaluate a R expression in an environment constructed from a DataFrame +#' with() allows access to columns of a DataFrame by simply referring to +#' their name. It appends every column of a DataFrame into a new +#' environment. Then, the given expression is evaluated in this new +#' environment. +#' +#' @rdname with +#' @title Evaluate a R expression in an environment constructed from a DataFrame +#' @param data (DataFrame) DataFrame to use for constructing an environment. +#' @param expr (expression) Expression to evaluate. +#' @param ... arguments to be passed to future methods. +#' @examples +#' \dontrun{ +#' with(irisDf, nrow(Sepal_Width)) +#' } +#' @seealso \link{attach} +setMethod("with", + signature(data = "DataFrame"), + function(data, expr, ...) { + newEnv <- assignNewEnv(data) + eval(substitute(expr), envir = newEnv, enclos = newEnv) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0b35340e48e42..083d37fee28a4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1043,3 +1043,7 @@ setGeneric("as.data.frame") #' @rdname attach #' @export setGeneric("attach") + +#' @rdname with +#' @export +setGeneric("with") diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0b9e2957fe9a5..db3b2c4bbd799 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -623,3 +623,16 @@ convertNamedListToEnv <- function(namedList) { } env } + +# Assign a new environment for attach() and with() methods +assignNewEnv <- function(data) { + stopifnot(class(data) == "DataFrame") + cols <- columns(data) + stopifnot(length(cols) > 0) + + env <- new.env() + for (i in 1:length(cols)) { + assign(x = cols[i], value = data[, cols[i]], envir = env) + } + env +} \ No newline at end of file diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b4a4d03b2643b..816315b1e4e13 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1494,6 +1494,15 @@ test_that("attach() on a DataFrame", { expect_error(age) }) +test_that("with() on a DataFrame", { + df <- createDataFrame(sqlContext, iris) + expect_error(Sepal_Length) + sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) + expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") + sum2 <- with(df, distinct(Sepal_Length)) + expect_equal(nrow(sum2), 35) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From d9e30c59cede7f57786bb19e64ba422eda43bdcb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 5 Nov 2015 14:53:16 -0800 Subject: [PATCH 198/324] [SPARK-10656][SQL] completely support special chars in DataFrame the main problem is: we interpret column name with special handling of `.` for DataFrame. This enables us to write something like `df("a.b")` to get the field `b` of `a`. However, we don't need this feature in `DataFrame.apply("*")` or `DataFrame.withColumnRenamed`. In these 2 cases, the column name is the final name already, we don't need extra process to interpret it. The solution is simple, use `queryExecution.analyzed.output` to get resolved column directly, instead of using `DataFrame.resolve`. close https://github.com/apache/spark/pull/8811 Author: Wenchen Fan Closes #9462 from cloud-fan/special-chars. --- .../scala/org/apache/spark/sql/DataFrame.scala | 16 ++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6336dee7be6a3..f2d4db5550273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -698,7 +698,7 @@ class DataFrame private[sql]( */ def col(colName: String): Column = colName match { case "*" => - Column(ResolvedStar(schema.fieldNames.map(resolve))) + Column(ResolvedStar(queryExecution.analyzed.output)) case _ => val expr = resolve(colName) Column(expr) @@ -1259,13 +1259,17 @@ class DataFrame private[sql]( */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldRename = schema.exists(f => resolver(f.name, existingName)) + val output = queryExecution.analyzed.output + val shouldRename = output.exists(f => resolver(f.name, existingName)) if (shouldRename) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, existingName)) Column(name).as(newName) else Column(name) + val columns = output.map { col => + if (resolver(col.name, existingName)) { + Column(col).as(newName) + } else { + Column(col) + } } - select(colNames : _*) + select(columns : _*) } else { this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 84a616d0b9081..f3a7aa280367a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1128,4 +1128,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-10656: completely support special chars") { + val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") + checkAnswer(df.select(df("*")), Row(1, "a")) + checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) + } } From b6974f8fed1726a381636e996834111a8e7ced8d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 15:34:05 -0800 Subject: [PATCH 199/324] [SPARK-11536][SQL] Remove the internal implicit conversion from Expression to Column in functions.scala Author: Reynold Xin Closes #9505 from rxin/SPARK-11536. --- .../org/apache/spark/sql/functions.scala | 580 +++++++++--------- 1 file changed, 299 insertions(+), 281 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c70c965a9b04c..04627589886a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -51,7 +51,7 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + private def withExpr(expr: Expression): Column = Column(expr) /** * Returns a [[Column]] based on the given column name. @@ -128,7 +128,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -144,7 +144,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withExpr { + ApproxCountDistinct(e.expr, rsd) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -162,7 +164,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = Average(e.expr) + def avg(e: Column): Column = withExpr { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -178,8 +180,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = + def corr(column1: Column, column2: Column): Column = withExpr { Corr(column1.expr, column2.expr) + } /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. @@ -187,8 +190,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(columnName1: String, columnName2: String): Column = + def corr(columnName1: String, columnName2: String): Column = { corr(Column(columnName1), Column(columnName2)) + } /** * Aggregate function: returns the number of items in a group. @@ -196,10 +200,12 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = e.expr match { - // Turn count(*) into count(1) - case s: Star => Count(Literal(1)) - case _ => Count(e.expr) + def count(e: Column): Column = withExpr { + e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } } /** @@ -217,8 +223,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = + def countDistinct(expr: Column, exprs: Column*): Column = withExpr { CountDistinct((expr +: exprs).map(_.expr)) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -236,7 +243,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = First(e.expr) + def first(e: Column): Column = withExpr { First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -252,7 +259,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = Kurtosis(e.expr) + def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) } /** * Aggregate function: returns the last value in a group. @@ -260,7 +267,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = Last(e.expr) + def last(e: Column): Column = withExpr { Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -276,7 +283,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = Max(e.expr) + def max(e: Column): Column = withExpr { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -310,7 +317,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = Min(e.expr) + def min(e: Column): Column = withExpr { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -326,7 +333,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = Skewness(e.expr) + def skewness(e: Column): Column = withExpr { Skewness(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -334,7 +341,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = StddevSamp(e.expr) + def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) } /** * Aggregate function: returns the unbiased sample standard deviation of @@ -343,7 +350,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = StddevSamp(e.expr) + def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) } /** * Aggregate function: returns the population standard deviation of @@ -352,7 +359,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = StddevPop(e.expr) + def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) } /** * Aggregate function: returns the sum of all values in the expression. @@ -360,7 +367,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = Sum(e.expr) + def sum(e: Column): Column = withExpr { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -376,7 +383,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) } /** * Aggregate function: returns the sum of distinct values in the expression. @@ -392,7 +399,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = VarianceSamp(e.expr) + def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -400,7 +407,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = VarianceSamp(e.expr) + def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) } /** * Aggregate function: returns the population variance of the values in a group. @@ -408,7 +415,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = VariancePop(e.expr) + def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -429,9 +436,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) - } + def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -446,9 +451,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) - } + def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -460,9 +463,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int): Column = { - lag(e, offset, null) - } + def lag(e: Column, offset: Int): Column = lag(e, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -474,9 +475,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int): Column = { - lag(columnName, offset, null) - } + def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -502,7 +501,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = { + def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) } @@ -516,9 +515,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int): Column = { - lead(columnName, offset, null) - } + def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -530,9 +527,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int): Column = { - lead(e, offset, null) - } + def lead(e: Column, offset: Int): Column = { lead(e, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -558,7 +553,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = { + def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) } @@ -572,9 +567,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = { - UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) - } + def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -589,9 +582,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) - } + def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -606,9 +597,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = { - UnresolvedWindowFunction("rank", Nil) - } + def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } /** * Window function: returns a sequential number starting at 1 within a window partition. @@ -618,9 +607,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) - } + def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -632,7 +619,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def abs(e: Column): Column = Abs(e.expr) + def abs(e: Column): Column = withExpr { Abs(e.expr) } /** * Creates a new array column. The input columns must all have the same data type. @@ -641,7 +628,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) } /** * Creates a new array column. The input columns must all have the same data type. @@ -679,14 +666,14 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs */ - def inputFileName(): Column = InputFileName() + def inputFileName(): Column = withExpr { InputFileName() } /** * Return true iff the column is NaN. @@ -694,7 +681,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def isNaN(e: Column): Column = IsNaN(e.expr) + def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -711,7 +698,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -721,7 +708,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) + def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) } /** * Unary minus, i.e. negate the expression. @@ -760,7 +747,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def rand(seed: Long): Column = Rand(seed) + def rand(seed: Long): Column = withExpr { Rand(seed) } /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. @@ -776,7 +763,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def randn(seed: Long): Column = Randn(seed) + def randn(seed: Long): Column = withExpr { Randn(seed) } /** * Generate a column with i.i.d. samples from the standard normal distribution. @@ -794,7 +781,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = SparkPartitionID() + def sparkPartitionId(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -802,7 +789,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def sqrt(e: Column): Column = Sqrt(e.expr) + def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) } /** * Computes the square root of the specified float value. @@ -823,9 +810,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = { - CreateStruct(cols.map(_.expr)) - } + def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. @@ -858,7 +843,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = { + def when(condition: Column, value: Any): Column = withExpr { CaseWhen(Seq(condition.expr, lit(value).expr)) } @@ -868,7 +853,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) } /** * Parses the expression string into the column that it represents, similar to @@ -893,7 +878,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(e: Column): Column = Acos(e.expr) + def acos(e: Column): Column = withExpr { Acos(e.expr) } /** * Computes the cosine inverse of the given column; the returned angle is in the range @@ -911,7 +896,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(e: Column): Column = Asin(e.expr) + def asin(e: Column): Column = withExpr { Asin(e.expr) } /** * Computes the sine inverse of the given column; the returned angle is in the range @@ -928,7 +913,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(e: Column): Column = Atan(e.expr) + def atan(e: Column): Column = withExpr { Atan(e.expr) } /** * Computes the tangent inverse of the given column. @@ -945,7 +930,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -982,7 +967,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1000,7 +985,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + def atan2(l: Double, r: Column): Column = atan2(lit(l), r) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1018,7 +1003,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(e: Column): Column = Bin(e.expr) + def bin(e: Column): Column = withExpr { Bin(e.expr) } /** * An expression that returns the string representation of the binary value of the given long @@ -1035,7 +1020,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(e: Column): Column = Cbrt(e.expr) + def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) } /** * Computes the cube-root of the given column. @@ -1051,7 +1036,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = Ceil(e.expr) + def ceil(e: Column): Column = withExpr { Ceil(e.expr) } /** * Computes the ceiling of the given column. @@ -1067,8 +1052,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def conv(num: Column, fromBase: Int, toBase: Int): Column = + def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + } /** * Computes the cosine of the given value. @@ -1076,7 +1062,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(e: Column): Column = Cos(e.expr) + def cos(e: Column): Column = withExpr { Cos(e.expr) } /** * Computes the cosine of the given column. @@ -1092,7 +1078,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(e: Column): Column = Cosh(e.expr) + def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** * Computes the hyperbolic cosine of the given column. @@ -1108,7 +1094,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(e: Column): Column = Exp(e.expr) + def exp(e: Column): Column = withExpr { Exp(e.expr) } /** * Computes the exponential of the given column. @@ -1124,7 +1110,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(e: Column): Column = Expm1(e.expr) + def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** * Computes the exponential of the given column. @@ -1140,7 +1126,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def factorial(e: Column): Column = Factorial(e.expr) + def factorial(e: Column): Column = withExpr { Factorial(e.expr) } /** * Computes the floor of the given value. @@ -1148,7 +1134,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = Floor(e.expr) + def floor(e: Column): Column = withExpr { Floor(e.expr) } /** * Computes the floor of the given column. @@ -1166,7 +1152,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = { + def greatest(exprs: Column*): Column = withExpr { require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1189,7 +1175,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def hex(column: Column): Column = Hex(column.expr) + def hex(column: Column): Column = withExpr { Hex(column.expr) } /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -1198,7 +1184,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = Unhex(column.expr) + def unhex(column: Column): Column = withExpr { Unhex(column.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1206,7 +1192,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1239,7 +1225,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + def hypot(l: Column, r: Double): Column = hypot(l, lit(r)) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1255,7 +1241,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + def hypot(l: Double, r: Column): Column = hypot(lit(l), r) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1273,7 +1259,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = { + def least(exprs: Column*): Column = withExpr { require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1296,7 +1282,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(e: Column): Column = Log(e.expr) + def log(e: Column): Column = withExpr { Log(e.expr) } /** * Computes the natural logarithm of the given column. @@ -1312,7 +1298,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -1328,7 +1314,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(e: Column): Column = Log10(e.expr) + def log10(e: Column): Column = withExpr { Log10(e.expr) } /** * Computes the logarithm of the given value in base 10. @@ -1344,7 +1330,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(e: Column): Column = Log1p(e.expr) + def log1p(e: Column): Column = withExpr { Log1p(e.expr) } /** * Computes the natural logarithm of the given column plus one. @@ -1360,7 +1346,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(expr: Column): Column = Log2(expr.expr) + def log2(expr: Column): Column = withExpr { Log2(expr.expr) } /** * Computes the logarithm of the given value in base 2. @@ -1376,7 +1362,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -1408,7 +1394,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + def pow(l: Column, r: Double): Column = pow(l, lit(r)) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1424,7 +1410,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + def pow(l: Double, r: Column): Column = pow(lit(l), r) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1440,7 +1426,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + def pmod(dividend: Column, divisor: Column): Column = withExpr { + Pmod(dividend.expr, divisor.expr) + } /** * Returns the double value that is closest in value to the argument and @@ -1449,7 +1437,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(e: Column): Column = Rint(e.expr) + def rint(e: Column): Column = withExpr { Rint(e.expr) } /** * Returns the double value that is closest in value to the argument and @@ -1466,7 +1454,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = round(e.expr, 0) + def round(e: Column): Column = round(e, 0) /** * Round the value of `e` to `scale` decimal places if `scale` >= 0 @@ -1475,7 +1463,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** * Shift the the given value numBits left. If the given value is a long value, this function @@ -1484,7 +1472,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } /** * Shift the the given value numBits right. If the given value is a long value, it will return @@ -1493,7 +1481,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + def shiftRight(e: Column, numBits: Int): Column = withExpr { + ShiftRight(e.expr, lit(numBits).expr) + } /** * Unsigned shift the the given value numBits right. If the given value is a long value, @@ -1502,8 +1492,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRightUnsigned(e: Column, numBits: Int): Column = + def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr { ShiftRightUnsigned(e.expr, lit(numBits).expr) + } /** * Computes the signum of the given value. @@ -1511,7 +1502,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(e: Column): Column = Signum(e.expr) + def signum(e: Column): Column = withExpr { Signum(e.expr) } /** * Computes the signum of the given column. @@ -1527,7 +1518,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(e: Column): Column = Sin(e.expr) + def sin(e: Column): Column = withExpr { Sin(e.expr) } /** * Computes the sine of the given column. @@ -1543,7 +1534,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(e: Column): Column = Sinh(e.expr) + def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** * Computes the hyperbolic sine of the given column. @@ -1559,7 +1550,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(e: Column): Column = Tan(e.expr) + def tan(e: Column): Column = withExpr { Tan(e.expr) } /** * Computes the tangent of the given column. @@ -1575,7 +1566,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(e: Column): Column = Tanh(e.expr) + def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** * Computes the hyperbolic tangent of the given column. @@ -1591,7 +1582,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toDegrees(e: Column): Column = ToDegrees(e.expr) + def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -1607,7 +1598,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toRadians(e: Column): Column = ToRadians(e.expr) + def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -1628,7 +1619,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def md5(e: Column): Column = Md5(e.expr) + def md5(e: Column): Column = withExpr { Md5(e.expr) } /** * Calculates the SHA-1 digest of a binary column and returns the value @@ -1637,7 +1628,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def sha1(e: Column): Column = Sha1(e.expr) + def sha1(e: Column): Column = withExpr { Sha1(e.expr) } /** * Calculates the SHA-2 family of hash functions of a binary column and @@ -1652,7 +1643,7 @@ object functions { def sha2(e: Column, numBits: Int): Column = { require(Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - Sha2(e.expr, lit(numBits).expr) + withExpr { Sha2(e.expr, lit(numBits).expr) } } /** @@ -1662,7 +1653,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def crc32(e: Column): Column = Crc32(e.expr) + def crc32(e: Column): Column = withExpr { Crc32(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // String functions @@ -1675,7 +1666,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = Ascii(e.expr) + def ascii(e: Column): Column = withExpr { Ascii(e.expr) } /** * Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -1684,7 +1675,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = Base64(e.expr) + def base64(e: Column): Column = withExpr { Base64(e.expr) } /** * Concatenates multiple input string columns together into a single string column. @@ -1693,7 +1684,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } /** * Concatenates multiple input string columns together into a single string column, @@ -1703,7 +1694,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat_ws(sep: String, exprs: Column*): Column = { + def concat_ws(sep: String, exprs: Column*): Column = withExpr { ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) } @@ -1715,7 +1706,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + def decode(value: Column, charset: String): Column = withExpr { + Decode(value.expr, lit(charset).expr) + } /** * Computes the first argument into a binary from a string using the provided character set @@ -1725,7 +1718,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + def encode(value: Column, charset: String): Column = withExpr { + Encode(value.expr, lit(charset).expr) + } /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, @@ -1737,7 +1732,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def format_number(x: Column, d: Int): Column = withExpr { + FormatNumber(x.expr, lit(d).expr) + } /** * Formats the arguments in printf-style and returns the result as a string column. @@ -1746,7 +1743,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def format_string(format: String, arguments: Column*): Column = { + def format_string(format: String, arguments: Column*): Column = withExpr { FormatString((lit(format) +: arguments).map(_.expr): _*) } @@ -1759,7 +1756,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def initcap(e: Column): Column = InitCap(e.expr) + def initcap(e: Column): Column = withExpr { InitCap(e.expr) } /** * Locate the position of the first occurrence of substr column in the given string. @@ -1771,7 +1768,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + def instr(str: Column, substring: String): Column = withExpr { + StringInstr(str.expr, lit(substring).expr) + } /** * Computes the length of a given string or binary column. @@ -1779,7 +1778,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def length(e: Column): Column = Length(e.expr) + def length(e: Column): Column = withExpr { Length(e.expr) } /** * Converts a string column to lower case. @@ -1787,14 +1786,14 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def lower(e: Column): Column = Lower(e.expr) + def lower(e: Column): Column = withExpr { Lower(e.expr) } /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr) } /** * Locate the position of the first occurrence of substr. @@ -1804,7 +1803,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column): Column = { + def locate(substr: String, str: Column): Column = withExpr { new StringLocate(lit(substr).expr, str.expr) } @@ -1817,7 +1816,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column, pos: Int): Column = { + def locate(substr: String, str: Column, pos: Int): Column = withExpr { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } @@ -1827,7 +1826,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = { + def lpad(str: Column, len: Int, pad: String): Column = withExpr { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1837,7 +1836,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = StringTrimLeft(e.expr) + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** * Extract a specific(idx) group identified by a java regex, from the specified string column. @@ -1845,7 +1844,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } @@ -1855,7 +1854,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) } @@ -1866,7 +1865,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = UnBase64(e.expr) + def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** * Right-padded with pad to a length of len. @@ -1874,7 +1873,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = { + def rpad(str: Column, len: Int, pad: String): Column = withExpr { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1884,7 +1883,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, n: Int): Column = { + def repeat(str: Column, n: Int): Column = withExpr { StringRepeat(str.expr, lit(n).expr) } @@ -1894,9 +1893,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def reverse(str: Column): Column = { - StringReverse(str.expr) - } + def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } /** * Trim the spaces from right end for the specified string value. @@ -1904,7 +1901,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = StringTrimRight(e.expr) + def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** * * Return the soundex code for the specified expression. @@ -1912,7 +1909,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def soundex(e: Column): Column = SoundEx(e.expr) + def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** * Splits str around pattern (pattern is a regular expression). @@ -1921,7 +1918,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = { + def split(str: Column, pattern: String): Column = withExpr { StringSplit(str.expr, lit(pattern).expr) } @@ -1933,8 +1930,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def substring(str: Column, pos: Int, len: Int): Column = + def substring(str: Column, pos: Int, len: Int): Column = withExpr { Substring(str.expr, lit(pos).expr, lit(len).expr) + } /** * Returns the substring from string str before count occurrences of the delimiter delim. @@ -1944,8 +1942,9 @@ object functions { * * @group string_funcs */ - def substring_index(str: Column, delim: String, count: Int): Column = + def substring_index(str: Column, delim: String, count: Int): Column = withExpr { SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + } /** * Translate any character in the src by a character in replaceString. @@ -1956,8 +1955,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def translate(src: Column, matchingString: String, replaceString: String): Column = + def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + } /** * Trim the spaces from both ends for the specified string column. @@ -1965,7 +1965,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = StringTrim(e.expr) + def trim(e: Column): Column = withExpr { StringTrim(e.expr) } /** * Converts a string column to upper case. @@ -1973,7 +1973,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def upper(e: Column): Column = Upper(e.expr) + def upper(e: Column): Column = withExpr { Upper(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // DateTime functions @@ -1985,8 +1985,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = + def add_months(startDate: Column, numMonths: Int): Column = withExpr { AddMonths(startDate.expr, Literal(numMonths)) + } /** * Returns the current date as a date column. @@ -1994,7 +1995,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_date(): Column = CurrentDate() + def current_date(): Column = withExpr { CurrentDate() } /** * Returns the current timestamp as a timestamp column. @@ -2002,7 +2003,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_timestamp(): Column = CurrentTimestamp() + def current_timestamp(): Column = withExpr { CurrentTimestamp() } /** * Converts a date/timestamp/string to a value of string in the format specified by the date @@ -2017,71 +2018,72 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_format(dateExpr: Column, format: String): Column = + def date_format(dateExpr: Column, format: String): Column = withExpr { DateFormatClass(dateExpr.expr, Literal(format)) + } /** * Returns the date that is `days` days after `start` * @group datetime_funcs * @since 1.5.0 */ - def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) } /** * Returns the date that is `days` days before `start` * @group datetime_funcs * @since 1.5.0 */ - def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) } /** * Returns the number of days from `start` to `end`. * @group datetime_funcs * @since 1.5.0 */ - def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) } /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def year(e: Column): Column = Year(e.expr) + def year(e: Column): Column = withExpr { Year(e.expr) } /** * Extracts the quarter as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def quarter(e: Column): Column = Quarter(e.expr) + def quarter(e: Column): Column = withExpr { Quarter(e.expr) } /** * Extracts the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def month(e: Column): Column = Month(e.expr) + def month(e: Column): Column = withExpr { Month(e.expr) } /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) } /** * Extracts the day of the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofyear(e: Column): Column = DayOfYear(e.expr) + def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) } /** * Extracts the hours as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def hour(e: Column): Column = Hour(e.expr) + def hour(e: Column): Column = withExpr { Hour(e.expr) } /** * Given a date column, returns the last day of the month which the given date belongs to. @@ -2091,21 +2093,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def last_day(e: Column): Column = LastDay(e.expr) + def last_day(e: Column): Column = withExpr { LastDay(e.expr) } /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def minute(e: Column): Column = Minute(e.expr) + def minute(e: Column): Column = withExpr { Minute(e.expr) } /* * Returns number of months between dates `date1` and `date2`. * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + def months_between(date1: Column, date2: Column): Column = withExpr { + MonthsBetween(date1.expr, date2.expr) + } /** * Given a date column, returns the first date which is later than the value of the date column @@ -2120,21 +2124,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + def next_day(date: Column, dayOfWeek: String): Column = withExpr { + NextDay(date.expr, lit(dayOfWeek).expr) + } /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def second(e: Column): Column = Second(e.expr) + def second(e: Column): Column = withExpr { Second(e.expr) } /** * Extracts the week number as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def weekofyear(e: Column): Column = WeekOfYear(e.expr) + def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2143,7 +2149,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def from_unixtime(ut: Column): Column = withExpr { + FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2152,14 +2160,18 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + def from_unixtime(ut: Column, f: String): Column = withExpr { + FromUnixTime(ut.expr, Literal(f)) + } /** * Gets current Unix timestamp in seconds. * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(): Column = withExpr { + UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), @@ -2167,7 +2179,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(s: Column): Column = withExpr { + UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Convert time string with given pattern @@ -2176,7 +2190,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } /** * Converts the column into DateType. @@ -2184,7 +2198,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = ToDate(e.expr) + def to_date(e: Column): Column = withExpr { ToDate(e.expr) } /** * Returns date truncated to the unit specified by the format. @@ -2195,22 +2209,27 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + def trunc(date: Column, format: String): Column = withExpr { + TruncDate(date.expr, Literal(format)) + } /** * Assumes given timestamp is UTC and converts to given timezone. * @group datetime_funcs * @since 1.5.0 */ - def from_utc_timestamp(ts: Column, tz: String): Column = - FromUTCTimestamp(ts.expr, Literal(tz).expr) + def from_utc_timestamp(ts: Column, tz: String): Column = withExpr { + FromUTCTimestamp(ts.expr, Literal(tz)) + } /** * Assumes given timestamp is in given timezone and converts to UTC. * @group datetime_funcs * @since 1.5.0 */ - def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + def to_utc_timestamp(ts: Column, tz: String): Column = withExpr { + ToUTCTimestamp(ts.expr, Literal(tz)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions @@ -2221,8 +2240,9 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def array_contains(column: Column, value: Any): Column = + def array_contains(column: Column, value: Any): Column = withExpr { ArrayContains(column.expr, Literal(value)) + } /** * Creates a new row for each element in the given array or map column. @@ -2230,7 +2250,7 @@ object functions { * @group collection_funcs * @since 1.3.0 */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = withExpr { Explode(e.expr) } /** * Returns length of array or map. @@ -2238,7 +2258,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = Size(e.expr) + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, @@ -2256,7 +2276,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2296,11 +2316,10 @@ object functions { * @deprecated As of 1.5.0, since it's redundant with udf() */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2435,147 +2454,146 @@ object functions { } ////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = { + def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2597,7 +2615,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def callUDF(udfName: String, cols: Column*): Column = { + def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } @@ -2618,7 +2636,7 @@ object functions { * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ @deprecated("Use callUDF", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = { + def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). val exprs = new Array[Expression](cols.size) From 244010624200eddea6dfd1b2c89f40be45212e96 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 16:34:10 -0800 Subject: [PATCH 200/324] [SPARK-11542] [SPARKR] fix glm with long fomular Because deparse() will break the long string into multiple lines, the deserialization will fail Author: Davies Liu Closes #9510 from davies/fix_glm. --- R/pkg/R/mllib.R | 3 ++- R/pkg/inst/tests/test_mllib.R | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 60bfadb8e7503..b0d73dd93a79d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -48,8 +48,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, standardize = TRUE, solver = "auto") { family <- match.arg(family) + formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + "fitRModelFormula", formula, data@sdf, family, lambda, alpha, standardize, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032cfef061fd3..4761e285a2479 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -33,6 +33,18 @@ test_that("glm and predict", { expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") }) +test_that("glm should work with long formula", { + training <- createDataFrame(sqlContext, iris) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) From 07414afac9a100ede1dee5f3d45a657802c8bd2a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 17:02:22 -0800 Subject: [PATCH 201/324] [SPARK-11537] [SQL] fix negative hours/minutes/seconds Currently, if the Timestamp is before epoch (1970/01/01), the hours, minutes and seconds will be negative (also rounding up). Author: Davies Liu Closes #9502 from davies/neg_hour. --- .../sql/catalyst/util/DateTimeUtils.scala | 23 ++++++++++++------- .../catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 781ed1688a327..f5fff90e5a542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -392,29 +392,36 @@ object DateTimeUtils { Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + /** + * Returns the microseconds since year zero (-17999) from microseconds since epoch. + */ + def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + microsec + toYearZero * MICROS_PER_DAY + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ - def getHours(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 3600) % 24).toInt + def getHours(microsec: SQLTimestamp): Int = { + val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + ((localTs / MICROS_PER_SECOND / 3600) % 24).toInt } /** * Returns the minute value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getMinutes(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 60) % 60).toInt + def getMinutes(microsec: SQLTimestamp): Int = { + val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + ((localTs / MICROS_PER_SECOND / 60) % 60).toInt } /** * Returns the second value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getSeconds(timestamp: SQLTimestamp): Int = { - ((timestamp / 1000 / 1000) % 60).toInt + def getSeconds(microsec: SQLTimestamp): Int = { + ((absoluteMicroSecond(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 46335941b62d6..64d15e6b910c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -358,6 +358,19 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } + test("hours / miniute / seconds") { + Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), + Timestamp.valueOf("2015-06-11 20:13:40.789"), + Timestamp.valueOf("1900-06-11 12:14:50.789"), + Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => + val us = fromJavaTimestamp(t) + assert(toJavaTimestamp(us) === t) + assert(getHours(us) === t.getHours) + assert(getMinutes(us) === t.getMinutes) + assert(getSeconds(us) === t.getSeconds) + } + } + test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) From 6091e91fca58078a0f1d9c35d68c0ae7205a534c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 17:10:35 -0800 Subject: [PATCH 202/324] Revert "[SPARK-11469][SQL] Allow users to define nondeterministic udfs." This reverts commit 9cf56c96b7d02a14175d40b336da14c2e1c88339. --- project/MimaExcludes.scala | 47 ----- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../apache/spark/sql/UDFRegistration.scala | 164 ++++++++---------- .../spark/sql/UserDefinedFunction.scala | 13 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 105 ----------- .../datasources/parquet/ParquetIOSuite.scala | 4 +- 6 files changed, 78 insertions(+), 262 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90dc947d4e588..40f5c9fec8bb8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -114,53 +114,6 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a04af7f1dd877..11c7950c0613b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,18 +30,13 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil, - isDeterministic: Boolean = true) + inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true override def toString: String = s"UDF(${children.mkString(",")})" - override def foldable: Boolean = deterministic && children.forall(_.foldable) - - override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic) - // scalastyle:off /** This method has been generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f5b95e13e47bc..fc4d0938c533a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -58,10 +58,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. - * @param udaf the UDAF that needs to be registered. + * @param udaf the UDAF needs to be registered. * @return the registered UDAF. - * - * @since 1.5.0 */ def register( name: String, @@ -71,22 +69,6 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } - /** - * Register a user-defined function (UDF). - * - * @param name the name of the UDF. - * @param udf the UDF that needs to be registered. - * @return the registered UDF. - * - * @since 1.6.0 - */ - def register( - name: String, - udf: UserDefinedFunction): UserDefinedFunction = { - functionRegistry.registerFunction(name, udf.builder) - udf - } - // scalastyle:off /* register 0-22 were generated by this script @@ -104,9 +86,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try($inputTypes).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) }""") } @@ -136,9 +118,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -149,9 +131,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -162,9 +144,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -175,9 +157,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -188,9 +170,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -201,9 +183,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -214,9 +196,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -227,9 +209,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -240,9 +222,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -253,9 +235,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -266,9 +248,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -279,9 +261,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -292,9 +274,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -305,9 +287,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -318,9 +300,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -331,9 +313,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -344,9 +326,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -357,9 +339,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -370,9 +352,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -383,9 +365,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -396,9 +378,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -409,9 +391,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -422,9 +404,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 1319391db5375..0f8cd280b5acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,20 +44,11 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil, - deterministic: Boolean = true) { + inputTypes: Seq[DataType] = Nil) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) } - - protected[sql] def builder: Seq[Expression] => ScalaUDF = { - (exprs: Seq[Expression]) => - ScalaUDF(f, dataType, exprs, inputTypes, deterministic) - } - - def nondeterministic: UserDefinedFunction = - UserDefinedFunction(f, dataType, inputTypes, deterministic = false) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6e510f0b8aff4..e0435a0dba6ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -193,107 +191,4 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } - - private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { - val udfs = df.queryExecution.optimizedPlan.collect { - case p: logical.Project => p.projectList.flatMap { - case e => e.collect { - case udf: ScalaUDF => udf - } - } - }.flatten - assert(udfs.length === expectedNumUDFs) - } - - test("foldable udf") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 0) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using UDFRegistration") { - import org.apache.spark.sql.functions._ - - val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) - sqlContext.udf.register("plusOne2", myUDF.nondeterministic) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne1", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne2", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using udf function") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - // nondeterministicUDF will not be foldable. - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("override a registered udf") { - sqlContext.udf.register("intExpected", (x: Int) => x) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) - - sqlContext.udf.register("intExpected", (x: Int) => x + 1) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f14b2886a9ecb..72744799897be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) From 8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 5 Nov 2015 17:59:01 -0800 Subject: [PATCH 203/324] [SPARK-11514][ML] Pass random seed to spark.ml DecisionTree* cc jkbradley Author: Yu ISHIKAWA Closes #9486 from yu-iskw/SPARK-11514. --- .../ml/classification/DecisionTreeClassifier.scala | 4 +++- .../spark/ml/regression/DecisionTreeRegressor.scala | 4 +++- .../scala/org/apache/spark/ml/tree/treeParams.scala | 11 ++++++----- .../classification/DecisionTreeClassifierSuite.scala | 1 + .../ml/regression/DecisionTreeRegressorSuite.scala | 1 + 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b0157f7ce24ec..c478aea44ace8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String) override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 04420fc6e8251..477030d9ea3ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 281ba6eeffa92..1da97db9277d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { +private[ml] trait DecisionTreeParams extends PredictorParams + with HasCheckpointInterval with HasSeed { /** * Maximum depth of the tree (>= 0). @@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 815f6fd997584..92b8f84144ab0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setImpurity("gini") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 868fb8eecb8bb..e0d5afa7a7e97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } From 468ad0ae874d5cf55712ee976faf77f19c937ccb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 5 Nov 2015 18:03:12 -0800 Subject: [PATCH 204/324] [SPARK-11457][STREAMING][YARN] Fix incorrect AM proxy filter conf recovery from checkpoint Currently Yarn AM proxy filter configuration is recovered from checkpoint file when Spark Streaming application is restarted, which will lead to some unwanted behaviors: 1. Wrong RM address if RM is redeployed from failure. 2. Wrong proxyBase, since app id is updated, old app id for proxyBase is wrong. So instead of recovering from checkpoint file, these configurations should be reloaded each time when app started. This problem only exists in Yarn cluster mode, for Yarn client mode, these configurations will be updated with RPC message `AddWebUIFilter`. Please help to review tdas harishreedharan vanzin , thanks a lot. Author: jerryshao Closes #9412 from jerryshao/SPARK-11457. --- .../org/apache/spark/streaming/Checkpoint.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b7de6dde61c63..0cd55d9aec2cd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,7 +55,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.port", "spark.master", "spark.yarn.keytab", - "spark.yarn.principal") + "spark.yarn.principal", + "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") @@ -66,6 +67,16 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) newSparkConf.set(prop, value) } } + + // Add Yarn proxy filter specific configurations to the recovered SparkConf + val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val filterPrefix = s"spark.$filter.param." + newReloadConf.getAll.foreach { case (k, v) => + if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) { + newSparkConf.set(k, v) + } + } + newSparkConf } From 5e31db70bb783656ba042863fcd3c223e17a8f81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 5 Nov 2015 18:05:58 -0800 Subject: [PATCH 205/324] [SPARK-11538][BUILD] Force guava 14 in sbt build. sbt's version resolution code always picks the most recent version, and we don't want that for guava. Author: Marcelo Vanzin Closes #9508 from vanzin/SPARK-11538. --- project/SparkBuild.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 75c36930decef..b75ed13a78c68 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -207,7 +207,8 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ + ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -291,6 +292,14 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + * Overrides to work around sbt's dependency resolution being different from Maven's. + */ +object DependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") +} + /** This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. From 3cc2c053b5d68c747a30bd58cf388b87b1922f13 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 18:12:54 -0800 Subject: [PATCH 206/324] [SPARK-11540][SQL] API audit for QueryExecutionListener. Author: Reynold Xin Closes #9509 from rxin/SPARK-11540. --- .../spark/sql/execution/QueryExecution.scala | 30 +++--- .../sql/util/QueryExecutionListener.scala | 101 ++++++++++-------- 2 files changed, 72 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index fc9174549e642..c2142d03f422b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import com.google.common.annotations.VisibleForTesting + import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -25,31 +27,33 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. + * + * While this is not a public class, we should avoid changing the function names for the sake of + * changing them, because a lot of developers use the feature for debugging. */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - val analyzer = sqlContext.analyzer - val optimizer = sqlContext.optimizer - val planner = sqlContext.planner - val cacheManager = sqlContext.cacheManager - val prepareForExecution = sqlContext.prepareForExecution - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + @VisibleForTesting + def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) - lazy val analyzed: LogicalPlan = analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() - cacheManager.useCachedData(analyzed) + sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) + + lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(sqlContext) - planner.plan(optimizedPlan).next() + sqlContext.planner.plan(optimizedPlan).next() } + // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() @@ -57,11 +61,11 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } - def simpleString: String = + def simpleString: String = { s"""== Physical Plan == |${stringOrError(executedPlan)} """.stripMargin.trim - + } override def toString: String = { def output = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 909a8abd225b8..ac432e2baa3c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -19,36 +19,38 @@ package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer +import scala.util.control.NonFatal -import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.execution.QueryExecution /** + * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they will be used in a non - * thread-safe way. + * Note that implementations should guarantee thread-safety as they can be invoked by + * multiple different threads. */ @Experimental trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * - * @param funcName the name of the action that triggered this query. + * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. - * @param duration the execution time for this query in nanoseconds. + * @param durationNs the execution time for this query in nanoseconds. */ @DeveloperApi - def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, @@ -56,34 +58,20 @@ trait QueryExecutionListener { * @param exception the exception that failed this query. */ @DeveloperApi - def onFailure(funcName: String, qe: QueryExecution, exception: Exception) + def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit } -@Experimental -class ExecutionListenerManager extends Logging { - private[this] val listeners = ListBuffer.empty[QueryExecutionListener] - private[this] val lock = new ReentrantReadWriteLock() - - /** Acquires a read lock on the cache for the duration of `f`. */ - private def readLock[A](f: => A): A = { - val rl = lock.readLock() - rl.lock() - try f finally { - rl.unlock() - } - } - /** Acquires a write lock on the cache for the duration of `f`. */ - private def writeLock[A](f: => A): A = { - val wl = lock.writeLock() - wl.lock() - try f finally { - wl.unlock() - } - } +/** + * :: Experimental :: + * + * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. + */ +@Experimental +class ExecutionListenerManager private[sql] () extends Logging { /** - * Registers the specified QueryExecutionListener. + * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi def register(listener: QueryExecutionListener): Unit = writeLock { @@ -91,7 +79,7 @@ class ExecutionListenerManager extends Logging { } /** - * Unregisters the specified QueryExecutionListener. + * Unregisters the specified [[QueryExecutionListener]]. */ @DeveloperApi def unregister(listener: QueryExecutionListener): Unit = writeLock { @@ -99,38 +87,59 @@ class ExecutionListenerManager extends Logging { } /** - * clears out all registered QueryExecutionListeners. + * Removes all the registered [[QueryExecutionListener]]. */ @DeveloperApi def clear(): Unit = writeLock { listeners.clear() } - private[sql] def onSuccess( - funcName: String, - qe: QueryExecution, - duration: Long): Unit = readLock { - withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } } } - private[sql] def onFailure( - funcName: String, - qe: QueryExecution, - exception: Exception): Unit = readLock { - withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) + private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } } } + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + + /** A lock to prevent updating the list of listeners while we are traversing through them. */ + private[this] val lock = new ReentrantReadWriteLock() + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { for (listener <- listeners) { try { f(listener) } catch { - case e: Exception => logWarning("error executing query execution listener", e) + case NonFatal(e) => logWarning("Error executing query execution listener", e) } } } + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } } From eec74ba8bde7f9446cc38e687bda103e85669d35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 19:02:18 -0800 Subject: [PATCH 207/324] [SPARK-7542][SQL] Support off-heap index/sort buffer This brings the support of off-heap memory for array inside BytesToBytesMap and InMemorySorter, then we could allocate all the memory from off-heap for execution. Closes #8068 Author: Davies Liu Closes #9477 from davies/unsafe_timsort. --- .../apache/spark/memory/MemoryConsumer.java | 36 +++++----- .../spark/memory/TaskMemoryManager.java | 6 +- .../shuffle/sort/ShuffleExternalSorter.java | 26 +++---- .../shuffle/sort/ShuffleInMemorySorter.java | 67 ++++++++++--------- .../shuffle/sort/ShuffleSortDataFormat.java | 38 +++++++---- .../spark/unsafe/map/BytesToBytesMap.java | 18 +++-- .../unsafe/sort/UnsafeExternalSorter.java | 28 +++----- .../unsafe/sort/UnsafeInMemorySorter.java | 66 +++++++++++------- .../unsafe/sort/UnsafeSortDataFormat.java | 47 +++++++------ .../spark/memory/TaskMemoryManagerSuite.java | 23 ------- .../spark/memory/TestMemoryConsumer.java | 45 +++++++++++++ .../sort/ShuffleInMemorySorterSuite.java | 16 +++-- .../sort/UnsafeExternalSorterSuite.java | 1 - .../sort/UnsafeInMemorySorterSuite.java | 12 ++-- .../sql/execution/UnsafeKVExternalSorter.java | 3 +- .../apache/spark/unsafe/array/LongArray.java | 18 ++++- .../spark/unsafe/array/LongArraySuite.java | 4 ++ 17 files changed, 265 insertions(+), 189 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 008799cc77395..8fbdb72832adf 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -20,6 +20,7 @@ import java.io.IOException; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -28,9 +29,9 @@ */ public abstract class MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; + protected final TaskMemoryManager taskMemoryManager; private final long pageSize; - private long used; + protected long used; protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; @@ -74,26 +75,29 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Acquire `size` bytes memory. - * - * If there is not enough memory, throws OutOfMemoryError. + * Allocates a LongArray of `size`. */ - protected void acquireMemory(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); - if (got < size) { - taskMemoryManager.releaseExecutionMemory(got, this); + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); } - used += got; + used += required; + return new LongArray(page); } /** - * Release `size` bytes memory. + * Frees a LongArray. */ - protected void releaseMemory(long size) { - used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); } /** @@ -109,7 +113,7 @@ protected MemoryBlock allocatePage(long required) { long got = 0; if (page != null) { got = page.size(); - freePage(page); + taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 4230575446d31..6440f9c0f30de 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -137,7 +137,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got < required) { // Call spill() on other consumers to release memory for (MemoryConsumer c: consumers) { - if (c != null && c != consumer && c.getUsed() > 0) { + if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -173,7 +173,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { } } - consumers.add(consumer); + if (consumer != null) { + consumers.add(consumer); + } logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); return got; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 400d8520019b9..9affff80143d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -39,6 +39,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -114,8 +115,7 @@ public ShuffleExternalSorter( this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - acquireMemory(initialSize * 8L); - this.inMemSorter = new ShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); this.peakMemoryUsedBytes = getMemoryUsage(); } @@ -301,9 +301,8 @@ private long freeMemory() { public void cleanupResources() { freeMemory(); if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -321,9 +320,10 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -331,16 +331,9 @@ private void growPointerArrayIfNecessary() throws IOException { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -404,9 +397,8 @@ public SpillInfo[] closeAndGetSpills() throws IOException { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index e630575d1ae19..58ad88e1ed87b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -19,11 +19,14 @@ import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; final class ShuffleInMemorySorter { - private final Sorter sorter; + private final Sorter sorter; private static final class SortComparator implements Comparator { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { @@ -32,24 +35,34 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { } private static final SortComparator SORT_COMPARATOR = new SortComparator(); + private final MemoryConsumer consumer; + /** * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] array; + private LongArray array; /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; - public ShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + this.consumer = consumer; assert (initialSize > 0); - this.array = new long[initialSize]; + this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } + public void free() { + if (array != null) { + consumer.freeArray(array); + array = null; + } + } + public int numRecords() { return pos; } @@ -58,30 +71,25 @@ public void reset() { pos = 0; } - private int newLength() { - // Guard against overflow: - return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - /** - * Returns the memory needed to expand - */ - public long getMemoryToExpand() { - return ((long) (newLength() - array.length)) * 8; - } - - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L + ); + consumer.freeArray(array); + array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.length; + return pos < array.size(); } public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } /** @@ -96,14 +104,9 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (array.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort pointer array has reached maximum size"); - } else { - expandPointerArray(); - } + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = - PackedRecordPointer.packPointer(recordPointer, partitionId); + array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; } @@ -112,12 +115,12 @@ public void insertRecord(long recordPointer, int partitionId) { */ public static final class ShuffleSorterIterator { - private final long[] pointerArray; + private final LongArray pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public ShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, LongArray pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -127,7 +130,7 @@ public boolean hasNext() { } public void loadNext() { - packedRecordPointer.set(pointerArray[position]); + packedRecordPointer.set(pointerArray.get(position)); position++; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8a1e5aec6ff0e..8f4e3229976dc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,16 +17,19 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; -final class ShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); private ShuffleSortDataFormat() { } @Override - public PackedRecordPointer getKey(long[] data, int pos) { + public PackedRecordPointer getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -37,31 +40,38 @@ public PackedRecordPointer newKey() { } @Override - public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.set(data[pos]); + public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) { + reuse.set(data.get(pos)); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - final long temp = data[pos0]; - data[pos0] = data[pos1]; - data[pos1] = temp; + public void swap(LongArray data, int pos0, int pos1) { + final long temp = data.get(pos0); + data.set(pos0, data.get(pos1)); + data.set(pos1, temp); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos] = src[srcPos]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos, src.get(srcPos)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos, dst, dstPos, length); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8, + length * 8 + ); } @Override - public long[] allocate(int length) { - return new long[length]; + public LongArray allocate(int length) { + // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length])); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6656fd1d0bc59..04694dc54418c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,7 +20,6 @@ import javax.annotation.Nullable; import java.io.File; import java.io.IOException; -import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -724,11 +723,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { */ private void allocate(int capacity) { assert (capacity >= 0); - // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - acquireMemory(capacity * 16); - longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); + longArray = allocateArray(capacity * 2); + longArray.zeroOut(); this.growthThreshold = (int) (capacity * loadFactor); this.mask = capacity - 1; @@ -743,9 +741,8 @@ private void allocate(int capacity) { public void free() { updatePeakMemoryUsed(); if (longArray != null) { - long used = longArray.memoryBlock().size(); + freeArray(longArray); longArray = null; - releaseMemory(used); } Iterator dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { @@ -834,9 +831,9 @@ public int getNumDataPages() { /** * Returns the underline long[] of longArray. */ - public long[] getArray() { + public LongArray getArray() { assert(longArray != null); - return (long[]) longArray.memoryBlock().getBaseObject(); + return longArray; } /** @@ -844,7 +841,8 @@ public long[] getArray() { */ public void reset() { numElements = 0; - Arrays.fill(getArray(), 0); + longArray.zeroOut(); + while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); @@ -887,7 +885,7 @@ void growAndRehash() { longArray.set(newPos * 2, keyPointer); longArray.set(newPos * 2 + 1, hashcode); } - releaseMemory(oldLongArray.memoryBlock().size()); + freeArray(oldLongArray); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index cba043bc48cc8..9a7b2ad06cab6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,6 +32,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; @@ -123,9 +124,8 @@ private UnsafeExternalSorter( this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { - this.inMemSorter = - new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); - acquireMemory(inMemSorter.getMemoryUsage()); + this.inMemSorter = new UnsafeInMemorySorter( + this, taskMemoryManager, recordComparator, prefixComparator, initialSize); } else { this.inMemSorter = existingInMemorySorter; } @@ -277,9 +277,8 @@ public void cleanupResources() { deleteSpillFiles(); freeMemory(); if (inMemSorter != null) { - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } } } @@ -293,9 +292,10 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -303,16 +303,9 @@ private void growPointerArrayIfNecessary() throws IOException { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -498,9 +491,8 @@ public void loadNext() throws IOException { nextUpstream = null; assert(inMemSorter != null); - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index d57213b9b8bfc..a218ad4623f46 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,8 +19,10 @@ import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; /** @@ -62,15 +64,16 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { } } + private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; - private final Sorter sorter; + private final Sorter sorter; private final Comparator sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] array; + private LongArray array; /** * The position in the sort buffer where new records can be inserted. @@ -78,22 +81,33 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private int pos = 0; public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, int initialSize) { - this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + this(consumer, memoryManager, recordComparator, prefixComparator, + consumer.allocateArray(initialSize * 2)); } public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - long[] array) { - this.array = array; + LongArray array) { + this.consumer = consumer; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + this.array = array; + } + + /** + * Free the memory used by pointer array. + */ + public void free() { + consumer.freeArray(array); } public void reset() { @@ -107,26 +121,26 @@ public int numRecords() { return pos / 2; } - private int newLength() { - return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - public long getMemoryToExpand() { - return (long) (newLength() - array.length) * 8L; - } - public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } public boolean hasSpaceForAnotherRecord() { - return pos + 2 <= array.length; + return pos + 2 <= array.size(); } - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + if (newArray.size() < array.size()) { + throw new OutOfMemoryError("Not enough memory to grow pointer array"); + } + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L); + consumer.freeArray(array); + array = newArray; } /** @@ -138,11 +152,11 @@ public void expandPointerArray() { */ public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(); + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = recordPointer; + array.set(pos, recordPointer); pos++; - array[pos] = keyPrefix; + array.set(pos, keyPrefix); pos++; } @@ -150,7 +164,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private final TaskMemoryManager memoryManager; private final int sortBufferInsertPosition; - private final long[] sortBuffer; + private final LongArray sortBuffer; private int position = 0; private Object baseObject; private long baseOffset; @@ -160,7 +174,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private SortedIterator( TaskMemoryManager memoryManager, int sortBufferInsertPosition, - long[] sortBuffer) { + LongArray sortBuffer) { this.memoryManager = memoryManager; this.sortBufferInsertPosition = sortBufferInsertPosition; this.sortBuffer = sortBuffer; @@ -188,11 +202,11 @@ public int numRecordsLeft() { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer[position]; + final long recordPointer = sortBuffer.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer[position + 1]; + keyPrefix = sortBuffer.get(position + 1); position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d09c728a7a638..d3137f5f31c25 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -17,6 +17,9 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; /** @@ -26,14 +29,14 @@ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat extends SortDataFormat { +final class UnsafeSortDataFormat extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -44,37 +47,43 @@ public RecordPointerAndKeyPrefix newKey() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { - reuse.recordPointer = data[pos * 2]; - reuse.keyPrefix = data[pos * 2 + 1]; + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data.get(pos * 2); + reuse.keyPrefix = data.get(pos * 2 + 1); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - long tempPointer = data[pos0 * 2]; - long tempKeyPrefix = data[pos0 * 2 + 1]; - data[pos0 * 2] = data[pos1 * 2]; - data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; - data[pos1 * 2] = tempPointer; - data[pos1 * 2 + 1] = tempKeyPrefix; + public void swap(LongArray data, int pos0, int pos1) { + long tempPointer = data.get(pos0 * 2); + long tempKeyPrefix = data.get(pos0 * 2 + 1); + data.set(pos0 * 2, data.get(pos1 * 2)); + data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1)); + data.set(pos1 * 2, tempPointer); + data.set(pos1 * 2 + 1, tempKeyPrefix); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos * 2] = src[srcPos * 2]; - dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos * 2, src.get(srcPos * 2)); + dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 16, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 16, + length * 16); } @Override - public long[] allocate(int length) { + public LongArray allocate(int length) { assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - return new long[length * 2]; + // This is used as temporary buffer, it's fine to allocate from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length * 2])); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index dab7b0592cb4e..c731317395612 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.memory; -import java.io.IOException; - import org.junit.Assert; import org.junit.Test; @@ -27,27 +25,6 @@ public class TaskMemoryManagerSuite { - class TestMemoryConsumer extends MemoryConsumer { - TestMemoryConsumer(TaskMemoryManager memoryManager) { - super(memoryManager); - } - - @Override - public long spill(long size, MemoryConsumer trigger) throws IOException { - long used = getUsed(); - releaseMemory(used); - return used; - } - - void use(long size) { - acquireMemory(size); - } - - void free(long size) { - releaseMemory(size); - } - } - @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java new file mode 100644 index 0000000000000..8ae3642738509 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager) { + super(memoryManager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + long used = getUsed(); + free(used); + return used; + } + + void use(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + used += got; + } + + void free(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } +} + + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 2293b1bbc113e..faa5a863ee630 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -25,13 +25,19 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); @@ -40,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -63,7 +69,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -104,7 +110,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cfead0e5924b8..11c3a7be38875 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -390,7 +390,6 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - // The first page is pre-allocated on instantiation if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 642f6585f8a15..a203a09648ac0 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -23,6 +23,7 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -44,9 +45,11 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -69,6 +72,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { }; final TaskMemoryManager memoryManager = new TaskMemoryManager( new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: @@ -102,7 +106,7 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index e2898ef2e2158..8c9b9c85e37fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,8 +85,9 @@ public UnsafeKVExternalSorter( } else { // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). + // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 74105050e4191..1a3cdff638264 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -39,7 +39,6 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); @@ -51,6 +50,14 @@ public MemoryBlock memoryBlock() { return memory; } + public Object getBaseObject() { + return baseObj; + } + + public long getBaseOffset() { + return baseOffset; + } + /** * Returns the number of elements this array can hold. */ @@ -58,6 +65,15 @@ public long size() { return length; } + /** + * Fill this all with 0L. + */ + public void zeroOut() { + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); + } + } + /** * Sets the value at position {@code index}. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 5974cf91ff993..fb8e53b3348f3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -34,5 +34,9 @@ public void basicTest() { Assert.assertEquals(2, arr.size()); Assert.assertEquals(1L, arr.get(0)); Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); } } From 363a476c3fefb0263e63fd24df0b2779a64f79ec Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 5 Nov 2015 21:42:32 -0800 Subject: [PATCH 208/324] [SPARK-11528] [SQL] Typed aggregations for Datasets This PR adds the ability to do typed SQL aggregations. We will likely also want to provide an interface to allow users to do aggregations on objects, but this is deferred to another PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds.groupBy(_._1).agg(sum("_2").as[Int]).collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust Closes #9499 from marmbrus/dataset-agg. --- .../expressions/namedExpressions.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 93 ++++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++ 4 files changed, 132 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8957df0be6814..9ab5c299d0f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -254,6 +254,10 @@ case class AttributeReference( } override def toString: String = s"$name#${exprId.id}$typeSuffix" + + // Since the expression id is not in the first constructor it is missing from the default + // tree string. + override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500227e93a472..4bca9c3b3fe54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType * @since 1.6.0 */ @Experimental -class Dataset[T] private( +class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, unresolvedEncoder: Encoder[T]) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 96d6e9dd548e5..b8fc373dffcf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution /** + * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing * [[Dataset]]. + * + * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, + * making this change to the class hierarchy would break some function signatures. As such, this + * class should be considered a preview of the final API. Changes will be made to the interface + * after Spark 1.6. */ +@Experimental class GroupedDataset[K, T] private[sql]( private val kEncoder: Encoder[K], private val tEncoder: Encoder[T], @@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql]( private val groupingAttributes: Seq[Attribute]) extends Serializable { private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) + case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) case other => throw new UnsupportedOperationException("Only expression encoders are currently supported") } @@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql]( throw new UnsupportedOperationException("Only expression encoders are currently supported") } + /** Encoders for built in aggregations. */ + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. @@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + // To ensure valid overloading. + protected def agg(expr: Column, exprs: Column*): DataFrame = + groupedData.agg(expr, exprs: _*) + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggrecations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + val aliases = (groupingAttributes ++ columns.map(_.expr)).map { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) + val execution = new QueryExecution(sqlContext, unresolvedPlan) + + val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) + + // Rebind the encoders to the nested schema that will be produced by the aggregation. + val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { + case (e: ExpressionEncoder[_], a) if !e.flat => + e.nested(a).resolve(execution.analyzed.output) + case (e, a) => + e.unbind(a :: Nil).resolve(execution.analyzed.output) + } + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + */ + def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3, A4]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3], + col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + /** * Applies the given function to each cogrouped data. For each unique group, the function will * be passed the grouping key and 2 iterators containing all elements in the group from diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3e9b621cfd67f..d61e17edc64ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) } + test("typed aggregation: expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int]), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), + ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), + ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum("_2").as[Int], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() From bc5d6c03893a9bd340d6b94d3550e25648412241 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 22:03:26 -0800 Subject: [PATCH 209/324] [SPARK-11541][SQL] Break JdbcDialects.scala into multiple files and mark various dialects as private. Author: Reynold Xin Closes #9511 from rxin/SPARK-11541. --- project/MimaExcludes.scala | 19 +- .../org/apache/spark/sql/GroupedData.scala | 2 +- .../spark/sql/jdbc/AggregatedDialect.scala | 44 ++++ .../apache/spark/sql/jdbc/DB2Dialect.scala | 32 +++ .../apache/spark/sql/jdbc/DerbyDialect.scala | 44 ++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 190 +----------------- .../spark/sql/jdbc/MsSqlServerDialect.scala | 41 ++++ .../apache/spark/sql/jdbc/MySQLDialect.scala | 48 +++++ .../apache/spark/sql/jdbc/OracleDialect.scala | 45 +++++ .../spark/sql/jdbc/PostgresDialect.scala | 54 +++++ 10 files changed, 332 insertions(+), 187 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 40f5c9fec8bb8..dacef911e397e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -116,7 +116,24 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") ) ++ Seq( // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), + // SPARK-11541 mark various JDBC dialects as private + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 7cf66b65c8722..f9eab5c2e965b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.NumericType class GroupedData protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], - private val groupType: GroupedData.GroupType) { + groupType: GroupedData.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala new file mode 100644 index 0000000000000..467d8d62d1b7f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{DataType, MetadataBuilder} + +/** + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * + * @param dialects List of dialects. + */ +private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala new file mode 100644 index 0000000000000..b1cb0e55026be --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{BooleanType, StringType, DataType} + + +private object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala new file mode 100644 index 0000000000000..84f68e779c38c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object DerbyDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case t: DecimalType if t.precision > 31 => + Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f9a6a09b6270d..14bfea4e3e287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.jdbc -import java.sql.Types - import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -115,11 +113,10 @@ abstract class JdbcDialect { @DeveloperApi object JdbcDialects { - private var dialects = List[JdbcDialect]() - /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. * Readding an existing dialect will cause a move-to-front. + * * @param dialect The new dialect. */ def registerDialect(dialect: JdbcDialect) : Unit = { @@ -128,12 +125,15 @@ object JdbcDialects { /** * Unregister a dialect. Does nothing if the dialect is not registered. + * * @param dialect The jdbc dialect. */ def unregisterDialect(dialect : JdbcDialect) : Unit = { dialects = dialects.filterNot(_ == dialect) } + private[this] var dialects = List[JdbcDialect]() + registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) @@ -141,7 +141,6 @@ object JdbcDialects { registerDialect(DerbyDialect) registerDialect(OracleDialect) - /** * Fetch the JdbcDialect class corresponding to a given database url. */ @@ -156,187 +155,8 @@ object JdbcDialects { } /** - * :: DeveloperApi :: - * AggregatedDialect can unify multiple dialects into one virtual Dialect. - * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. - * @param dialects List of dialects. - */ -@DeveloperApi -class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - - require(dialects.nonEmpty) - - override def canHandle(url : String): Boolean = - dialects.map(_.canHandle(url)).reduce(_ && _) - - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = { - dialects.flatMap(_.getJDBCType(dt)).headOption - } -} - -/** - * :: DeveloperApi :: * NOOP dialect object, always returning the neutral element. */ -@DeveloperApi -case object NoopDialect extends JdbcDialect { +private object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } - -/** - * :: DeveloperApi :: - * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. - */ -@DeveloperApi -case object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - case _ => None - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } - -} - -/** - * :: DeveloperApi :: - * Default mysql dialect to read bit/bitsets correctly. - */ -@DeveloperApi -case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - Option(LongType) - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - Option(BooleanType) - } else None - } - - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } -} - -/** - * :: DeveloperApi :: - * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. - * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). - */ -@DeveloperApi -case object DB2Dialect extends JdbcDialect { - - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. - */ -@DeveloperApi -case object MsSqlServerDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (typeName.contains("datetimeoffset")) { - // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Apache Derby dialect, mapping real on read - * and string/byte/short/boolean/decimal on write. - */ -@DeveloperApi -case object DerbyDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.REAL) Option(FloatType) else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL - case (t: DecimalType) if (t.precision > 31) => - Some(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) - case _ => None - } - -} - -/** - * :: DeveloperApi :: - * Default Oracle dialect, mapping a nonspecific numeric type to a general decimal type. - */ -@DeveloperApi -case object OracleDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - if (sqlType == Types.NUMERIC && size == 0) { - // This is sub-optimal as we have to pick a precision/scale in advance whereas the data - // in Oracle is allowed to have different precision/scale for each value. - Some(DecimalType(DecimalType.MAX_PRECISION, 10)) - } else { - None - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala new file mode 100644 index 0000000000000..3eb722b070d5d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ + + +private object MsSqlServerDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Option(StringType) + } else { + None + } + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala new file mode 100644 index 0000000000000..da413ed1f08b5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types.{BooleanType, LongType, DataType, MetadataBuilder} + + +private case object MySQLDialect extends JdbcDialect { + + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Option(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Option(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala new file mode 100644 index 0000000000000..4165c382689f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object OracleDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + if (sqlType == Types.NUMERIC && size == 0) { + // This is sub-optimal as we have to pick a precision/scale in advance whereas the data + // in Oracle is allowed to have different precision/scale for each value. + Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else { + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala new file mode 100644 index 0000000000000..e701a7fcd9e16 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object PostgresDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Option(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("json")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { + Option(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} From 253e87e8ab8717ffef40a6d0d376b1add155ef90 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 6 Nov 2015 06:38:49 -0800 Subject: [PATCH 210/324] [SPARK-11453][SQL][FOLLOW-UP] remove DecimalLit A cleanup for https://github.com/apache/spark/pull/9085. The `DecimalLit` is very similar to `FloatLit`, we can just keep one of them. Also added low level unit test at `SqlParserSuite` Author: Wenchen Fan Closes #9482 from cloud-fan/parser. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 23 ++++++++----------- .../apache/spark/sql/catalyst/SqlParser.scala | 20 ++++------------ .../spark/sql/catalyst/SqlParserSuite.scala | 21 +++++++++++++++++ 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 04ac4f20c66ec..bdc52c08acb66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -78,10 +78,6 @@ private[sql] abstract class AbstractSparkSQLParser } class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString: String = chars - } - case class DecimalLit(chars: String) extends Token { override def toString: String = chars } @@ -106,17 +102,16 @@ class SqlLexical extends StdLexical { } override lazy val token: Parser[Token] = - ( rep1(digit) ~ ('.' ~> digit.*).? ~ (exp ~> sign.? ~ rep1(digit)) ^^ { - case i ~ None ~ (sig ~ rest) => - DecimalLit(i.mkString + "e" + sig.mkString + rest.mkString) - case i ~ Some(d) ~ (sig ~ rest) => - DecimalLit(i.mkString + "." + d.mkString + "e" + sig.mkString + rest.mkString) - } + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } | digit.* ~ identChar ~ (identChar | digit).* ^^ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ { case chars => StringLit(chars mkString "") } @@ -133,8 +128,10 @@ class SqlLexical extends StdLexical { override def identChar: Parser[Elem] = letter | elem('_') - private lazy val sign: Parser[Elem] = elem("s", c => c == '+' || c == '-') - private lazy val exp: Parser[Elem] = elem("e", c => c == 'E' || c == 'e') + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 440e9e28fa783..cd717c09f8e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -334,27 +334,15 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { - case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) - } - | sign.? ~ unsignedDecimal ^^ { - case s ~ d => Literal(toDecimalOrDouble(s.getOrElse("") + d)) - } + | sign.? ~ unsignedFloat ^^ + { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) protected lazy val unsignedFloat: Parser[String] = ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) ) - protected lazy val unsignedDecimal: Parser[String] = - ( "." ~> decimalLit ^^ { u => "0." + u } - | elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - ) - - def decimalLit: Parser[String] = - elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - protected lazy val sign: Parser[String] = ("+" | "-") protected lazy val integral: Parser[String] = @@ -477,7 +465,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | (ident <~ "."). + <~ "*" ^^ { case target => UnresolvedStar(Option(target))} + | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))} | primary ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index ea28bfa021bed..9ff893b84775b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -126,4 +126,25 @@ class SqlParserSuite extends PlanTest { checkSingleUnit("13.123456789", "second") checkSingleUnit("-13.123456789", "second") } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = SqlParser.parse("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight(".9e+2", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[RuntimeException](SqlParser.parse("SELECT .e3")) + } } From cf69ce136590fea51843bc54f44f0f45c7d0ac36 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 6 Nov 2015 14:51:53 +0000 Subject: [PATCH 211/324] [SPARK-11511][STREAMING] Fix NPE when an InputDStream is not used Just ignored `InputDStream`s that have null `rememberDuration` in `DStreamGraph.getMaxInputStreamRememberDuration`. Author: Shixiong Zhu Closes #9476 from zsxwing/SPARK-11511. --- .../apache/spark/streaming/DStreamGraph.scala | 3 ++- .../spark/streaming/StreamingContextSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 1b0b7890b3b00..7829f5e887995 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -167,7 +167,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { * safe remember duration which can be used to perform cleanup operations. */ def getMaxInputStreamRememberDuration(): Duration = { - inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + // If an InputDStream is not used, its `rememberDuration` will be null and we can ignore them + inputStreams.map(_.rememberDuration).filter(_ != null).maxBy(_.milliseconds) } @throws(classOf[IOException]) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index c7a877142b374..860fac29c0ee0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -780,6 +780,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo "Please don't use queueStream when checkpointing is enabled.")) } + test("Creating an InputDStream but not using it should not crash") { + ssc = new StreamingContext(master, appName, batchDuration) + val input1 = addInputStream(ssc) + val input2 = addInputStream(ssc) + val output = new TestOutputStream(input2) + output.register() + val batchCount = new BatchCounter(ssc) + ssc.start() + // Just wait for completing 2 batches to make sure it triggers + // `DStream.getMaxInputStreamRememberDuration` + batchCount.waitUntilBatchesCompleted(2, 10000) + // Throw the exception if crash + ssc.awaitTerminationOrTimeout(1) + ssc.stop() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) From 574141a29835ce78d68c97bb54336cf4fd3c39d3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 Nov 2015 10:52:04 -0800 Subject: [PATCH 212/324] [SPARK-9162] [SQL] Implement code generation for ScalaUDF JIRA: https://issues.apache.org/jira/browse/SPARK-9162 Currently ScalaUDF extends CodegenFallback and doesn't provide code generation implementation. This path implements code generation for ScalaUDF. Author: Liang-Chi Hsieh Closes #9270 from viirya/scalaudf-codegen. --- .../sql/catalyst/expressions/ScalaUDF.scala | 85 ++++++++++++++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 41 +++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 11c7950c0613b..3388cc20a9803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType /** @@ -31,7 +31,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true @@ -60,6 +60,10 @@ case class ScalaUDF( */ + // Accessors used in genCode + def userDefinedFunc(): AnyRef = function + def getChildren(): Seq[Expression] = children + private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] @@ -960,6 +964,83 @@ case class ScalaUDF( } // scalastyle:on + + // Generate codes used to convert the arguments to Scala type for user-defined funtions + private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + val scalaUDFClassName = classOf[ScalaUDF].getName + + val converterTerm = ctx.freshName("converter") + val expressionIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, converterTerm, + s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + + s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());") + converterTerm + } + + override def genCode( + ctx: CodeGenContext, + ev: GeneratedExpressionCode): String = { + + ctx.references += this + + val scalaUDFClassName = classOf[ScalaUDF].getName + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + + // Generate codes used to convert the returned value of user-defined functions to Catalyst type + val catalystConverterTerm = ctx.freshName("catalystConverter") + val catalystConverterTermIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToCatalystConverter((($scalaUDFClassName)expressions" + + s"[$catalystConverterTermIdx]).dataType());") + + val resultTerm = ctx.freshName("result") + + // This must be called before children expressions' codegen + // because ctx.references is used in genCodeForConverter + val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + + // Initialize user-defined function + val funcClassName = s"scala.Function${children.size}" + + val funcTerm = ctx.freshName("udf") + val funcExpressionIdx = ctx.references.size - 1 + ctx.addMutableState(funcClassName, funcTerm, + s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" + + s"[$funcExpressionIdx]).userDefinedFunc());") + + // codegen for children expressions + val evals = children.map(_.gen(ctx)) + + // Generate the codes for expressions and calling user-defined function + // We need to get the boxedType of dataType's javaType here. Because for the dataType + // such as IntegerType, its javaType is `int` and the returned type of user-defined + // function is Object. Trying to convert an Object to `int` will cause casting exception. + val evalCode = evals.map(_.code).mkString + val funcArguments = converterTerms.zip(evals).map { + case (converter, eval) => s"$converter.apply(${eval.value})" + }.mkString(",") + val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + + s".apply($funcTerm.apply($funcArguments));" + + evalCode + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + Boolean ${ev.isNull}; + + $callFunc + + ${ev.value} = $resultTerm; + ${ev.isNull} = $resultTerm == null; + """ + } + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0435a0dba6ad..9837fa6bdb357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -191,4 +191,45 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } + + test("udf in different types") { + sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + sqlContext.udf.register("decimalDataFunc", + (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) + sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + sqlContext.udf.register("arrayDataFunc", + (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) + sqlContext.udf.register("mapDataFunc", + (data: scala.collection.Map[Int, String]) => { data }) + sqlContext.udf.register("complexDataFunc", + (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(key, value) AS t from testData) tmp").toDF(), + testData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT decimalDataFunc(a, b) AS t FROM decimalData) tmp + """.stripMargin).toDF(), decimalData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT binaryDataFunc(a, b) AS t FROM binaryData) tmp + """.stripMargin).toDF(), binaryData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp + """.stripMargin).toDF(), arrayData.toDF()) + checkAnswer( + sql(""" + | SELECT mapDataFunc(data) AS t FROM mapData + """.stripMargin).toDF(), mapData.toDF()) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp + """.stripMargin).toDF(), complexData.select("m", "a", "b")) + } } From c048929c6a9f7ce57f384037cd6c0bf5751c447a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 6 Nov 2015 11:11:36 -0800 Subject: [PATCH 213/324] [SPARK-10978][SQL][FOLLOW-UP] More comprehensive tests for PR #9399 This PR adds test cases that test various column pruning and filter push-down cases. Author: Cheng Lian Closes #9468 from liancheng/spark-10978.follow-up. --- .../spark/sql/sources/FilteredScanSuite.scala | 21 +- .../SimpleTextHadoopFsRelationSuite.scala | 335 ++++++++++++++++-- .../sql/sources/SimpleTextRelation.scala | 11 + 3 files changed, 321 insertions(+), 46 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 7541e723029bf..2cad964e55b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.existentials import org.apache.spark.rdd.RDD -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - +import org.apache.spark.unsafe.types.UTF8String class FilteredScanSource extends RelationProvider { override def createRelation( @@ -130,7 +129,7 @@ object ColumnsRequired { var set: Set[String] = Set.empty } -class FilteredScanSuite extends DataSourceTest with SharedSQLContext { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper { protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { @@ -144,9 +143,6 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { | to '10' |) """.stripMargin) - - // UDF for testing filter push-down - caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3) } sqlTest( @@ -276,14 +272,15 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c")) testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c")) - // Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down. - testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c")) + // Filters referencing multiple columns are not convertible, all referenced columns must be + // required. + testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", "b", "c")) - // A query with an unconvertible filter, an unhandled filter, and a handled filter. + // A query with an inconvertible filter, an unhandled filter, and a handled filter. testPushDown( """SELECT a | FROM oneToTenFiltered - | WHERE udf_gt3(b) + | WHERE a + b > 9 | AND b < 16 | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index d945408341fc9..9251a69f31a47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -17,15 +17,21 @@ package org.apache.spark.sql.sources +import java.io.File + import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.execution.PhysicalRDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper} +import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, Row, execution} +import org.apache.spark.util.Utils -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { import testImplicits._ override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName @@ -70,43 +76,304 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } - private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName) - private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName) - - test("unhandledFilters") { - withTempPath { dir => - - val path = dir.getCanonicalPath - writer.save(s"$path/p=0") - writer.save(s"$path/p=1") - - val isOdd = udf((_: Int) % 2 == 1) - val df = reader.load(path) - .filter( - // This filter is inconvertible - isOdd('a) && - // This filter is convertible but unhandled - 'a > 1 && - // This filter is convertible and handled - 'b > "val_1" && - // This filter references a partiiton column, won't be pushed down - 'p === 1 - ).select('a, 'p) - val rawScan = df.queryExecution.executedPlan collect { + private var tempPath: File = _ + + private var partitionedDF: DataFrame = _ + + private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + + protected override def beforeAll(): Unit = { + this.tempPath = Utils.createTempDir() + + val df = sqlContext.range(10).select( + 'id cast IntegerType as 'a, + ('id cast IntegerType) * 2 as 'b, + concat(lit("val_"), 'id) as 'c + ) + + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=0") + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=1") + + partitionedDF = partitionedReader.load(tempPath.getCanonicalPath) + } + + override protected def afterAll(): Unit = { + Utils.deleteRecursively(tempPath) + } + + private def partitionedWriter(df: DataFrame) = + df.write.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + private def partitionedReader = + sqlContext.read.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + /** + * Constructs test cases that test column pruning and filter push-down. + * + * For filter push-down, the following filters are not pushed-down. + * + * 1. Partitioning filters don't participate filter push-down, they are handled separately in + * `DataSourceStrategy` + * + * 2. Catalyst filter `Expression`s that cannot be converted to data source `Filter`s are not + * pushed down (e.g. UDF and filters referencing multiple columns). + * + * 3. Catalyst filter `Expression`s that can be converted to data source `Filter`s but cannot be + * handled by the underlying data source are not pushed down (e.g. returned from + * `BaseRelation.unhandledFilters()`). + * + * Note that for [[SimpleTextRelation]], all data source [[Filter]]s other than [[GreaterThan]] + * are unhandled. We made this assumption in [[SimpleTextRelation.unhandledFilters()]] only + * for testing purposes. + * + * @param projections Projection list of the query + * @param filter Filter condition of the query + * @param requiredColumns Expected names of required columns + * @param pushedFilters Expected data source [[Filter]]s that are pushed down + * @param inconvertibleFilters Expected Catalyst filter [[Expression]]s that cannot be converted + * to data source [[Filter]]s + * @param unhandledFilters Expected Catalyst flter [[Expression]]s that can be converted to data + * source [[Filter]]s but cannot be handled by the data source relation + * @param partitioningFilters Expected Catalyst filter [[Expression]]s that reference partition + * columns + * @param expectedRawScanAnswer Expected query result of the raw table scan returned by the data + * source relation + * @param expectedAnswer Expected query result of the full query + */ + def testPruningAndFiltering( + projections: Seq[Column], + filter: Column, + requiredColumns: Seq[String], + pushedFilters: Seq[Filter], + inconvertibleFilters: Seq[Column], + unhandledFilters: Seq[Column], + partitioningFilters: Seq[Column])( + expectedRawScanAnswer: => Seq[Row])( + expectedAnswer: => Seq[Row]): Unit = { + test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") { + val df = partitionedDF.where(filter).select(projections: _*) + val queryExecution = df.queryExecution + val executedPlan = queryExecution.executedPlan + + val rawScan = executedPlan.collect { case p: PhysicalRDD => p } match { - case Seq(p) => p + case Seq(scan) => scan + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") } - val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType) + markup("Checking raw scan answer") + checkAnswer( + DataFrame(sqlContext, LogicalRDD(rawScan.output, rawScan.rdd)(sqlContext)), + expectedRawScanAnswer) - assertResult(Set((2, 1), (3, 1))) { - rawScan.execute().collect() - .map { CatalystTypeConverters.convertToScala(_, outputSchema) } - .map { case Row(a, p) => (a, p) }.toSet + markup("Checking full query answer") + checkAnswer(df, expectedAnswer) + + markup("Checking required columns") + assert(requiredColumns === SimpleTextRelation.requiredColumns) + + val nonPushedFilters = { + val boundFilters = executedPlan.collect { + case f: execution.Filter => f + } match { + case Nil => Nil + case Seq(f) => splitConjunctivePredicates(f.condition) + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + + // Unbound these bound filters so that we can easily compare them with expected results. + boundFilters.map { + _.transform { case a: AttributeReference => UnresolvedAttribute(a.name) } + }.toSet } - checkAnswer(df, Row(3, 1)) + markup("Checking pushed filters") + assert(SimpleTextRelation.pushedFilters === pushedFilters.toSet) + + val expectedInconvertibleFilters = inconvertibleFilters.map(_.expr).toSet + val expectedUnhandledFilters = unhandledFilters.map(_.expr).toSet + val expectedPartitioningFilters = partitioningFilters.map(_.expr).toSet + + markup("Checking unhandled and inconvertible filters") + assert(expectedInconvertibleFilters ++ expectedUnhandledFilters === nonPushedFilters) + + markup("Checking partitioning filters") + val actualPartitioningFilters = splitConjunctivePredicates(filter.expr).filter { + _.references.contains(UnresolvedAttribute("p")) + }.toSet + + // Partitioning filters are handled separately and don't participate filter push-down. So they + // shouldn't be part of non-pushed filters. + assert(expectedPartitioningFilters.intersect(nonPushedFilters).isEmpty) + assert(expectedPartitioningFilters === actualPartitioningFilters) } } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'p > 0, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Nil, + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('c, 'p), + filter = 'a < 3 && 'p > 0, + requiredColumns = Seq("c", "a"), + pushedFilters = Nil, + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 3), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row("val_0", 1, 0), + Row("val_1", 1, 1), + Row("val_2", 1, 2), + Row("val_3", 1, 3), + Row("val_4", 1, 4), + Row("val_5", 1, 5), + Row("val_6", 1, 6), + Row("val_7", 1, 7), + Row("val_8", 1, 8), + Row("val_9", 1, 9)) + } { + Seq( + Row("val_0", 1), + Row("val_1", 1), + Row("val_2", 1)) + } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'a > 8, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(18, 0), + Row(18, 1)) + } { + Seq( + Row(18, 0), + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(18, 1)) + } { + Seq( + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("c", "val_7")), + inconvertibleFilters = Nil, + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1), + Row(18, 1)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("c", "val_7")), + inconvertibleFilters = Seq('a % 2 === 0), + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1, 8), + Row(18, 1, 9)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 7 && 'a < 9, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("a", 7)), + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 9), + partitioningFilters = Nil + ) { + Seq( + Row(16, 0, 8), + Row(16, 1, 8), + Row(18, 0, 9), + Row(18, 1, 9)) + } { + Seq( + Row(16, 0), + Row(16, 1)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index da09e1b00ae48..bdc48a383bbbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -128,6 +128,9 @@ class SimpleTextRelation( filters: Array[Filter], inputFiles: Array[FileStatus]): RDD[Row] = { + SimpleTextRelation.requiredColumns = requiredColumns + SimpleTextRelation.pushedFilters = filters.toSet + val fields = this.dataSchema.map(_.dataType) val inputAttributes = this.dataSchema.toAttributes val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name)) @@ -191,6 +194,14 @@ class SimpleTextRelation( } } +object SimpleTextRelation { + // Used to test column pruning + var requiredColumns: Seq[String] = Nil + + // Used to test filter push-down + var pushedFilters: Set[Filter] = Set.empty +} + /** * A simple example [[HadoopFsRelationProvider]]. */ From 8211aab0793cf64202b99be4f31bb8a9ae77050d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 6 Nov 2015 11:13:51 -0800 Subject: [PATCH 214/324] [SPARK-9858][SQL] Add an ExchangeCoordinator to estimate the number of post-shuffle partitions for aggregates and joins (follow-up) https://issues.apache.org/jira/browse/SPARK-9858 This PR is the follow-up work of https://github.com/apache/spark/pull/9276. It addresses JoshRosen's comments. Author: Yin Huai Closes #9453 from yhuai/numReducer-followUp. --- .../plans/physical/partitioning.scala | 8 - .../apache/spark/sql/execution/Exchange.scala | 40 +++-- .../sql/execution/ExchangeCoordinator.scala | 31 ++-- .../apache/spark/sql/CachedTableSuite.scala | 150 ++++++++++++++---- 4 files changed, 167 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 9312c8123e92e..86b9417477ba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -165,11 +165,6 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other - - def withNumPartitions(newNumPartitions: Int): Partitioning = { - throw new IllegalStateException( - s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}") - } } object Partitioning { @@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def withNumPartitions(newNumPartitions: Int): HashPartitioning = { - HashPartitioning(expressions, newNumPartitions) - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 0f72ec6cc107a..a4ce328c1a9eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -242,7 +242,7 @@ case class Exchange( // update the number of post-shuffle partitions. specifiedPartitionStartIndices.foreach { indices => assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = newPartitioning.withNumPartitions(indices.length) + newPartitioning = UnknownPartitioning(indices.length) } new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } @@ -262,7 +262,7 @@ case class Exchange( object Exchange { def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, None: Option[ExchangeCoordinator]) + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } } @@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child.outputPartitioning match { case hash: HashPartitioning => true case collection: PartitioningCollection => - collection.partitionings.exists(_.isInstanceOf[HashPartitioning]) + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) case _ => false } } @@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // First check if the existing partitions of the children all match. This means they are // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. - // TODO: this should be a cost based decision. For example, a big relation should probably - // maintain its existing number of partitions and smaller partitions should be shuffled. - // defaultPartitions is arbitrary. - val numPartitions = children.head.outputPartitioning.numPartitions + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max val useExistingPartitioning = children.zip(requiredChildDistributions).forall { case (child, distribution) => { child.outputPartitioning.guarantees( - createPartitioning(distribution, numPartitions)) + createPartitioning(distribution, maxChildrenNumPartitions)) } } children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. children } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + children.zip(requiredChildDistributions).map { case (child, distribution) => { val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + createPartitioning(distribution, numPartitions) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { - Exchange(targetPartitioning, child) + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case Exchange(_, c, _) => Exchange(targetPartitioning, c) + case _ => Exchange(targetPartitioning, child) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala index 8dbd69e1f44b8..827fdd278460a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.{Map => JMap, HashMap => JHashMap} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator( * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be * called in the `doPrepare` method of an [[Exchange]] operator. */ + @GuardedBy("this") def registerExchange(exchange: Exchange): Unit = synchronized { exchanges += exchange } @@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator( */ private[sql] def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. assert(mapOutputStatistics.length <= numExchanges) @@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator( val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum // The max at here is to make sure that when we have an empty table, we // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. val maxPostShuffleInputSize = math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) @@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator( // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). assert( distinctNumPreShufflePartitions.length == 1, "There should be only one distinct value of the number pre-shuffle partitions " + @@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator( partitionStartIndices.toArray } + @GuardedBy("this") private def doEstimationIfNecessary(): Unit = synchronized { // It is unlikely that this method will be called from multiple threads // (when multiple threads trigger the execution of THIS physical) @@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator( // Wait for the finishes of those submitted map stages. val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) - i = 0 - while (i < submittedStageFutures.length) { + var j = 0 + while (j < submittedStageFutures.length) { // This call is a blocking call. If the stage has not finished, we will wait at here. - mapOutputStatistics(i) = submittedStageFutures(i).get() - i += 1 + mapOutputStatistics(j) = submittedStageFutures(j).get() + j += 1 } // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the @@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator( Some(estimatePartitionStartIndices(mapOutputStatistics)) } - i = 0 - while (i < numExchanges) { - val exchange = exchanges(i) + var k = 0 + while (k < numExchanges) { + val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) newPostShuffleRDDs.put(exchange, rdd) - i += 1 + k += 1 } // Finally, we set postShuffleRDDs and estimated. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index dbcb011f603f7..bce94dafad755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SharedSQLContext { +class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { @@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) sqlContext.uncacheTable("orderedTable") + sqlContext.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - testData.repartition(numPartitions, $"key").registerTempTable("t1") - testData2.repartition(numPartitions, $"a").registerTempTable("t2") + withTempTable("t1", "t2") { + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") - // Joining them should result in no exchanges. - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) - checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), - sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // Grouping on the partition key should result in no exchanges - verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) - checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), - sql("SELECT count(*) FROM testData GROUP BY key")) + // One side of join is not partitioned in the desired way. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") } - // Distribute the tables into non-matching number of partitions. Need to shuffle. - testData.repartition(6, $"key").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(12, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // One side of join is not partitioned in the desired way. Need to shuffle. - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + // One side of join is not partitioned in the desired way. Since the number of partitions of + // the side that has already partitioned is smaller than the side that is not partitioned, + // we shuffle both side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 2) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // repartition's column ordering is different from group by column ordering. + // But they use the same set of columns. + withTempTable("t1") { + testData.repartition(6, $"value", $"key").registerTempTable("t1") + sqlContext.cacheTable("t1") + + val query = sql("SELECT value, key from t1 group by key, value") + verifyNumExchanges(query, 0) + checkAnswer( + query, + testData.distinct().select($"value", $"key")) + sqlContext.uncacheTable("t1") + } + + // repartition's column ordering is different from join condition's column ordering. + // We will still shuffle because hashcodes of a row depend on the column ordering. + // If we do not shuffle, we may actually partition two tables in totally two different way. + // See PartitioningSuite for more details. + withTempTable("t1", "t2") { + val df1 = testData + df1.repartition(6, $"value", $"key").registerTempTable("t1") + val df2 = testData2.select($"a", $"b".cast("string")) + df2.repartition(6, $"a", $"b").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = + sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } } } From 62bb290773c9f9fa53cbe6d4eedc6e153761a763 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 6 Nov 2015 20:05:18 +0000 Subject: [PATCH 215/324] Typo fixes + code readability improvements Author: Jacek Laskowski Closes #9501 from jaceklaskowski/typos-with-style. --- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 14 ++++++-------- .../org/apache/spark/scheduler/DAGScheduler.scala | 12 +++++++++--- .../apache/spark/scheduler/ShuffleMapTask.scala | 10 +++++----- .../scala/org/apache/spark/scheduler/TaskSet.scala | 2 +- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index d841f05ec52cf..0453614f6a1d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -88,8 +88,8 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed - * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD * creates. * @param inputFormatClass Storage format of the data to be read. @@ -123,7 +123,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableConfiguration(conf)) .asInstanceOf[Broadcast[SerializableConfiguration]], - None /* initLocalJobConfFuncOpt */, + initLocalJobConfFuncOpt = None, inputFormatClass, keyClass, valueClass, @@ -184,8 +184,9 @@ class HadoopRDD[K, V]( protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] - if (newInputFormat.isInstanceOf[Configurable]) { - newInputFormat.asInstanceOf[Configurable].setConf(conf) + newInputFormat match { + case c: Configurable => c.setConf(conf) + case _ => } newInputFormat } @@ -195,9 +196,6 @@ class HadoopRDD[K, V]( // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(jobConf) - } val inputSplits = inputFormat.getSplits(jobConf, minPartitions) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a1f0fd05f661a..4a9518fff4e7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -541,8 +541,7 @@ class DAGScheduler( } /** - * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object - * can be used to block until the the job finishes executing or can be used to cancel the job. + * Submit an action job to the scheduler. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -551,6 +550,11 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], @@ -584,7 +588,7 @@ class DAGScheduler( /** * Run an action job on the given RDD and pass all the results to the resultHandler function as - * they arrive. Throws an exception if the job fials, or returns normally if successful. + * they arrive. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -593,6 +597,8 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @throws Exception when the job fails */ def runJob[T, U]( rdd: RDD[T], diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f478f9982afef..ea97ef0e746d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -27,11 +27,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter /** -* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner -* specified in the ShuffleDependency). -* -* See [[org.apache.spark.scheduler.Task]] for more information. -* + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * * @param stageId id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index be8526ba9b94f..517c8991aed78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -29,7 +29,7 @@ private[spark] class TaskSet( val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + stageAttemptId + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } From 49f1a820372d1cba41f3f00d07eb5728f2ed6705 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 6 Nov 2015 20:06:24 +0000 Subject: [PATCH 216/324] [SPARK-10116][CORE] XORShiftRandom.hashSeed is random in high bits https://issues.apache.org/jira/browse/SPARK-10116 This is really trivial, just happened to notice it -- if `XORShiftRandom.hashSeed` is really supposed to have random bits throughout (as the comment implies), it needs to do something for the conversion to `long`. mengxr mkolod Author: Imran Rashid Closes #8314 from squito/SPARK-10116. --- R/pkg/inst/tests/test_sparkSQL.R | 8 +-- .../spark/util/random/XORShiftRandom.scala | 6 ++- .../java/org/apache/spark/JavaAPISuite.java | 20 ++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 52 +++++++++++++------ .../util/random/XORShiftRandomSuite.scala | 15 ++++++ .../MultilayerPerceptronClassifierSuite.scala | 5 +- .../spark/ml/feature/Word2VecSuite.scala | 16 ++++-- .../clustering/StreamingKMeansSuite.scala | 13 +++-- python/pyspark/ml/feature.py | 20 +++---- python/pyspark/ml/recommendation.py | 6 +-- python/pyspark/mllib/recommendation.py | 4 +- python/pyspark/sql/dataframe.py | 6 +-- .../catalyst/expressions/RandomSuite.scala | 8 +-- .../apache/spark/sql/JavaDataFrameSuite.java | 6 ++- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- 15 files changed, 128 insertions(+), 61 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 816315b1e4e13..92cff1fba7193 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -875,9 +875,9 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") - expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01) expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") - expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) + expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01) }) test_that("string operators", { @@ -1458,8 +1458,8 @@ test_that("sampleBy() on a DataFrame", { fractions <- list("0" = 0.1, "1" = 0.2) sample <- sampleBy(df, "key", fractions, 0) result <- collect(orderBy(count(groupBy(sample, "key")), "key")) - expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) - expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) + expect_identical(as.list(result[1, ]), list(key = "0", count = 3)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) }) test_that("SQL error message is returned from JVM", { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 85fb923cd9bc7..e8cdb6e98bf36 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { private[spark] object XORShiftRandom { /** Hash seeds to have 0/1 bits throughout. */ - private def hashSeed(seed: Long): Long = { + private[random] def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() - MurmurHash3.bytesHash(bytes) + val lowBits = MurmurHash3.bytesHash(bytes) + val highBits = MurmurHash3.bytesHash(bytes, lowBits) + (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index fd8f7f39b7cc8..4d4e9820500e7 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -146,21 +146,29 @@ public void intersection() { public void sample() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD sample20 = rdd.sample(true, 0.2, 3); + // the seeds here are "magic" to make this work out nicely + JavaRDD sample20 = rdd.sample(true, 0.2, 8); Assert.assertEquals(2, sample20.count()); - JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5); + JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 2); Assert.assertEquals(2, sample20WithoutReplacement.count()); } @Test public void randomSplit() { - List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + List ints = new ArrayList<>(1000); + for (int i = 0; i < 1000; i++) { + ints.add(i); + } JavaRDD rdd = sc.parallelize(ints); JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); + // the splits aren't perfect -- not enough data for them to be -- just check they're about right Assert.assertEquals(3, splits.length); - Assert.assertEquals(1, splits[0].count()); - Assert.assertEquals(2, splits[1].count()); - Assert.assertEquals(7, splits[2].count()); + long s0 = splits[0].count(); + long s1 = splits[1].count(); + long s2 = splits[2].count(); + Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); + Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @Test diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 1321ec84735b5..7d2cfcca9436a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.util.Progressable @@ -578,17 +579,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { + def assertBinomialSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { + if (exact) { + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new BinomialDistribution(trials, p) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } + } + } + + def assertPoissonSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { if (exact) { - return expected == actual + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new PoissonDistribution(p * trials) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -613,8 +633,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) + val trials = stratifiedData.countByKey() val fractions = Map("1" -> samplingRate, "0" -> samplingRate) val sample = if (exact) { stratifiedData.sampleByKeyExact(false, fractions, seed) @@ -623,8 +642,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } val sampleCounts = sample.countByKey() val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + sampleCounts.foreach { case (k, v) => + assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt, + p = samplingRate) + } assert(takeSample.size === takeSample.toSet.size) takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } } @@ -635,6 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { + val trials = stratifiedData.countByKey() val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -646,7 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val sampleCounts = sample.countByKey() val takeSample = sample.collect() sampleCounts.foreach { case (k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate) } val groupedByKey = takeSample.groupBy(_._1) for ((key, v) <- groupedByKey) { @@ -657,7 +679,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { assert(v.toSet.size <= expectedSampleSize(key)) } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate) } } } diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index d26667bf720cf..a5b50fce5c0a9 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -65,4 +65,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } + + test ("hashSeed has random bits throughout") { + val totalBitCount = (0 until 10).map { seed => + val hashed = XORShiftRandom.hashSeed(seed) + val bitCount = java.lang.Long.bitCount(hashed) + // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we + // don't have all 0s or 1s in the high bits + bitCount should be > 20 + bitCount should be < 44 + bitCount + }.sum + // and over all the seeds, very close to equal numbers of 0s & 1s + totalBitCount should be > (32 * 10 - 30) + totalBitCount should be < (32 * 10 + 30) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 17db8c44777d4..a326432d017fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -61,8 +61,9 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 1), 2) val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 @@ -70,7 +71,7 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(11L) // currently this seed is ignored .setMaxIter(numIterations) val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a2e46f2029956..23dfdaa9f8fc6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -66,9 +66,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } @@ -99,8 +102,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { val realVectors = model.getVectors.sort("word").select("vector").map { case Row(v: Vector) => v }.collect() + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExpected = Seq( + Vectors.dense(0.3326166272163391, -0.5603077411651611, -0.2309209555387497), + Vectors.dense(0.32463887333869934, -0.9306551218032837, 1.393115520477295), + Vectors.dense(-0.27150997519493103, 0.4372006058692932, -0.13465698063373566) + ) - realVectors.zip(expectedVectors).foreach { + realVectors.zip(magicExpected).foreach { case (real, expected) => assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") } @@ -122,7 +132,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823) val (synonyms, similarity) = model.findSynonyms("a", 2).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 3645d29dccdb2..65e37c64d404e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -98,9 +98,16 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers - // NOTE exact assignment depends on the initialization! - assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) - assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + // cluster ordering is arbitrary, so choose closest cluster + val d0 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(0)) + val d1 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(1)) + val (c0, c1) = if (d0 < d1) { + (centers(0), centers(1)) + } else { + (centers(1), centers(0)) + } + assert(c0 ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(c1 ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) } test("detecting dying clusters") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c7b6dd926c3e8..b02d41b52ab25 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1788,21 +1788,21 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has +----+--------------------+ |word| vector| +----+--------------------+ - | a|[-0.3511952459812...| - | b|[0.29077222943305...| - | c|[0.02315592765808...| + | a|[0.09461779892444...| + | b|[1.15474212169647...| + | c|[-0.3794820010662...| +----+--------------------+ ... >>> model.findSynonyms("a", 2).show() - +----+-------------------+ - |word| similarity| - +----+-------------------+ - | b|0.29255685145799626| - | c|-0.5414068302988307| - +----+-------------------+ + +----+--------------------+ + |word| similarity| + +----+--------------------+ + | b| 0.16782984556103436| + | c|-0.46761559092107646| + +----+--------------------+ ... >>> model.transform(doc).head().model - DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461]) .. versionadded:: 1.4.0 """ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index ec5748a1cfe94..b44c66f73cc49 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -76,11 +76,11 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, prediction=0.39...) + Row(user=0, item=2, prediction=-0.13807615637779236) >>> predictions[1] - Row(user=1, item=0, prediction=3.19...) + Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] - Row(user=2, item=0, prediction=-1.15...) + Row(user=2, item=0, prediction=-1.5018409490585327) .. versionadded:: 1.4.0 """ diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index b9442b0d16c0f..93e47a797f490 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -101,12 +101,12 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) >>> model = ALS.train(df, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3baff8147753d..765a4511b64bc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -436,7 +436,7 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() - 1 + 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) @@ -463,8 +463,8 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ |key|count| +---+-----+ - | 0| 3| - | 1| 8| + | 0| 5| + | 1| 9| +---+-----+ """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 4a644d136f09c..b7a0d44fa7e57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -24,12 +24,12 @@ import org.apache.spark.SparkFunSuite class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) - checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) + checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) + checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) } test("SPARK-9127 codegen with long seed") { - checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) - checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 49f516e86d754..40bff57a17a03 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -257,7 +257,9 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; - Assert.assertArrayEquals(expected, actual); + Assert.assertEquals(0, actual[0].getLong(0)); + Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); + Assert.assertEquals(1, actual[1].getLong(0)); + Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 6524abcf5e97f..b15af42caa3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -41,7 +41,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), - Seq(16, 23, 88, 100).map(Row(_)) + Seq(3, 17, 27, 58, 62).map(Row(_)) ) } @@ -186,6 +186,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 5), Row(1, 8))) + Seq(Row(0, 6), Row(1, 11))) } } From f328fedafd7bd084470a5e402de0429b5b7f8cd7 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 12:21:53 -0800 Subject: [PATCH 217/324] [SPARK-11450] [SQL] Add Unsafe Row processing to Expand This PR enables the Expand operator to process and produce Unsafe Rows. Author: Herman van Hovell Closes #9414 from hvanhovell/SPARK-11450. --- .../sql/catalyst/expressions/Projection.scala | 6 ++- .../apache/spark/sql/execution/Expand.scala | 19 ++++--- .../spark/sql/execution/basicOperators.scala | 8 +-- .../spark/sql/execution/ExpandSuite.scala | 54 +++++++++++++++++++ 4 files changed, 73 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a6fe730f6dad4..79dabe8e925ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -128,7 +128,11 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a458881f40948..55e95769d3faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,14 +41,21 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + private[this] val projection = { + if (outputsUnsafeRows) { + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + } else { + (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() + } + } + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray - + val groups = projections.map(projection).toArray new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d5a803f8c4b24..799650a4f784f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -67,16 +67,10 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - /** Rewrite the project list to use unsafe expressions as needed. */ - protected val unsafeProjectList = projectList.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(unsafeProjectList, child.output) + val project = UnsafeProjection.create(projectList, child.output) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala new file mode 100644 index 0000000000000..faef76d52ae75 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType + +class ExpandSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private def testExpand(f: SparkPlan => SparkPlan): Unit = { + val input = (1 to 1000).map(Tuple1.apply) + val projections = Seq.tabulate(2) { i => + Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil + } + val attributes = projections.head.map(_.toAttribute) + checkAnswer( + input.toDF(), + plan => Expand(projections, attributes, f(plan)), + input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) + ) + } + + test("inheriting child row type") { + val exprs = AttributeReference("a", IntegerType, false)() :: Nil + val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) + assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") + } + + test("expanding UnsafeRows") { + testExpand(ConvertToUnsafe) + } + + test("expanding SafeRows") { + testExpand(identity) + } +} From 3a652f691b220fada0286f8d0a562c5657973d4d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 6 Nov 2015 14:47:41 -0800 Subject: [PATCH 218/324] [SPARK-11561][SQL] Rename text data source's column name to value. Author: Reynold Xin Closes #9527 from rxin/SPARK-11561. --- .../sql/execution/datasources/text/DefaultSource.scala | 6 ++---- .../spark/sql/execution/datasources/text/TextSuite.scala | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 52c4421d7e87e..4b8b8e4e74dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -30,14 +30,12 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} -import org.apache.spark.sql.columnar.MutableUnsafeRow import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration /** @@ -78,7 +76,7 @@ private[sql] class TextRelation( extends HadoopFsRelation(maybePartitionSpec) { /** Data schema is always a single column, named "text". */ - override def dataSchema: StructType = new StructType().add("text", StringType) + override def dataSchema: StructType = new StructType().add("value", StringType) /** This is an internal data source that outputs internal row format. */ override val needConversion: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 0a2306c06646c..914e516613f9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -65,7 +65,7 @@ class TextSuite extends QueryTest with SharedSQLContext { /** Verifies data and schema. */ private def verifyFrame(df: DataFrame): Unit = { // schema - assert(df.schema == new StructType().add("text", StringType)) + assert(df.schema == new StructType().add("value", StringType)) // verify content val data = df.collect() From c447c9d54603890db7399fb80adc9fae40b71f64 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Nov 2015 14:51:03 -0800 Subject: [PATCH 219/324] [SPARK-11217][ML] save/load for non-meta estimators and transformers This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng Closes #9454 from mengxr/SPARK-11217. --- .../apache/spark/ml/feature/Binarizer.scala | 11 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../org/apache/spark/ml/util/ReadWrite.scala | 220 ++++++++++++++++++ .../ml/util/JavaDefaultReadWriteSuite.java | 74 ++++++ .../spark/ml/feature/BinarizerSuite.scala | 11 +- .../spark/ml/util/DefaultReadWriteTest.scala | 110 +++++++++ .../apache/spark/ml/util/TempDirectory.scala | 45 ++++ 7 files changed, 469 insertions(+), 4 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index edad754436455..e5c25574d4b11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with Writable with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,4 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object Binarizer extends Readable[Binarizer] { + + override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8361406f87299..c9325709187c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala new file mode 100644 index 0000000000000..ea790e0dddc7f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +/** + * Trait for [[Writer]] and [[Reader]]. + */ +private[util] sealed trait BaseReadWrite { + private var optionSQLContext: Option[SQLContext] = None + + /** + * Sets the SQL context to use for saving/loading. + */ + @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { + SQLContext.getOrCreate(SparkContext.getOrCreate()) + } +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class Writer extends BaseReadWrite { + + protected var shouldOverwrite: Boolean = false + + /** + * Saves the ML instances to the input path. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit + + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for classes that provide [[Writer]]. + */ +@Since("1.6.0") +trait Writable { + + /** + * Returns a [[Writer]] instance for this ML instance. + */ + @Since("1.6.0") + def write: Writer + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) +} + +/** + * Abstract class for utility classes that can load ML instances. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class Reader[T] extends BaseReadWrite { + + /** + * Loads the ML component from the input path. + */ + @Since("1.6.0") + def load(path: String): T + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for objects that provide [[Reader]]. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait Readable[T] { + + /** + * Returns a [[Reader]] instance for this class. + */ + @Since("1.6.0") + def read: Reader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) +} + +/** + * Default [[Writer]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @param instance object to save + */ +private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { + + /** + * Saves the ML component to the input path. + */ + override def save(path: String): Unit = { + val sc = sqlContext.sparkContext + + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val jsonParams = params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } +} + +/** + * Default [[Reader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @tparam T ML instance type + */ +private[ml] class DefaultParamsReader[T] extends Reader[T] { + + /** + * Loads the ML component from the input path. + */ + override def load(path: String): T = { + implicit val format = DefaultFormats + val sc = sqlContext.sparkContext + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + val cls = Utils.classForName((metadata \ "class").extract[String]) + val uid = (metadata \ "uid").extract[String] + val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] + (metadata \ "paramMap") match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + } + instance.asInstanceOf[T] + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 0000000000000..c39538014be81 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.save(outputPath); + try { + instance.save(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + SQLContext sqlContext = new SQLContext(jsc); + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 2086043983661..9dfa1439cc303 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Double] = _ @@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x === y, "The feature value is not correct after binarization.") } } + + test("read/write") { + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.1) + testDefaultReadWrite(binarizer) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala new file mode 100644 index 0000000000000..4545b0f281f5a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +trait DefaultReadWriteTest extends TempDirectory { self: Suite => + + /** + * Checks "overwrite" option and params. + * @param instance ML instance to test saving/loading + * @tparam T ML instance type + */ + def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { + val uid = instance.uid + val path = new File(tempDir, uid).getPath + + instance.save(path) + intercept[IOException] { + instance.save(path) + } + instance.write.overwrite().save(path) + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val newInstance = loader.load(path) + + assert(newInstance.uid === instance.uid) + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") + } + } + + val load = instance.getClass.getMethod("load", classOf[String]) + val another = load.invoke(instance, path).asInstanceOf[T] + assert(another.uid === instance.uid) + } +} + +class MyParams(override val uid: String) extends Params with Writable { + + final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") + final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") + final val longParam: LongParam = new LongParam(this, "longParam", "doc") + final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") + final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") + final val doubleArrayParam: DoubleArrayParam = + new DoubleArrayParam(this, "doubleArrayParam", "doc") + final val stringArrayParam: StringArrayParam = + new StringArrayParam(this, "stringArrayParam", "doc") + + setDefault(intParamWithDefault -> 0) + set(intParam -> 1) + set(floatParam -> 2.0f) + set(doubleParam -> 3.0) + set(longParam -> 4L) + set(stringParam -> "5") + set(intArrayParam -> Array(6, 7)) + set(doubleArrayParam -> Array(8.0, 9.0)) + set(stringArrayParam -> Array("10", "11")) + + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object MyParams extends Readable[MyParams] { + + override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + + override def load(path: String): MyParams = read.load(path) +} + +class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("default read/write") { + val myParams = new MyParams("my_params") + testDefaultReadWrite(myParams) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala new file mode 100644 index 0000000000000..2742026a69c2e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.{BeforeAndAfterAll, Suite} + +import org.apache.spark.util.Utils + +/** + * Trait that creates a temporary directory before all tests and deletes it after all. + */ +trait TempDirectory extends BeforeAndAfterAll { self: Suite => + + private var _tempDir: File = _ + + /** Returns the temporary directory as a [[File]] instance. */ + protected def tempDir: File = _tempDir + + override def beforeAll(): Unit = { + super.beforeAll() + _tempDir = Utils.createTempDir(this.getClass.getName) + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(_tempDir) + super.afterAll() + } +} From f6680cdc5d2912dea9768ef5c3e2cc101b06daf8 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 6 Nov 2015 15:24:33 -0800 Subject: [PATCH 220/324] [SPARK-11555] spark on yarn spark-class --num-workers doesn't work I tested the various options with both spark-submit and spark-class of specifying number of executors in both client and cluster mode where it applied. --num-workers, --num-executors, spark.executor.instances, SPARK_EXECUTOR_INSTANCES, default nothing supplied Author: Thomas Graves Closes #9523 from tgravescs/SPARK-11555. --- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 2 +- .../org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1165061db21e3..a9f4374357356 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -81,7 +81,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orNull // If dynamic allocation is enabled, start at the configured initial number of executors. // Default to minExecutors if no initialExecutors is set. - numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors) principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) .orNull diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 561ad79ee0228..a290ebeec9001 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -392,8 +392,11 @@ object YarnSparkHadoopUtil { /** * Getting the initial target number of executors depends on whether dynamic allocation is * enabled. + * If not using dynamic allocation it gets the number of executors reqeusted by the user. */ - def getInitialTargetExecutorNumber(conf: SparkConf): Int = { + def getInitialTargetExecutorNumber( + conf: SparkConf, + numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { if (Utils.isDynamicAllocationEnabled(conf)) { val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) val initialNumExecutors = @@ -406,7 +409,7 @@ object YarnSparkHadoopUtil { initialNumExecutors } else { val targetNumExecutors = - sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(DEFAULT_NUMBER_EXECUTORS) + sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) // System property can override environment variable. conf.getInt("spark.executor.instances", targetNumExecutors) } From 7e9a9e603abce8689938bdd62d04b29299644aa4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 6 Nov 2015 15:37:07 -0800 Subject: [PATCH 221/324] [SPARK-11269][SQL] Java API support & test cases for Dataset This simply brings https://github.com/apache/spark/pull/9358 up-to-date. Author: Wenchen Fan Author: Reynold Xin Closes #9528 from rxin/dataset-java. --- .../spark/sql/catalyst/encoders/Encoder.scala | 123 +++++- .../sql/catalyst/expressions/objects.scala | 21 ++ .../scala/org/apache/spark/sql/Dataset.scala | 126 ++++++- .../org/apache/spark/sql/DatasetHolder.scala | 6 +- .../org/apache/spark/sql/GroupedDataset.scala | 17 + .../org/apache/spark/sql/SQLContext.scala | 4 + .../apache/spark/sql/JavaDatasetSuite.java | 357 ++++++++++++++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 2 +- 8 files changed, 644 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 329a132d3d8b2..f05e18288de2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.encoders - - import scala.reflect.ClassTag -import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils +import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } + +object Encoder { + import scala.reflect.runtime.universe._ + + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + + def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] + + private def getTypeTag[T](c: Class[T]): TypeTag[T] = { + import scala.reflect.api + + // val mirror = runtimeMirror(c.getClassLoader) + val mirror = rootMirror + val sym = mirror.staticClass(c.getName) + val tpe = sym.selfType + TypeTag(mirror, new api.TypeCreator { + def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = + if (m eq mirror) tpe.asInstanceOf[U # Type] + else throw new IllegalArgumentException( + s"Type tag defined in $mirror cannot be migrated to other mirrors.") + }) + } + + def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + ExpressionEncoder[(T1, T2)]() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 81855289762c6..4f58464221b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" } } + +case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) + extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val row = child.gen(ctx) + s""" + ${row.code} + final boolean ${ev.isNull} = ${row.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; + } + """ + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4bca9c3b3fe54..fecbdac9a6004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} + import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner @@ -151,18 +155,37 @@ class Dataset[T] private[sql]( def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) /** + * (Scala-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = + filter(t => func.call(t).booleanValue()) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + map(t => func.call(t))(encoder) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ @@ -177,30 +200,77 @@ class Dataset[T] private[sql]( logicalPlan)) } + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def mapPartitions[U]( + f: FlatMapFunction[java.util.Iterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + mapPartitions(func)(encoder) + } + + /** + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) + /** + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterable[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } + /* ************** * * Side effects * * ************** */ /** + * (Scala-specific) * Runs `func` on each element of this Dataset. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** + * (Java-specific) + * Runs `func` on each element of this Dataset. + * @since 1.6.0 + */ + def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + + /** + * (Scala-specific) * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + /** + * (Java-specific) + * Runs `func` on each partition of this Dataset. + * @since 1.6.0 + */ + def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + foreachPartition(it => func.call(it.asJava)) + /* ************* * * Aggregation * * ************* */ /** + * (Scala-specific) * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 @@ -208,6 +278,15 @@ class Dataset[T] private[sql]( def reduce(func: (T, T) => T): T = rdd.reduce(func) /** + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given function + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) + + /** + * (Scala-specific) * Aggregates the elements of each partition, and then the results for all the partitions, using a * given associative and commutative function and a neutral "zero value". * @@ -221,6 +300,15 @@ class Dataset[T] private[sql]( def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) /** + * (Java-specific) + * Aggregates the elements of each partition, and then the results for all the partitions, using a + * given associative and commutative function and a neutral "zero value". + * @since 1.6.0 + */ + def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + + /** + * (Scala-specific) * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ @@ -258,6 +346,14 @@ class Dataset[T] private[sql]( keyAttributes) } + /** + * (Java-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * @since 1.6.0 + */ + def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(f.call(_))(encoder) + /* ****************** * * Typed Relational * * ****************** */ @@ -267,8 +363,7 @@ class Dataset[T] private[sql]( * {{{ * df.select($"colA", $"colB" + 1) * }}} - * @group dfops - * @since 1.3.0 + * @since 1.6.0 */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs @@ -279,7 +374,7 @@ class Dataset[T] private[sql]( * * {{{ * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(e[Int]("value + 1")) + * val newDS = ds.select(expr("value + 1").as[Int]) * }}} * @since 1.6.0 */ @@ -405,6 +500,8 @@ class Dataset[T] private[sql]( * This type of join can be useful both for preserving type-safety with the original object * types as well as working with relational data where either side of the join has column * names in common. + * + * @since 1.6.0 */ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { val left = this.logicalPlan @@ -438,12 +535,31 @@ class Dataset[T] private[sql]( * Gather to Driver Actions * * ************************** */ - /** Returns the first element in this [[Dataset]]. */ + /** + * Returns the first element in this [[Dataset]]. + * @since 1.6.0 + */ def first(): T = rdd.first() - /** Collects the elements to an Array. */ + /** + * Collects the elements to an Array. + * @since 1.6.0 + */ def collect(): Array[T] = rdd.collect() + /** + * (Java-specific) + * Collects the elements to a Java list. + * + * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at + * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method + * instead and keep the generic type for result. + * + * @since 1.6.0 + */ + def collectAsList(): java.util.List[T] = + rdd.collect().toSeq.asJava + /** Returns the first `num` elements of this [[Dataset]] as an Array. */ def take(num: Int): Array[T] = rdd.take(num) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 45f0098b92887..08097e9f02084 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -27,9 +27,9 @@ package org.apache.spark.sql * * @since 1.6.0 */ -case class DatasetHolder[T] private[sql](private val df: Dataset[T]) { +case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDS(): Dataset[T] = df + // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. + def toDS(): Dataset[T] = ds } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b8fc373dffcf5..b2803d5a9a1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql +import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} @@ -104,6 +108,12 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + def mapGroups[U]( + f: JFunction2[K, JIterator[T], JIterator[U]], + encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + // To ensure valid overloading. protected def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) @@ -196,4 +206,11 @@ class GroupedDataset[K, T] private[sql]( this.logicalPlan, other.logicalPlan)) } + + def cogroup[U, R]( + other: GroupedDataset[K, U], + f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5ad3871093fc8..5598731af5fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -508,6 +508,10 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java new file mode 100644 index 0000000000000..a9493d576d179 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; +import scala.Tuple5; +import org.junit.*; + +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.catalyst.encoders.Encoder; +import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.test.TestSQLContext; + +import static org.apache.spark.sql.functions.*; + +public class JavaDatasetSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + private transient Encoder$ e = Encoder$.MODULE$; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2(t1, t2); + } + + @Test + public void testCollect() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + String[] collected = (String[]) ds.collect(); + Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + } + + @Test + public void testCommonOperation() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + Assert.assertEquals("hello", ds.first()); + + Dataset filtered = ds.filter(new Function() { + @Override + public Boolean call(String v) throws Exception { + return v.startsWith("h"); + } + }); + Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); + + + Dataset mapped = ds.map(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); + + Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator it) throws Exception { + List ls = new LinkedList(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase()); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); + + Dataset flatMapped = ds.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) throws Exception { + List ls = new LinkedList(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals( + Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), + flatMapped.collectAsList()); + } + + @Test + public void testForeach() { + final Accumulator accum = jsc.accumulator(0); + List data = Arrays.asList("a", "b", "c"); + Dataset ds = context.createDataset(data, e.STRING()); + + ds.foreach(new VoidFunction() { + @Override + public void call(String s) throws Exception { + accum.add(1); + } + }); + Assert.assertEquals(3, accum.value().intValue()); + } + + @Test + public void testReduce() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()); + + int reduced = ds.reduce(new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(6, reduced); + + int folded = ds.fold(1, new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 * v2; + } + }); + Assert.assertEquals(6, folded); + } + + @Test + public void testGroupBy() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + List data2 = Arrays.asList(2, 6, 10); + Dataset ds2 = context.createDataset(data2, e.INT()); + GroupedDataset grouped2 = ds2.groupBy(new Function() { + @Override + public Integer call(Integer v) throws Exception { + return v / 2; + } + }, e.INT()); + + Dataset cogrouped = grouped.cogroup( + grouped2, + new Function3, Iterator, Iterator>() { + @Override + public Iterator call( + Integer key, + Iterator left, + Iterator right) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + } + + @Test + public void testGroupByColumn() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + } + + @Test + public void testSelect() { + List data = Arrays.asList(2, 6); + Dataset ds = context.createDataset(data, e.INT()); + + Dataset> selected = ds.select( + expr("value + 1").as(e.INT()), + col("value").cast("string").as(e.STRING())); + + Assert.assertEquals( + Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), + selected.collectAsList()); + } + + @Test + public void testSetOperation() { + List data = Arrays.asList("abc", "abc", "xyz"); + Dataset ds = context.createDataset(data, e.STRING()); + + Assert.assertEquals( + Arrays.asList("abc", "xyz"), + sort(ds.distinct().collectAsList().toArray(new String[0]))); + + List data2 = Arrays.asList("xyz", "foo", "foo"); + Dataset ds2 = context.createDataset(data2, e.STRING()); + + Dataset intersected = ds.intersect(ds2); + Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); + + Dataset unioned = ds.union(ds2); + Assert.assertEquals( + Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), + sort(unioned.collectAsList().toArray(new String[0]))); + + Dataset subtracted = ds.subtract(ds2); + Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + } + + private > List sort(T[] data) { + Arrays.sort(data); + return Arrays.asList(data); + } + + @Test + public void testJoin() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()).as("a"); + List data2 = Arrays.asList(2, 3, 4); + Dataset ds2 = context.createDataset(data2, e.INT()).as("b"); + + Dataset> joined = + ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); + Assert.assertEquals( + Arrays.asList(tuple2(2, 2), tuple2(3, 3)), + joined.collectAsList()); + } + + @Test + public void testTupleEncoder() { + Encoder> encoder2 = e.tuple(e.INT(), e.STRING()); + List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); + Dataset> ds2 = context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + Encoder> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + List> data3 = + Arrays.asList(new Tuple3(1, 2L, "a")); + Dataset> ds3 = context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + + Encoder> encoder4 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + List> data4 = + Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Dataset> ds4 = context.createDataset(data4, encoder4); + Assert.assertEquals(data4, ds4.collectAsList()); + + Encoder> encoder5 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + List> data5 = + Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Dataset> ds5 = + context.createDataset(data5, encoder5); + Assert.assertEquals(data5, ds5.collectAsList()); + } + + @Test + public void testNestedTupleEncoder() { + // test ((int, string), string) + Encoder, String>> encoder = + e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + List, String>> data = + Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); + Dataset, String>> ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + + // test (int, (string, string, long)) + Encoder>> encoder2 = + e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + List>> data2 = + Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Dataset>> ds2 = + context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + // test (int, ((string, long), string)) + Encoder, String>>> encoder3 = + e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + List, String>>> data3 = + Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); + Dataset, String>>> ds3 = + context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 32443557fb8e0..e3b0346f857d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -59,7 +59,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.accumulator(0) - ds.foreach(acc +=) + ds.foreach(acc += _) assert(acc.value == 6) } From 1ab72b08601a1c8a674bdd3fab84d9804899b2c7 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 6 Nov 2015 15:48:20 -0800 Subject: [PATCH 222/324] =?UTF-8?q?[SPARK-11410]=20[PYSPARK]=20Add=20pytho?= =?UTF-8?q?n=20bindings=20for=20repartition=20and=20sortW=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ithinPartitions. Author: Nong Li Closes #9504 from nongli/spark-11410. --- python/pyspark/sql/dataframe.py | 117 +++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 765a4511b64bc..b97c94dad834a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -422,6 +422,67 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) + def repartition(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is hash partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + .. versionchanged:: 1.6 + Added optional arguments to specify the partitioning columns. Also made numPartitions + optional if partitioning columns are specified. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + >>> data = df.unionAll(df).repartition("age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 2|Alice| + | 5| Bob| + | 5| Bob| + +---+-----+ + >>> data = data.repartition(7, "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data.rdd.getNumPartitions() + 7 + >>> data = data.repartition("name", "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + else: + return DataFrame( + self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions, ) + cols + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -589,6 +650,26 @@ def join(self, other, on=None, how=None): jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) + @since(1.6) + def sortWithinPartitions(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sortWithinPartitions("age", ascending=False).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): @@ -613,22 +694,7 @@ def sort(self, *cols, **kwargs): >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ - if not cols: - raise ValueError("should sort by at least one column") - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - jcols = [_to_java_column(c) for c in cols] - ascending = kwargs.get('ascending', True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() - for asc, jc in zip(ascending, jcols)] - else: - raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) - - jdf = self._jdf.sort(self._jseq(jcols)) + jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) orderBy = sort @@ -650,6 +716,25 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + def _sort_cols(self, cols, kwargs): + """ Return a JVM Seq of Columns that describes the sort order + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + return self._jseq(jcols) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. From 6d0ead322e72303c6444c6ac641378a4690cde96 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 16:04:20 -0800 Subject: [PATCH 223/324] [SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule The second PR for SPARK-9241, this adds support for multiple distinct columns to the new aggregation code path. This PR solves the multiple DISTINCT column problem by rewriting these Aggregates into an Expand-Aggregate-Aggregate combination. See the [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-9241) for some information on this. The advantages over the - competing - [first PR](https://github.com/apache/spark/pull/9280) are: - This can use the faster TungstenAggregate code path. - It is impossible to OOM due to an ```OpenHashSet``` allocating to much memory. However, this will multiply the number of input rows by the number of distinct clauses (plus one), and puts a lot more memory pressure on the aggregation code path itself. The location of this Rule is a bit funny, and should probably change when the old aggregation path is changed. cc yhuai - Could you also tell me where to add tests for this? Author: Herman van Hovell Closes #9406 from hvanhovell/SPARK-9241-rewriter. --- .../expressions/aggregate/Count.scala | 2 + .../expressions/aggregate/Utils.scala | 186 +++++++++++++++++- .../expressions/aggregate/interfaces.scala | 6 + .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../plans/logical/basicOperators.scala | 80 ++++---- .../spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 238 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 54df96cd2446a..ec0c8b483a909 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override val evaluateExpression = Cast(count, LongType) + + override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 644c6211d5f31..39010c3be6d4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -41,7 +42,7 @@ object Utils { private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Average(child), @@ -144,7 +145,8 @@ object Utils { aggregateFunction = aggregate.VarianceSamp(child), mode = aggregate.Complete, isDistinct = false) - } + }) + // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => @@ -156,6 +158,7 @@ object Utils { } // Check if there are multiple distinct columns. + // TODO remove this. val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg @@ -213,3 +216,178 @@ object Utils { case other => None } } + +/** + * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * TODO Expression cannocalization + * TODO Eliminate foldable expressions from distinct clauses. + * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate + * operator. Perhaps this is a good thing? It is much simpler to plan later on... + */ +object MultipleDistinctRewriter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression2 => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Only continue to rewrite if there is more than one distinct group. + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction2, + id: Literal, + attrs: Map[Expression, Expression]): AggregateFunction2 = { + af.withNewChildren(af.children.map { case afc => + evalWithinGroup(id, attrs(afc)) + }).asInstanceOf[AggregateFunction2] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap + val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren( + e.aggregateFunction, + regularGroupId, + regularAggChildAttrMap) + val a = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Get the result of the first aggregate in the last aggregate. + val b = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val c = af.defaultResult match { + case Some(lit) => Coalesce(Seq(b, lit)) + case None => b + } + + (e, a, c) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyName, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a2fab258fcac3..5c5b3d1ccd3cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp */ def supportsPartial: Boolean = true + /** + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. + */ + def defaultResult: Option[Literal] = None + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a2..d222dfa33ad8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33ee..fb963e2f8f7e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -235,33 +235,17 @@ case class Window( projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { + private def buildNonSelectExprSet( + bitmask: Int, + exprs: Seq[Expression]): OpenHashSet[Expression] = { val set = new OpenHashSet[Expression](2) var bit = exprs.length - 1 @@ -274,18 +258,28 @@ case class Expand( } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { + (child.output :+ gid).map(expr => expr transformDown { + // TODO this causes a problem when a column is used both for grouping and aggregation. case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null @@ -294,15 +288,29 @@ case class Expand( // replace the groupingId with concrete value (the bit mask) Literal.create(bitmask, IntegerType) }) - - result += substitution } - - result.toSeq + Expand(projections, child.output :+ gid, child) } +} - override def output: Seq[Attribute] = { - child.output :+ gid +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def statistics: Statistics = { + // TODO shouldn't we factor in the size of the projection versus the size of the backing child + // row? + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f8..dd3bb33c57287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled From 1c80d66e52c0bcc4e5adda78b3d8e5bf55e4f128 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 6 Nov 2015 17:13:46 -0800 Subject: [PATCH 224/324] [SPARK-11546] Thrift server makes too many logs about result schema SparkExecuteStatementOperation logs result schema for each getNextRowSet() calls which is by default every 1000 rows, overwhelming whole log file. Author: navis.ryu Closes #9514 from navis/SPARK-11546. --- .../SparkExecuteStatementOperation.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 719b03e1c7c71..82fef92dcb73b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -53,6 +53,18 @@ private[hive] class SparkExecuteStatementOperation( private var dataTypes: Array[DataType] = _ private var statementId: String = _ + private lazy val resultSchema: TableSchema = { + if (result == null || result.queryExecution.analyzed.output.size == 0) { + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) + } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema.asJava) + } + } + def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. hiveContext.sparkContext.clearJobGroup() @@ -120,17 +132,7 @@ private[hive] class SparkExecuteStatementOperation( } } - def getResultSetSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) - } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema.asJava) - } - } + def getResultSetSchema: TableSchema = resultSchema override def run(): Unit = { setState(OperationState.PENDING) From 105732dcc6b651b9779f4a5773a759c5b4fbd21d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 6 Nov 2015 17:22:30 -0800 Subject: [PATCH 225/324] [HOTFIX] Fix python tests after #9527 #9527 missed updating the python tests. Author: Michael Armbrust Closes #9533 from marmbrus/hotfixTextValue. --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 97bd90c4db829..927f4077424dc 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -203,7 +203,7 @@ def text(self, path): >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() - [Row(text=u'hello'), Row(text=u'this')] + [Row(value=u'hello'), Row(value=u'this')] """ return self._df(self._jreader.text(path)) From 30b706b7b36482921ec04145a0121ca147984fa8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 6 Nov 2015 18:17:34 -0800 Subject: [PATCH 226/324] [SPARK-11389][CORE] Add support for off-heap memory to MemoryManager In order to lay the groundwork for proper off-heap memory support in SQL / Tungsten, we need to extend our MemoryManager to perform bookkeeping for off-heap memory. ## User-facing changes This PR introduces a new configuration, `spark.memory.offHeapSize` (name subject to change), which specifies the absolute amount of off-heap memory that Spark and Spark SQL can use. If Tungsten is configured to use off-heap execution memory for allocating data pages, then all data page allocations must fit within this size limit. ## Internals changes This PR contains a lot of internal refactoring of the MemoryManager. The key change at the heart of this patch is the introduction of a `MemoryPool` class (name subject to change) to manage the bookkeeping for a particular category of memory (storage, on-heap execution, and off-heap execution). These MemoryPools are not fixed-size; they can be dynamically grown and shrunk according to the MemoryManager's policies. In StaticMemoryManager, these pools have fixed sizes, proportional to the legacy `[storage|shuffle].memoryFraction`. In the new UnifiedMemoryManager, the sizes of these pools are dynamically adjusted according to its policies. There are two subclasses of `MemoryPool`: `StorageMemoryPool` manages storage memory and `ExecutionMemoryPool` manages execution memory. The MemoryManager creates two execution pools, one for on-heap memory and one for off-heap. Instances of `ExecutionMemoryPool` manage the logic for fair sharing of their pooled memory across running tasks (in other words, the ShuffleMemoryManager-like logic has been moved out of MemoryManager and pushed into these ExecutionMemoryPool instances). I think that this design is substantially easier to understand and reason about than the previous design, where most of these responsibilities were handled by MemoryManager and its subclasses. To see this, take at look at how simple the logic in `UnifiedMemoryManager` has become: it's now very easy to see when memory is dynamically shifted between storage and execution. ## TODOs - [x] Fix handful of test failures in the MemoryManagerSuites. - [x] Fix remaining TODO comments in code. - [ ] Document new configuration. - [x] Fix commented-out tests / asserts: - [x] UnifiedMemoryManagerSuite. - [x] Write tests that exercise the new off-heap memory management policies. Author: Josh Rosen Closes #9344 from JoshRosen/offheap-memory-accounting. --- .../apache/spark/memory/MemoryConsumer.java | 7 +- .../org/apache/spark/memory/MemoryMode.java | 26 ++ .../spark/memory/TaskMemoryManager.java | 72 +++-- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/memory/ExecutionMemoryPool.scala | 153 +++++++++++ .../apache/spark/memory/MemoryManager.scala | 246 ++++++------------ .../org/apache/spark/memory/MemoryPool.scala | 71 +++++ .../spark/memory/StaticMemoryManager.scala | 75 +----- .../spark/memory/StorageMemoryPool.scala | 138 ++++++++++ .../spark/memory/UnifiedMemoryManager.scala | 138 +++++----- .../org/apache/spark/memory/package.scala | 75 ++++++ .../spark/util/collection/Spillable.scala | 8 +- .../spark/memory/TaskMemoryManagerSuite.java | 8 +- .../spark/memory/TestMemoryConsumer.java | 10 +- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 +- .../spark/memory/MemoryManagerSuite.scala | 104 +++++--- .../memory/StaticMemoryManagerSuite.scala | 39 +-- .../spark/memory/TestMemoryManager.scala | 20 +- .../memory/UnifiedMemoryManagerSuite.scala | 93 +++---- .../spark/storage/BlockManagerSuite.scala | 2 +- 21 files changed, 828 insertions(+), 465 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/MemoryMode.java create mode 100644 core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/MemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/package.scala diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 8fbdb72832adf..36138cc9a297c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -17,15 +17,15 @@ package org.apache.spark.memory; - import java.io.IOException; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; - /** * An memory consumer of TaskMemoryManager, which support spilling. + * + * Note: this only supports allocation / spilling of Tungsten memory. */ public abstract class MemoryConsumer { @@ -36,7 +36,6 @@ public abstract class MemoryConsumer { protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; this.pageSize = pageSize; - this.used = 0; } protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { @@ -67,6 +66,8 @@ public void spill() throws IOException { * * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). * + * Note: today, this only frees Tungsten-managed pages. + * * @param size the amount of memory should be released * @param trigger the MemoryConsumer that trigger this spilling * @return the amount of released memory in bytes diff --git a/core/src/main/java/org/apache/spark/memory/MemoryMode.java b/core/src/main/java/org/apache/spark/memory/MemoryMode.java new file mode 100644 index 0000000000000..3a5e72d8aaec0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryMode.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.apache.spark.annotation.Private; + +@Private +public enum MemoryMode { + ON_HEAP, + OFF_HEAP +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 6440f9c0f30de..5f743b28857b4 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -103,10 +103,10 @@ public class TaskMemoryManager { * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. */ - private final boolean inHeap; + final MemoryMode tungstenMemoryMode; /** - * The size of memory granted to each consumer. + * Tracks spillable memory consumers. */ @GuardedBy("this") private final HashSet consumers; @@ -115,7 +115,7 @@ public class TaskMemoryManager { * Construct a new TaskMemoryManager. */ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { - this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; this.consumers = new HashSet<>(); @@ -127,12 +127,19 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + public long acquireExecutionMemory( + long required, + MemoryMode mode, + MemoryConsumer consumer) { assert(required >= 0); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. synchronized (this) { - long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); - // try to release memory from other consumers first, then we can reduce the frequency of + // Try to release memory from other consumers first, then we can reduce the frequency of // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory @@ -140,10 +147,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from {} for {}", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); if (got >= required) { break; } @@ -161,10 +168,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got < required && consumer != null) { try { long released = consumer.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from itself ({})", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from itself ({})", taskAttemptId, Utils.bytesToString(released), consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); @@ -184,9 +191,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { /** * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) { logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); - memoryManager.releaseExecutionMemory(size, taskAttemptId); + memoryManager.releaseExecutionMemory(size, taskAttemptId, mode); } /** @@ -195,11 +202,19 @@ public void releaseExecutionMemory(long size, MemoryConsumer consumer) { public void showMemoryUsage() { logger.info("Memory used in task " + taskAttemptId); synchronized (this) { + long memoryAccountedForByConsumers = 0; for (MemoryConsumer c: consumers) { - if (c.getUsed() > 0) { - logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); } } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, taskAttemptId); } } @@ -214,7 +229,8 @@ public long pageSizeBytes() { * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is * intended for allocating large blocks of Tungsten memory that will be shared between operators. * - * Returns `null` if there was not enough memory to allocate the page. + * Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { @@ -222,7 +238,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } - long acquired = acquireExecutionMemory(size, consumer); + long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer); if (acquired <= 0) { return null; } @@ -231,7 +247,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { - releaseExecutionMemory(acquired, consumer); + releaseExecutionMemory(acquired, tungstenMemoryMode, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } @@ -262,7 +278,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize, consumer); + releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer); } /** @@ -276,7 +292,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (!inHeap) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { // In off-heap mode, an offset is an absolute address that may require a full 64 bits to // encode. Due to our page size limitation, though, we can convert this into an offset that's // relative to the page's base offset; this relative offset will fit in 51 bits. @@ -305,7 +321,7 @@ private static long decodeOffset(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final MemoryBlock page = pageTable[pageNumber]; @@ -323,7 +339,7 @@ public Object getPage(long pagePlusOffsetAddress) { */ public long getOffsetInPage(long pagePlusOffsetAddress) { final long offsetInPage = decodeOffset(pagePlusOffsetAddress); - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { return offsetInPage; } else { // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we @@ -351,11 +367,19 @@ public long cleanUpAllAllocatedMemory() { } consumers.clear(); } + + for (MemoryBlock page : pageTable) { + if (page != null) { + memoryManager.tungstenMemoryAllocator().free(page); + } + } + Arrays.fill(pageTable, null); + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** - * Returns the memory consumption, in bytes, for the current task + * Returns the memory consumption, in bytes, for the current task. */ public long getMemoryConsumptionForThisTask() { return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 23ae9360f6a22..4474a83bedbdb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -341,7 +341,7 @@ object SparkEnv extends Logging { if (useLegacyMemoryManager) { new StaticMemoryManager(conf, numUsableCores) } else { - new UnifiedMemoryManager(conf, numUsableCores) + UnifiedMemoryManager(conf, numUsableCores) } val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala new file mode 100644 index 0000000000000..7825bae425877 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * Implements policies and bookkeeping for sharing a adjustable-sized pool of memory between tasks. + * + * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up + * to a large amount first and then causing others to spill to disk repeatedly. + * + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this + * set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. + * + * @param lock a [[MemoryManager]] instance to synchronize on + * @param poolName a human-readable name for this pool, for use in log messages + */ +class ExecutionMemoryPool( + lock: Object, + poolName: String + ) extends MemoryPool(lock) with Logging { + + /** + * Map from taskAttemptId -> memory consumption in bytes + */ + @GuardedBy("lock") + private val memoryForTask = new mutable.HashMap[Long, Long]() + + override def memoryUsed: Long = lock.synchronized { + memoryForTask.values.sum + } + + /** + * Returns the memory consumption, in bytes, for the given task. + */ + def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized { + memoryForTask.getOrElse(taskAttemptId, 0L) + } + + /** + * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes + * obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * @return the number of bytes granted to the task. + */ + def acquireMemory(numBytes: Long, taskAttemptId: Long): Long = lock.synchronized { + assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` + if (!memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + lock.notifyAll() + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = memoryForTask.keys.size + val curMem = memoryForTask(taskAttemptId) + + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; + // don't let it be negative + val maxToGrant = + math.min(numBytes, math.max(0, (poolSize / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, memoryFree) + + if (curMem < poolSize / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (memoryFree >= math.min(maxToGrant, poolSize / (2 * numActiveTasks) - curMem)) { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } else { + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() + } + } else { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** + * Release `numBytes` of memory acquired by the given task. + */ + def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized { + val curMem = memoryForTask.getOrElse(taskAttemptId, 0L) + var memoryToFree = if (curMem < numBytes) { + logWarning( + s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " + + s"of memory from the $poolName pool") + curMem + } else { + numBytes + } + if (memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) -= memoryToFree + if (memoryForTask(taskAttemptId) <= 0) { + memoryForTask.remove(taskAttemptId) + } + } + lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized { + val numBytesToFree = getMemoryUsageForTask(taskAttemptId) + releaseMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index b0cf2696a397f..ceb8ea434e1be 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -20,12 +20,8 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.util.Utils -import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator @@ -36,53 +32,40 @@ import org.apache.spark.unsafe.memory.MemoryAllocator * In this context, execution memory refers to that used for computation in shuffles, joins, * sorts and aggregations, while storage memory refers to that used for caching and propagating * internal data across the cluster. There exists one MemoryManager per JVM. - * - * The MemoryManager abstract base class itself implements policies for sharing execution memory - * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of - * some task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access to mutable state and using wait() and - * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across - * tasks was performed by the ShuffleMemoryManager. */ -private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging { +private[spark] abstract class MemoryManager( + conf: SparkConf, + numCores: Int, + storageMemory: Long, + onHeapExecutionMemory: Long) extends Logging { // -- Methods related to memory allocation policies and bookkeeping ------------------------------ - // The memory store used to evict cached blocks - private var _memoryStore: MemoryStore = _ - protected def memoryStore: MemoryStore = { - if (_memoryStore == null) { - throw new IllegalArgumentException("memory store not initialized yet") - } - _memoryStore - } + @GuardedBy("this") + protected val storageMemoryPool = new StorageMemoryPool(this) + @GuardedBy("this") + protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "on-heap execution") + @GuardedBy("this") + protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "off-heap execution") - // Amount of execution/storage memory in use, accesses must be synchronized on `this` - @GuardedBy("this") protected var _executionMemoryUsed: Long = 0 - @GuardedBy("this") protected var _storageMemoryUsed: Long = 0 - // Map from taskAttemptId -> memory consumption in bytes - @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]() - - /** - * Set the [[MemoryStore]] used by this manager to evict cached blocks. - * This must be set after construction due to initialization ordering constraints. - */ - final def setMemoryStore(store: MemoryStore): Unit = { - _memoryStore = store - } + storageMemoryPool.incrementPoolSize(storageMemory) + onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeapSize", 0)) /** - * Total available memory for execution, in bytes. + * Total available memory for storage, in bytes. This amount can vary over time, depending on + * the MemoryManager implementation. + * In this model, this is equivalent to the amount of memory not occupied by execution. */ - def maxExecutionMemory: Long + def maxStorageMemory: Long /** - * Total available memory for storage, in bytes. + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. */ - def maxStorageMemory: Long + final def setMemoryStore(store: MemoryStore): Unit = synchronized { + storageMemoryPool.setMemoryStore(store) + } // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) @@ -94,7 +77,9 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -109,103 +94,25 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, evictedBlocks) - } - - /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). - */ - @VisibleForTesting - private[memory] def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean /** - * Try to acquire up to `numBytes` of execution memory for the current task and return the number - * of bytes obtained, or 0 if none can be allocated. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * * This call may block until there is enough free memory in some situations, to make sure each * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. - * - * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies - * that control global sharing of memory between execution and storage. */ private[memory] - final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) = 0L - // This will later cause waiting tasks to wake up and check numTasks again - notifyAll() - } - - // Once the cross-task memory allocation policy has decided to grant more memory to a task, - // this method is called in order to actually obtain that execution memory, potentially - // triggering eviction of storage memory: - def acquire(toGrant: Long): Long = synchronized { - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } - executionMemoryForTask(taskAttemptId) += acquired - acquired - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - // TODO: simplify this to limit each task to its own slot - while (true) { - val numActiveTasks = executionMemoryForTask.keys.size - val curMem = executionMemoryForTask(taskAttemptId) - val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem)) - // Only give it as much memory as is free, which might be none if it reached 1 / numTasks - val toGrant = math.min(maxToGrant, freeMemory) - - if (curMem < maxExecutionMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if ( - freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) { - return acquire(toGrant) - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free") - wait() - } - } else { - return acquire(toGrant) - } - } - 0L // Never reached - } - - @VisibleForTesting - private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _executionMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of execution " + - s"memory when we only have ${_executionMemoryUsed} bytes") - _executionMemoryUsed = 0 - } else { - _executionMemoryUsed -= numBytes + def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } } @@ -213,24 +120,14 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * Release numBytes of execution memory belonging to the given task. */ private[memory] - final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { - val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - if (Utils.isTesting) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") - } else { - logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem") - } - } - if (executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) -= numBytes - if (executionMemoryForTask(taskAttemptId) <= 0) { - executionMemoryForTask.remove(taskAttemptId) - } - releaseExecutionMemory(numBytes) + def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) } - notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed } /** @@ -238,35 +135,28 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * @return the number of bytes freed. */ private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { - val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId) - releaseExecutionMemory(numBytesToFree, taskAttemptId) - numBytesToFree + onHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) + + offHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) } /** * Release N bytes of storage memory. */ def releaseStorageMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _storageMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of storage " + - s"memory when we only have ${_storageMemoryUsed} bytes") - _storageMemoryUsed = 0 - } else { - _storageMemoryUsed -= numBytes - } + storageMemoryPool.releaseMemory(numBytes) } /** * Release all storage memory acquired. */ - def releaseAllStorageMemory(): Unit = synchronized { - _storageMemoryUsed = 0 + final def releaseAllStorageMemory(): Unit = synchronized { + storageMemoryPool.releaseAllMemory() } /** * Release N bytes of unroll memory. */ - def releaseUnrollMemory(numBytes: Long): Unit = synchronized { + final def releaseUnrollMemory(numBytes: Long): Unit = synchronized { releaseStorageMemory(numBytes) } @@ -274,25 +164,34 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * Execution memory currently in use, in bytes. */ final def executionMemoryUsed: Long = synchronized { - _executionMemoryUsed + onHeapExecutionMemoryPool.memoryUsed + offHeapExecutionMemoryPool.memoryUsed } /** * Storage memory currently in use, in bytes. */ final def storageMemoryUsed: Long = synchronized { - _storageMemoryUsed + storageMemoryPool.memoryUsed } /** * Returns the execution memory consumption, in bytes, for the given task. */ private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { - executionMemoryForTask.getOrElse(taskAttemptId, 0L) + onHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) + + offHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) } // -- Fields related to Tungsten managed memory ------------------------------------------------- + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryMode: MemoryMode = { + if (conf.getBoolean("spark.unsafe.offHeap", false)) MemoryMode.OFF_HEAP else MemoryMode.ON_HEAP + } + /** * The default page size, in bytes. * @@ -306,21 +205,22 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case val safetyFactor = 16 - val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor) + val maxTungstenMemory: Long = tungstenMemoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize + } + val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) } - /** - * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using - * sun.misc.Unsafe. - */ - final val tungstenMemoryIsAllocatedInHeap: Boolean = - !conf.getBoolean("spark.unsafe.offHeap", false) - /** * Allocates memory for use by Unsafe/Tungsten code. */ - private[memory] final val tungstenMemoryAllocator: MemoryAllocator = - if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = { + tungstenMemoryMode match { + case MemoryMode.ON_HEAP => MemoryAllocator.HEAP + case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala new file mode 100644 index 0000000000000..bfeec47e3892e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +/** + * Manages bookkeeping for an adjustable-sized region of memory. This class is internal to + * the [[MemoryManager]]. See subclasses for more details. + * + * @param lock a [[MemoryManager]] instance, used for synchronization. We purposely erase the type + * to `Object` to avoid programming errors, since this object should only be used for + * synchronization purposes. + */ +abstract class MemoryPool(lock: Object) { + + @GuardedBy("lock") + private[this] var _poolSize: Long = 0 + + /** + * Returns the current size of the pool, in bytes. + */ + final def poolSize: Long = lock.synchronized { + _poolSize + } + + /** + * Returns the amount of free memory in the pool, in bytes. + */ + final def memoryFree: Long = lock.synchronized { + _poolSize - memoryUsed + } + + /** + * Expands the pool by `delta` bytes. + */ + final def incrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + _poolSize += delta + } + + /** + * Shrinks the pool by `delta` bytes. + */ + final def decrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + require(delta <= _poolSize) + require(_poolSize - delta >= memoryUsed) + _poolSize -= delta + } + + /** + * Returns the amount of used memory in this pool (in bytes). + */ + def memoryUsed: Long +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 9c2c2e90a2282..12a094306861f 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus} - /** * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. * @@ -32,10 +31,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} */ private[spark] class StaticMemoryManager( conf: SparkConf, - override val maxExecutionMemory: Long, + maxOnHeapExecutionMemory: Long, override val maxStorageMemory: Long, numCores: Int) - extends MemoryManager(conf, numCores) { + extends MemoryManager( + conf, + numCores, + maxStorageMemory, + maxOnHeapExecutionMemory) { def this(conf: SparkConf, numCores: Int) { this( @@ -50,76 +53,15 @@ private[spark] class StaticMemoryManager( (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } - /** - * Acquire N bytes of memory for execution. - * @return number of bytes successfully granted (<= N). - */ - override def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { - assert(numBytes >= 0) - assert(_executionMemoryUsed <= maxExecutionMemory) - val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) - _executionMemoryUsed += bytesToGrant - bytesToGrant - } - - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) - } - - /** - * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. - * - * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage - * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any, - * are added to `evictedBlocks`. - * - * @return whether all N bytes were successfully granted. - */ override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - val currentUnrollMemory = memoryStore.currentUnrollMemory + val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) val numBytesToFree = math.min(numBytes, maxNumBytesToFree) - acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } - - /** - * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. - * - * @param blockId the ID of the block we are acquiring storage memory for - * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks - * @param evictedBlocks a holder for blocks evicted in the process - * @return whether all N bytes were successfully granted. - */ - private def acquireStorageMemory( - blockId: BlockId, - numBytesToAcquire: Long, - numBytesToFree: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - assert(numBytesToAcquire >= 0) - assert(numBytesToFree >= 0) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - assert(_storageMemoryUsed <= maxStorageMemory) - val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory - if (enoughMemory) { - _storageMemoryUsed += numBytesToAcquire - } - enoughMemory - } - } @@ -135,7 +77,6 @@ private[spark] object StaticMemoryManager { (systemMaxMemory * memoryFraction * safetyFraction).toLong } - /** * Return the total amount of memory available for the execution region, in bytes. */ diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala new file mode 100644 index 0000000000000..6a322eabf81ed --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} + +/** + * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage + * (caching). + * + * @param lock a [[MemoryManager]] instance to synchronize on + */ +class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { + + @GuardedBy("lock") + private[this] var _memoryUsed: Long = 0L + + override def memoryUsed: Long = lock.synchronized { + _memoryUsed + } + + private var _memoryStore: MemoryStore = _ + def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalStateException("memory store not initialized yet") + } + _memoryStore + } + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + acquireMemory(blockId, numBytes, numBytes, evictedBlocks) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the size of space to be freed through evicting blocks + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) + assert(memoryUsed <= poolSize) + memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + // NOTE: If the memory store evicts blocks, then those evictions will synchronously call + // back into this StorageMemoryPool in order to free. Therefore, these variables should have + // been updated. + val enoughMemory = numBytesToAcquire <= memoryFree + if (enoughMemory) { + _memoryUsed += numBytesToAcquire + } + enoughMemory + } + + def releaseMemory(size: Long): Unit = lock.synchronized { + if (size > _memoryUsed) { + logWarning(s"Attempted to release $size bytes of storage " + + s"memory when we only have ${_memoryUsed} bytes") + _memoryUsed = 0 + } else { + _memoryUsed -= size + } + } + + def releaseAllMemory(): Unit = lock.synchronized { + _memoryUsed = 0 + } + + /** + * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number + * of bytes removed from the pool's capacity. + */ + def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { + // First, shrink the pool by reclaiming free memory: + val spaceFreedByReleasingUnusedMemory = Math.min(spaceToFree, memoryFree) + decrementPoolSize(spaceFreedByReleasingUnusedMemory) + if (spaceFreedByReleasingUnusedMemory == spaceToFree) { + spaceFreedByReleasingUnusedMemory + } else { + // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + memoryStore.ensureFreeSpace(spaceToFree - spaceFreedByReleasingUnusedMemory, evictedBlocks) + val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum + _memoryUsed -= spaceFreedByEviction + decrementPoolSize(spaceFreedByEviction) + spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index a3093030a0f93..8be5b05419094 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockStatus, BlockId} - /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. @@ -41,98 +40,105 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * The implication is that attempts to cache blocks may fail if execution has already eaten * up most of the storage space, in which case the new blocks will be evicted immediately * according to their respective storage levels. + * + * @param storageRegionSize Size of the storage region, in bytes. + * This region is not statically reserved; execution can borrow from + * it if necessary. Cached blocks can be evicted only if actual + * storage memory usage exceeds this region. */ -private[spark] class UnifiedMemoryManager( +private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, maxMemory: Long, + private val storageRegionSize: Long, numCores: Int) - extends MemoryManager(conf, numCores) { - - def this(conf: SparkConf, numCores: Int) { - this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores) - } - - /** - * Size of the storage region, in bytes. - * - * This region is not statically reserved; execution can borrow from it if necessary. - * Cached blocks can be evicted only if actual storage memory usage exceeds this region. - */ - private val storageRegionSize: Long = { - (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong - } - - /** - * Total amount of memory, in bytes, not currently occupied by either execution or storage. - */ - private def totalFreeMemory: Long = synchronized { - assert(_executionMemoryUsed <= maxMemory) - assert(_storageMemoryUsed <= maxMemory) - assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory) - maxMemory - _executionMemoryUsed - _storageMemoryUsed - } + extends MemoryManager( + conf, + numCores, + storageRegionSize, + maxMemory - storageRegionSize) { - /** - * Total available memory for execution, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by storage. - */ - override def maxExecutionMemory: Long = synchronized { - maxMemory - _storageMemoryUsed - } + // We always maintain this invariant: + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) - /** - * Total available memory for storage, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by execution. - */ override def maxStorageMemory: Long = synchronized { - maxMemory - _executionMemoryUsed + maxMemory - onHeapExecutionMemoryPool.memoryUsed } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * - * This method evicts blocks only up to the amount of memory borrowed by storage. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. */ - private[memory] override def doAcquireExecutionMemory( + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) - val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize) - // If there is not enough free memory AND storage has borrowed some execution memory, - // then evict as much memory borrowed by storage as needed to grant this request - val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0 - if (shouldEvictStorage) { - val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage) - memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks) + memoryMode match { + case MemoryMode.ON_HEAP => + if (numBytes > onHeapExecutionMemoryPool.memoryFree) { + val extraMemoryNeeded = numBytes - onHeapExecutionMemoryPool.memoryFree + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } + } + onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => + // For now, we only support on-heap caching of data, so we do not need to interact with + // the storage pool when allocating off-heap memory. This will change in the future, though. + super.acquireExecutionMemory(numBytes, taskAttemptId, memoryMode) } - val bytesToGrant = math.min(numBytes, totalFreeMemory) - _executionMemoryUsed += bytesToGrant - bytesToGrant } - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ override def acquireStorageMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) - memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks) - val enoughMemory = totalFreeMemory >= numBytes - if (enoughMemory) { - _storageMemoryUsed += numBytes + if (numBytes > storageMemoryPool.memoryFree) { + // There is not enough free memory in the storage pool, so try to borrow free memory from + // the execution pool. + val memoryBorrowedFromExecution = Math.min(onHeapExecutionMemoryPool.memoryFree, numBytes) + onHeapExecutionMemoryPool.decrementPoolSize(memoryBorrowedFromExecution) + storageMemoryPool.incrementPoolSize(memoryBorrowedFromExecution) } - enoughMemory + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) } + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, evictedBlocks) + } } -private object UnifiedMemoryManager { +object UnifiedMemoryManager { + + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { + val maxMemory = getMaxMemory(conf) + new UnifiedMemoryManager( + conf, + maxMemory = maxMemory, + storageRegionSize = + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong, + numCores = numCores) + } /** * Return the total amount of memory shared between execution and storage, in bytes. diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala new file mode 100644 index 0000000000000..564e30d2ffd66 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This package implements Spark's memory management system. This system consists of two main + * components, a JVM-wide memory manager and a per-task manager: + * + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution (memory + * used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual tasks. + * Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. + * + * Internally, each of these components have additional abstractions for memory bookkeeping: + * + * - [[org.apache.spark.memory.MemoryConsumer]]s are clients of the TaskMemoryManager and + * correspond to individual operators and data structures within a task. The TaskMemoryManager + * receives memory allocation requests from MemoryConsumers and issues callbacks to consumers + * in order to trigger spilling when running low on memory. + * - [[org.apache.spark.memory.MemoryPool]]s are a bookkeeping abstraction used by the + * MemoryManager to track the division of memory between storage and execution. + * + * Diagrammatically: + * + * {{{ + * +-------------+ + * | MemConsumer |----+ +------------------------+ + * +-------------+ | +-------------------+ | MemoryManager | + * +--->| TaskMemoryManager |----+ | | + * +-------------+ | +-------------------+ | | +------------------+ | + * | MemConsumer |----+ | | | StorageMemPool | | + * +-------------+ +-------------------+ | | +------------------+ | + * | TaskMemoryManager |----+ | | + * +-------------------+ | | +------------------+ | + * +---->| |OnHeapExecMemPool | | + * * | | +------------------+ | + * * | | | + * +-------------+ * | | +------------------+ | + * | MemConsumer |----+ | | |OffHeapExecMemPool| | + * +-------------+ | +-------------------+ | | +------------------+ | + * +--->| TaskMemoryManager |----+ | | + * +-------------------+ +------------------------+ + * }}} + * + * + * There are two implementations of [[org.apache.spark.memory.MemoryManager]] which vary in how + * they handle the sizing of their memory pools: + * + * - [[org.apache.spark.memory.UnifiedMemoryManager]], the default in Spark 1.6+, enforces soft + * boundaries between storage and execution memory, allowing requests for memory in one region + * to be fulfilled by borrowing memory from the other. + * - [[org.apache.spark.memory.StaticMemoryManager]] enforces hard boundaries between storage + * and execution memory by statically partitioning Spark's memory and preventing storage and + * execution from borrowing memory from each other. This mode is retained only for legacy + * compatibility purposes. + */ +package object memory diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 9e002621a6909..3a48af82b1dae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.{Logging, SparkEnv} /** @@ -78,7 +78,8 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null) + val granted = + taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -107,7 +108,8 @@ private[spark] trait Spillable[C] extends Logging { */ def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null) + taskMemoryManager.releaseExecutionMemory( + myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index c731317395612..711eed0193bc0 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -28,8 +28,14 @@ public class TaskMemoryManagerSuite { @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new StaticMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "false"), + Long.MAX_VALUE, + Long.MAX_VALUE, + 1), + 0); manager.allocatePage(4096, null); // leak memory + Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index 8ae3642738509..e6e16fff80401 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -32,13 +32,19 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } void use(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); + long got = taskMemoryManager.acquireExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); used += got; } void free(long size) { used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + taskMemoryManager.releaseExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4763395d7d401..0e0eca515afc1 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -423,7 +423,7 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 92bd45e5fa241..3bca790f30870 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -83,7 +83,9 @@ public OutputStream apply(OutputStream stream) { public void setup() { memoryManager = new TestMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())); + new SparkConf() + .set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeapSize", "256mb")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 4a9479cf490fb..f55d435fa33a6 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} @@ -29,7 +30,7 @@ import org.mockito.stubbing.Answer import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.MemoryStore +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel} /** @@ -78,7 +79,12 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") val numBytes = args(numBytesPos).asInstanceOf[Long] - mockEnsureFreeSpace(mm, numBytes) + val success = mockEnsureFreeSpace(mm, numBytes) + if (success) { + args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytes, 0L, 0L))) + } + success } } } @@ -132,93 +138,95 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { } /** - * Create a MemoryManager with the specified execution memory limit and no storage memory. + * Create a MemoryManager with the specified execution memory limits and no storage memory. */ - protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager + protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long = 0L): MemoryManager // -- Tests of sharing of execution memory between tasks ---------------------------------------- // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. implicit val ec = ExecutionContext.global - test("single task requesting execution memory") { + test("single task requesting on-heap execution memory") { val manager = createMemoryManager(1000L) val taskMemoryManager = new TaskMemoryManager(manager, 0) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) - taskMemoryManager.releaseExecutionMemory(500L, null) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L) + taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L) taskMemoryManager.cleanUpAllAllocatedMemory() - assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) } - test("two tasks requesting full execution memory") { + test("two tasks requesting full on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 500 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 500L) assert(Await.result(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) assert(Await.result(t2Result2, 200.millis) === 0L) } - test("two tasks cannot grow past 1 / N of execution memory") { + test("two tasks cannot grow past 1 / N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 250 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 250L) assert(Await.result(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, futureTimeout) === 250L) assert(Await.result(t2Result2, futureTimeout) === 250L) } - test("tasks can block to get at least 1 / 2N of execution memory") { + test("tasks can block to get at least 1 / 2N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) - t1MemManager.releaseExecutionMemory(250L, null) + t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) // The memory freed from t1 should now be granted to t2. assert(Await.result(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, 200.millis) === 0L) } @@ -229,18 +237,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() assert(Await.result(t2Result1, futureTimeout) === 500L) - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, futureTimeout) === 500L) - val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result3, 200.millis) === 0L) } @@ -251,15 +259,35 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 700L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result1, futureTimeout) === 300L) - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) } + + test("off-heap execution allocations cannot exceed limit") { + val memoryManager = createMemoryManager( + maxOnHeapExecutionMemory = 0L, + maxOffHeapExecutionMemory = 1000L) + + val tMemManager = new TaskMemoryManager(memoryManager, 1) + val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result1, 200.millis) === 1000L) + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + + val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result2, 200.millis) === 0L) + + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 500L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 0L) + } } private object MemoryManagerSuite { diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 885c450d6d4f5..54cb28c389c2f 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -24,7 +24,6 @@ import org.mockito.Mockito.when import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] @@ -36,38 +35,47 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { maxExecutionMem: Long, maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1) + conf, + maxOnHeapExecutionMemory = maxExecutionMem, + maxStorageMemory = maxStorageMem, + numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): StaticMemoryManager = { new StaticMemoryManager( - conf, - maxExecutionMemory = maxMemory, + conf.clone + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString), + maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxStorageMemory = 0, numCores = 1) } test("basic execution memory") { val maxExecutionMem = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxExecutionMem) + mm.releaseExecutionMemory(maxExecutionMem, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 0L) } @@ -113,13 +121,14 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { test("execution and storage isolation") { val maxExecutionMem = 200L val maxStorageMem = 1000L + val taskAttemptId = 0L val dummyBlock = TestBlockId("ain't nobody love like you do") val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) // Only execution memory should increase - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase @@ -128,7 +137,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released - mm.releaseExecutionMemory(133L) + mm.releaseExecutionMemory(133L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 67L) // Only storage memory should be released diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 77e43554ee27c..0706a6e45de8f 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -22,19 +22,20 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockStatus, BlockId} -class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) { - private[memory] override def doAcquireExecutionMemory( +class TestMemoryManager(conf: SparkConf) + extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { + + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = { if (oomOnce) { oomOnce = false 0 } else if (available >= numBytes) { - _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory available -= numBytes numBytes } else { - _executionMemoryUsed += available val grant = available available = 0 grant @@ -48,12 +49,13 @@ class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { + override def releaseStorageMemory(numBytes: Long): Unit = {} + override private[memory] def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = { available += numBytes - _executionMemoryUsed -= numBytes } - override def releaseStorageMemory(numBytes: Long): Unit = {} - override def maxExecutionMemory: Long = Long.MaxValue override def maxStorageMemory: Long = Long.MaxValue private var oomOnce = false diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0c97f2bd89651..8cebe81c3bfff 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -24,57 +24,52 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { - private val conf = new SparkConf().set("spark.memory.storageFraction", "0.5") private val dummyBlock = TestBlockId("--") private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + private val storageFraction: Double = 0.5 + /** * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. */ private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { - val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1) + val mm = createMemoryManager(maxMemory) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { - new UnifiedMemoryManager(conf, maxMemory, numCores = 1) - } - - private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { - mm invokePrivate PrivateMethod[Long]('storageRegionSize)() - } - - test("storage region size") { - val maxMemory = 1000L - val (mm, _) = makeThings(maxMemory) - val storageFraction = conf.get("spark.memory.storageFraction").toDouble - val expectedStorageRegionSize = maxMemory * storageFraction - val actualStorageRegionSize = getStorageRegionSize(mm) - assert(expectedStorageRegionSize === actualStorageRegionSize) + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): UnifiedMemoryManager = { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString) + .set("spark.memory.storageFraction", storageFraction.toString) + UnifiedMemoryManager(conf, numCores = 1) } test("basic execution memory") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxMemory) assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) assert(mm.executionMemoryUsed === maxMemory) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) assert(mm.executionMemoryUsed === maxMemory) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxMemory) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 0L) } @@ -118,44 +113,34 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("execution evicts storage") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") // Acquire enough storage memory to exceed the storage region assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assertEnsureFreeSpaceCalled(ms, 750L) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) - require(mm.storageMemoryUsed > storageRegionSize, - s"bad test: storage memory used should exceed the storage region") // Execution needs to request 250 bytes to evict storage memory - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) assertEnsureFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) - assertEnsureFreeSpaceCalled(ms, 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) + assert(mm.executionMemoryUsed === 300L) + assertEnsureFreeSpaceCalled(ms, 50L) assert(mm.executionMemoryUsed === 300L) mm.releaseAllStorageMemory() - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") + require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) assertEnsureFreeSpaceCalled(ms, 400L) - require(mm.storageMemoryUsed < storageRegionSize, - s"bad test: storage memory used should be within the storage region") + assert(mm.storageMemoryUsed === 400L) + assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 - assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) assertEnsureFreeSpaceNotCalled(ms) @@ -163,23 +148,13 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("storage does not evict execution") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") // Acquire enough execution memory to exceed the execution region - assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed > executionRegionSize, - s"bad test: execution memory used should exceed the execution region") // Storage should not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) @@ -189,15 +164,13 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) assertEnsureFreeSpaceCalled(ms, 250L) - mm.releaseExecutionMemory(maxMemory) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") // Storage should still not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d49015afcd594..53991d8a1aede 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -825,7 +825,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memoryManager = new StaticMemoryManager( conf, - maxExecutionMemory = Long.MaxValue, + maxOnHeapExecutionMemory = Long.MaxValue, maxStorageMemory = 1200, numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, From 7f741905b06ed6d3dfbff6db41a3355dab71aa3c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 7 Nov 2015 05:35:53 +0100 Subject: [PATCH 227/324] [SPARK-11112] DAG visualization: display RDD callsite screen shot 2015-11-01 at 9 42 33 am mateiz sarutak Author: Andrew Or Closes #9398 from andrewor14/rdd-callsite. --- .../apache/spark/ui/static/spark-dag-viz.css | 4 ++ .../org/apache/spark/storage/RDDInfo.scala | 16 +++++++- .../spark/ui/scope/RDDOperationGraph.scala | 10 +++-- .../org/apache/spark/util/JsonProtocol.scala | 17 ++++++++- .../scala/org/apache/spark/util/Utils.scala | 1 + .../org/apache/spark/ui/UISeleniumSuite.scala | 14 +++---- .../apache/spark/util/JsonProtocolSuite.scala | 37 ++++++++++++++++--- 7 files changed, 79 insertions(+), 20 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 3b4ae2ed354b8..9cc5c79f67346 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -122,3 +122,7 @@ stroke: #52C366; stroke-width: 2px; } + +.tooltip-inner { + white-space: pre-wrap; +} diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 96062626b5045..3fa209b924170 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDDOperationScope, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} @DeveloperApi class RDDInfo( @@ -28,9 +28,20 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], + val callSite: CallSite, val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { + def this( + id: Int, + name: String, + numPartitions: Int, + storageLevel: StorageLevel, + parentIds: Seq[Int], + scope: Option[RDDOperationScope] = None) { + this(id, name, numPartitions, storageLevel, parentIds, CallSite.empty, scope) + } + var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L @@ -56,6 +67,7 @@ private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, + rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 81f168a447ead..24274562657b3 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -38,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -104,8 +105,8 @@ private[ui] object RDDOperationGraph extends Logging { edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) } // TODO: differentiate between the intention to cache an RDD and whether it's actually cached - val node = nodes.getOrElseUpdate( - rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE)) + val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode( + rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite)) if (rdd.scope.isEmpty) { // This RDD has no encompassing scope, so we put it directly in the root cluster @@ -177,7 +178,8 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - s"""${node.id} [label="${node.name} [${node.id}]"]""" + val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + s"""${node.id} [label="$label"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee2eb58cf5e2a..c9beeb25e05af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,6 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ + ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -407,6 +408,11 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } + def callsiteToJson(callsite: CallSite): JValue = { + ("Short Form" -> callsite.shortForm) ~ + ("Long Form" -> callsite.longForm) + } + def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -851,6 +857,9 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) + val callsite = Utils.jsonOption(json \ "Callsite") + .map(callsiteFromJson) + .getOrElse(CallSite.empty) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -863,7 +872,7 @@ private[spark] object JsonProtocol { .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope) + val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.externalBlockStoreSize = externalBlockStoreSize @@ -871,6 +880,12 @@ private[spark] object JsonProtocol { rddInfo } + def callsiteFromJson(json: JValue): CallSite = { + val shortForm = (json \ "Short Form").extract[String] + val longForm = (json \ "Long Form").extract[String] + CallSite(shortForm, longForm) + } + def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5a976ee839b1e..316c194ff3454 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -57,6 +57,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object CallSite { val SHORT_FORM = "callSite.short" val LONG_FORM = "callSite.long" + val empty = CallSite("", "") } /** diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 18eec7da9763e..ceecfd665bf87 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -615,29 +615,29 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]"];\n }")) + "0 [label="ParallelCollectionRDD [0]")) assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]"];\n }")) + "1 [label="MapPartitionsRDD [1]")) assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]"];\n }")) + "2 [label="MapPartitionsRDD [2]")) val stage1 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]"];\n }")) + "3 [label="ShuffledRDD [3]")) assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]"];\n }")) + "4 [label="MapPartitionsRDD [4]")) assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]"];\n }")) + "5 [label="MapPartitionsRDD [5]")) val stage2 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]"];\n }")) + "6 [label="ShuffledRDD [6]")) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 953456c2caa89..3f94ef7041914 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,6 +111,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) + testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -163,6 +164,10 @@ class JsonProtocolSuite extends SparkFunSuite { testBlockId(StreamBlockId(1, 2L)) } + /* ============================== * + | Backward compatibility tests | + * ============================== */ + test("ExceptionFailure backward compatibility") { val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None, None) @@ -334,14 +339,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) } - test("RDDInfo backward compatibility (scope, parent IDs)") { - // Prior to Spark 1.4.0, RDDInfo did not have the "Scope" and "Parent IDs" properties - val rddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), Some(new RDDOperationScope("fable"))) + test("RDDInfo backward compatibility (scope, parent IDs, callsite)") { + // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 + // "Callsite" was introduced in Spark 1.6.0 + val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), + CallSite("short", "long"), Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) - val expectedRddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq.empty, scope = None) + .removeField({ _._1 == "Callsite"}) + val expectedRddInfo = new RDDInfo( + 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -389,6 +397,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testCallsite(callsite: CallSite): Unit = { + val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) + assert(callsite === newCallsite) + } + private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -713,7 +726,8 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, + Seq(1, 4, 7), CallSite(a.toString, b.toString)) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -856,6 +870,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", + | "Callsite": {"Short Form": "101", "Long Form": "201"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1258,6 +1273,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", + | "Callsite": {"Short Form": "1", "Long Form": "200"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1301,6 +1317,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", + | "Callsite": {"Short Form": "2", "Long Form": "400"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1318,6 +1335,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "401"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1361,6 +1379,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "600"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1378,6 +1397,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "601"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1395,6 +1415,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "602"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1438,6 +1459,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "800"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1455,6 +1477,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "801"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1472,6 +1495,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", + | "Callsite": {"Short Form": "6", "Long Form": "802"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1489,6 +1513,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", + | "Callsite": {"Short Form": "7", "Long Form": "803"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, From 2ff0e79a8647cca5c9c57f613a07e739ac4f677e Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 6 Nov 2015 22:56:29 -0800 Subject: [PATCH 228/324] [SPARK-8467] [MLLIB] [PYSPARK] Add LDAModel.describeTopics() in Python Could jkbradley and davies review it? - Create a wrapper class: `LDAModelWrapper` for `LDAModel`. Because we can't deal with the return value of`describeTopics` in Scala from pyspark directly. `Array[(Array[Int], Array[Double])]` is too complicated to convert it. - Add `loadLDAModel` in `PythonMLlibAPI`. Since `LDAModel` in Scala is an abstract class and we need to call `load` of `DistributedLDAModel`. [[SPARK-8467] Add LDAModel.describeTopics() in Python - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8467) Author: Yu ISHIKAWA Closes #8643 from yu-iskw/SPARK-8467-2. --- .../mllib/api/python/LDAModelWrapper.scala | 46 +++++++++++++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 13 +++++- python/pyspark/mllib/clustering.py | 33 +++++++------ 3 files changed, 75 insertions(+), 17 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 0000000000000..63282eee6e656 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 40c41806cdfea..54b03a9f90283 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8629aa5a17164..12081f8c69075 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -671,7 +671,7 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) -class LDAModel(JavaModelWrapper): +class LDAModel(JavaModelWrapper, JavaSaveable, Loader): """ A clustering model derived from the LDA method. @@ -691,9 +691,14 @@ class LDAModel(JavaModelWrapper): ... [2, SparseVector(2, {0: 1.0})], ... ] >>> rdd = sc.parallelize(data) - >>> model = LDA.train(rdd, k=2) + >>> model = LDA.train(rdd, k=2, seed=1) >>> model.vocabSize() 2 + >>> model.describeTopics() + [([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])] + >>> model.describeTopics(1) + [([1], [0.5...]), ([0], [0.5...])] + >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) @@ -724,18 +729,17 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - @since('1.5.0') - def save(self, sc, path): - """Save the LDAModel on to disk. + @since('1.6.0') + def describeTopics(self, maxTermsPerTopic=None): + """Return the topics described by weighted terms. - :param sc: SparkContext - :param path: str, path to where the model needs to be stored. + WARNING: If vocabSize and k are large, this can return a large object! """ - if not isinstance(sc, SparkContext): - raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_model.save(sc._jsc.sc(), path) + if maxTermsPerTopic is None: + topics = self.call("describeTopics") + else: + topics = self.call("describeTopics", maxTermsPerTopic) + return topics @classmethod @since('1.5.0') @@ -749,9 +753,8 @@ def load(cls, sc, path): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( - sc._jsc.sc(), path) - return cls(java_model) + model = callMLlibFunc("loadLDAModel", sc, path) + return LDAModel(model) class LDA(object): From ef362846eb448769bcf774fc9090a5013d459464 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 7 Nov 2015 13:37:37 -0800 Subject: [PATCH 229/324] [SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite. cc yhuai marmbrus Author: Herman van Hovell Closes #9541 from hvanhovell/SPARK-9241-followup. --- .../expressions/aggregate/Utils.scala | 114 ++++++++++++++---- .../execution/AggregationQuerySuite.scala | 17 +++ 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 39010c3be6d4e..ac23f727829b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -222,10 +222,76 @@ object Utils { * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * - * TODO Expression cannocalization - * TODO Eliminate foldable expressions from distinct clauses. - * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate - * operator. Perhaps this is a good thing? It is much simpler to plan later on... + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. */ object MultipleDistinctRewriter extends Rule[LogicalPlan] { @@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Functions used to modify aggregate functions and their inputs. def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( - af: AggregateFunction2, - id: Literal, - attrs: Map[Expression, Expression]): AggregateFunction2 = { - af.withNewChildren(af.children.map { case afc => - evalWithinGroup(id, attrs(afc)) + af: AggregateFunction2)( + attrs: Expression => Expression): AggregateFunction2 = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) }).asInstanceOf[AggregateFunction2] } @@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Final aggregate val operators = expressions.map { e => val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrMap(x)) + } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val regularGroupId = Literal(0) val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren( - e.aggregateFunction, - regularGroupId, - regularAggChildAttrMap) - val a = Alias(e.copy(aggregateFunction = af), e.toString)() - - // Get the result of the first aggregate in the last aggregate. - val b = AggregateExpression2( - aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap) + val operator = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) // Some aggregate functions (COUNT) have the special property that they can return a // non-null result without any input. We need to make sure we return a result in this case. - val c = af.defaultResult match { - case Some(lit) => Coalesce(Seq(b, lit)) - case None => b + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result } - (e, a, c) + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) } // Construct the regular aggregate input projection only if we need one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea80060e370e0..7f6fe339232ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } + test("multiple distinct column sets") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1), + | count(distinct value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 3) :: + Row(1, 2, 3) :: + Row(2, 2, 1) :: + Row(3, 0, 1) :: Nil) + } + test("test count") { checkAnswer( sqlContext.sql( From 4b69a42eda3aff08eb7437c353fe2cc87ed67181 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Nov 2015 19:44:45 -0800 Subject: [PATCH 230/324] [SPARK-11362] [SQL] Use Spark BitSet in BroadcastNestedLoopJoin JIRA: https://issues.apache.org/jira/browse/SPARK-11362 We use scala.collection.mutable.BitSet in BroadcastNestedLoopJoin now. We should use Spark's BitSet. Author: Liang-Chi Hsieh Closes #9316 from viirya/use-spark-bitset. --- .../joins/BroadcastNestedLoopJoin.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 05d20f511aef8..aab177b2e8427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoin( @@ -95,9 +95,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new CompactBuffer[InternalRow] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -115,11 +113,11 @@ case class BroadcastNestedLoopJoin( case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case _ => } i += 1 @@ -138,8 +136,8 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - )(_ ++ _) + new BitSet(broadcastedRelation.value.size) + )(_ | _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -155,7 +153,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withLeft(leftNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 @@ -164,7 +162,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withRight(rightNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 From d981902101767b32dc83a5a639311e197f5cbcc1 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 8 Nov 2015 11:15:58 +0000 Subject: [PATCH 231/324] [SPARK-11476][DOCS] Incorrect function referred to in MLib Random data generation documentation Fix Python example to use normalRDD as advertised Author: Sean Owen Closes #9529 from srowen/SPARK-11476. --- docs/mllib-statistics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 2c7c9ed693fd4..ade5b0768aefe 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -594,7 +594,7 @@ sc = ... # SparkContext # Generate a random double RDD that contains 1 million i.i.d. values drawn from the # standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +u = RandomRDDs.normalRDD(sc, 1000000L, 10) # Apply a transform to get a random double RDD following `N(1, 4)`. v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} From 5c4e6d7ec9157c02494a382dfb49e7fbde3be222 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Sun, 8 Nov 2015 14:24:26 +0000 Subject: [PATCH 232/324] [DOC][SQL] Remove redundant out-of-place python snippet This snippet seems to be mistakenly introduced at two places in #5348. Author: Rohit Agarwal Closes #9540 from mindprince/patch-1. --- docs/sql-programming-guide.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fe5c36338899..085874133d968 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1089,15 +1089,6 @@ for (teenName in collect(teenNames)) { -
    - -{% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.sql("REFRESH TABLE my_table") -{% endhighlight %} - -
    -
    {% highlight sql %} From 30c8ba71a76788cbc6916bc1ba6bc8522925fc2b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 8 Nov 2015 11:06:10 -0800 Subject: [PATCH 233/324] [SPARK-11451][SQL] Support single distinct count on multiple columns. This PR adds support for multiple column in a single count distinct aggregate to the new aggregation path. cc yhuai Author: Herman van Hovell Closes #9409 from hvanhovell/SPARK-11451. --- .../expressions/aggregate/Utils.scala | 44 +++++++++++-------- .../expressions/conditionalExpressions.scala | 30 ++++++++++++- .../plans/logical/basicOperators.scala | 3 ++ .../ConditionalExpressionSuite.scala | 14 ++++++ .../spark/sql/DataFrameAggregateSuite.scala | 25 +++++++++++ .../execution/AggregationQuerySuite.scala | 37 +++++++++++++--- 6 files changed, 127 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index ac23f727829b6..9b22ce2619731 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -22,26 +22,27 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} +import org.apache.spark.sql.types._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - !hasComplexTypes + // Check if the DataType given cannot be part of a group by clause. + private def isUnGroupable(dt: DataType): Boolean = dt match { + case _: ArrayType | _: MapType => true + case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType)) + case _ => false } + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = + !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType)) + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( @@ -55,10 +56,14 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => + case expressions.CountDistinct(children) => + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), + aggregateFunction = aggregate.Count(child), mode = aggregate.Complete, isDistinct = true) @@ -320,7 +325,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -365,14 +370,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Setup expand for the 'regular' aggregate expressions. val regularAggExprs = aggExpressions.filter(!_.isDistinct) val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap) - val operator = Alias(e.copy(aggregateFunction = af), e.toString)() + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression2( @@ -416,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Construct the expand operator. val expand = Expand( regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), a.child) // Construct the first aggregate operator. This de-duplicates the all the children of @@ -457,5 +463,5 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyName, e.dataType, true)() + e -> new AttributeReference(e.prettyString, e.dataType, true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d532629984bec..0d4af43978ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} +import org.apache.spark.sql.types._ case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends Expression { """ } } + +/** Operator that drops a row when it contains any nulls. */ +case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) + + protected override def nullSafeEval(input: Any): InternalRow = { + val row = input.asInstanceOf[InternalRow] + if (row.anyNull) { + null + } else { + row + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.anyNull()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $eval; + } + """ + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fb963e2f8f7e7..09aac00a455f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -306,6 +306,9 @@ case class Expand( output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + override def statistics: Statistics = { // TODO shouldn't we factor in the size of the projection versus the size of the backing child // row? diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa02..c1e3c17b87102 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } + + test("function dropAnyNull") { + val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) + val a = create_row("a", "q") + val nullStr: String = null + checkEvaluation(drop, a, a) + checkEvaluation(drop, null, create_row("b", nullStr)) + checkEvaluation(drop, null, create_row(nullStr, nullStr)) + + val row = 'r.struct( + StructField("a", StringType, false), + StructField("b", StringType, true)).at(0) + checkEvaluation(DropAnyNull(row), null, create_row(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2e679e7bc4e0a..eb1ee266c5d28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("multiple column distinct count") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkAnswer( + df1.agg(countDistinct('key1, 'key2)), + Row(3) + ) + + checkAnswer( + df1.agg(countDistinct('key1, 'key2, 'key3)), + Row(3) + ) + + checkAnswer( + df1.groupBy('key1).agg(countDistinct('key2, 'key3)), + Seq(Row("a", 2), Row("x", 1)) + ) + } + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7f6fe339232ad..ea36c132bb190 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -516,21 +516,46 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } - test("multiple distinct column sets") { + test("single distinct multiple columns set") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1, value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3) :: + Row(1, 3) :: + Row(2, 1) :: + Row(3, 0) :: Nil) + } + + test("multiple distinct multiple columns sets") { checkAnswer( sqlContext.sql( """ |SELECT | key, | count(distinct value1), - | count(distinct value2) + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | count(*), + | count(1) |FROM agg2 |GROUP BY key """.stripMargin), - Row(null, 3, 3) :: - Row(1, 2, 3) :: - Row(2, 2, 1) :: - Row(3, 0, 1) :: Nil) + Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil) } test("test count") { From 26739059bc39cd7aa7e0b1c16089c1cf8d8e4d7d Mon Sep 17 00:00:00 2001 From: xin Wu Date: Sun, 8 Nov 2015 12:28:19 -0800 Subject: [PATCH 234/324] =?UTF-8?q?[SPARK-10046][SQL]=20Hive=20warehouse?= =?UTF-8?q?=20dir=20not=20set=20in=20current=20directory=20when=20not=20?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Doc change to align with HiveConf default in terms of where to create `warehouse` directory. Author: xin Wu Closes #9365 from xwu0226/spark-10046-commit. --- docs/sql-programming-guide.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 085874133d968..52e03b951f966 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1627,8 +1627,10 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. +hive-site.xml, the context automatically creates `metastore_db` in the current directory and +creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. +Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts +the spark application. {% highlight scala %} // sc is an existing SparkContext. From b2d195e137fad88d567974659fa7023ff4da96cd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 8 Nov 2015 12:59:35 -0800 Subject: [PATCH 235/324] [SPARK-11554][SQL] add map/flatMap to GroupedDataset Author: Wenchen Fan Closes #9521 from cloud-fan/map. --- .../plans/logical/basicOperators.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 29 ++++++++++++-- .../spark/sql/execution/basicOperators.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 16 ++++---- .../spark/sql/DatasetPrimitiveSuite.scala | 16 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 40 +++++++++---------- 6 files changed, 70 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09aac00a455f9..e151ac04ede2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -494,7 +494,7 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( @@ -514,7 +514,7 @@ object MapGroups { * object representation of all the rows with that key. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b2803d5a9a1e3..5c3f626545875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } - def mapGroups[U]( + def flatMap[U]( f: JFunction2[K, JIterator[T], JIterator[U]], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) + new Dataset[U]( + sqlContext, + MapGroups(func, groupingAttributes, logicalPlan)) + } + + def map[U]( + f: JFunction2[K, JIterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + map((key, data) => f.call(key, data.asJava))(encoder) } // To ensure valid overloading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a4f784f..2593b16b1c8d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -356,7 +356,7 @@ case class AppendColumns[T, U]( * being output. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a9493d576d179..0d3b1a5af52c4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,15 +170,15 @@ public Integer call(String v) throws Exception { } }, e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); @@ -224,15 +224,15 @@ public void testGroupByColumn() { Dataset ds = context.createDataset(data, e.STRING()); GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index e3b0346f857d3..fcf03f7180984 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroups { case (g, iter) => + val agged = grouped.map { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } checkAnswer( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d61e17edc64ed..6f1174e6577e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, fatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey, mapGroups") { + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey tuple, mapGroups") { + test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } - test("groupBy columns asKey class, mapGroups") { + test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 97b7080cf2d2846c7257f8926f775f27d457fe7d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 8 Nov 2015 20:57:09 -0800 Subject: [PATCH 236/324] [SPARK-11564][SQL] Dataset Java API audit A few changes: 1. Removed fold, since it can be confusing for distributed collections. 2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction) 3. Added more documentation and test cases. The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function. Author: Reynold Xin Closes #9531 from rxin/SPARK-11564. --- .../api/java/function/FilterFunction.java | 29 +++++ .../api/java/function/ForeachFunction.java | 29 +++++ .../function/ForeachPartitionFunction.java | 28 +++++ .../spark/api/java/function/Function0.java | 2 +- .../spark/api/java/function/MapFunction.java | 27 +++++ .../java/function/MapPartitionsFunction.java | 28 +++++ .../api/java/function/ReduceFunction.java | 27 +++++ .../spark/sql/catalyst/encoders/Encoder.scala | 38 +++++-- .../org/apache/spark/sql/DataFrame.scala | 47 ++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 100 +++++++++--------- .../apache/spark/sql/JavaDataFrameSuite.java | 7 ++ .../apache/spark/sql/JavaDatasetSuite.java | 36 +++---- .../spark/sql/DatasetPrimitiveSuite.scala | 5 - .../org/apache/spark/sql/DatasetSuite.scala | 10 +- 14 files changed, 316 insertions(+), 97 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java new file mode 100644 index 0000000000000..e8d999dd00135 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's filter function. + * + * If the function returns true, the element is discarded in the returned Dataset. + */ +public interface FilterFunction extends Serializable { + boolean call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java new file mode 100644 index 0000000000000..07e54b28fa12c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's foreach function. + * + * Spark will invoke the call function on each element in the input Dataset. + */ +public interface ForeachFunction extends Serializable { + void call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java new file mode 100644 index 0000000000000..4938a51bcd712 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a function used in Dataset's foreachPartition function. + */ +public interface ForeachPartitionFunction extends Serializable { + void call(Iterator t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index 38e410c5debe6..c86928dd05408 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -23,5 +23,5 @@ * A zero-argument function that returns an R. */ public interface Function0 extends Serializable { - public R call() throws Exception; + R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java new file mode 100644 index 0000000000000..3ae6ef44898e1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a map function used in Dataset's map function. + */ +public interface MapFunction extends Serializable { + U call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java new file mode 100644 index 0000000000000..6cb569ce0cb6b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for function used in Dataset's mapPartitions. + */ +public interface MapPartitionsFunction extends Serializable { + Iterable call(Iterator input) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java new file mode 100644 index 0000000000000..ee092d0058f44 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for function used in Dataset's reduce. + */ +public interface ReduceFunction extends Serializable { + T call(T v1, T v2) throws Exception; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index f05e18288de2b..6569b900fed90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.sql.catalyst.expressions._ /** @@ -100,7 +100,7 @@ object Encoder { expr.transformUp { case BoundReference(0, t: ObjectType, _) => Invoke( - BoundReference(0, ObjectType(cls), true), + BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", t) } @@ -114,13 +114,13 @@ object Encoder { } else { enc.constructExpression.transformUp { case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) } } } val constructExpression = - NewInstance(cls, constructExpressions, false, ObjectType(cls)) + NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) new ExpressionEncoder[Any]( schema, @@ -130,7 +130,6 @@ object Encoder { ClassTag.apply(cls)) } - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] private def getTypeTag[T](c: Class[T]): TypeTag[T] = { @@ -148,9 +147,36 @@ object Encoder { }) } - def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { implicit val typeTag1 = getTypeTag(c1) implicit val typeTag2 = getTypeTag(c2) ExpressionEncoder[(T1, T2)]() } + + def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + ExpressionEncoder[(T1, T2, T3)]() + } + + def forTuple[T1, T2, T3, T4]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + ExpressionEncoder[(T1, T2, T3, T4)]() + } + + def forTuple[T1, T2, T3, T4, T5]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) + : Encoder[(T1, T2, T3, T4, T5)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + implicit val typeTag5 = getTypeTag(c5) + ExpressionEncoder[(T1, T2, T3, T4, T5)]() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db5550273..8ab958adadcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1478,18 +1478,54 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so on a + * very large dataset can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ def take(n: Int): Array[Row] = head(n) + /** + * Returns the first `n` rows in the [[DataFrame]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.6.0 + */ + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * * @group action * @since 1.3.0 */ def collect(): Array[Row] = collect(needCallback = true) + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.3.0 + */ + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } + } + private def collect(needCallback: Boolean): Array[Row] = { def execute(): Array[Row] = withNewExecutionId { queryExecution.executedPlan.executeCollectPublic() @@ -1502,17 +1538,6 @@ class DataFrame private[sql]( } } - /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => - withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fecbdac9a6004..959e0f5ba03e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -75,7 +75,11 @@ class Dataset[T] private[sql]( private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) - /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ + /** + * Returns the schema of the encoded form of the objects in this [[Dataset]]. + * + * @since 1.6.0 + */ def schema: StructType = encoder.schema /* ************* * @@ -103,6 +107,7 @@ class Dataset[T] private[sql]( /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have * the same name after two Datasets have been joined. + * @since 1.6.0 */ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) @@ -166,8 +171,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ - def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = - filter(t => func.call(t).booleanValue()) + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) /** * (Scala-specific) @@ -181,7 +185,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = map(t => func.call(t))(encoder) /** @@ -205,10 +209,8 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def mapPartitions[U]( - f: FlatMapFunction[java.util.Iterator[T], U], - encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala mapPartitions(func)(encoder) } @@ -248,7 +250,7 @@ class Dataset[T] private[sql]( * Runs `func` on each element of this Dataset. * @since 1.6.0 */ - def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) @@ -262,7 +264,7 @@ class Dataset[T] private[sql]( * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ - def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = foreachPartition(it => func.call(it.asJava)) /* ************* * @@ -271,7 +273,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -279,33 +281,11 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ - def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) - - /** - * (Scala-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * - * This behaves somewhat differently than the fold operations implemented for non-distributed - * collections in functional languages like Scala. This fold operation may be applied to - * partitions individually, and then those results will be folded into the final result. - * If op is not commutative, then the result may differ from that of a fold applied to a - * non-distributed collection. - * @since 1.6.0 - */ - def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) - - /** - * (Java-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * @since 1.6.0 - */ - def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** * (Scala-specific) @@ -351,7 +331,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ - def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = groupBy(f.call(_))(encoder) /* ****************** * @@ -367,7 +347,7 @@ class Dataset[T] private[sql]( */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs - def select(cols: Column*): DataFrame = toDF().select(cols: _*) + protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. @@ -462,8 +442,7 @@ class Dataset[T] private[sql]( * and thus is not affected by a custom `equals` function defined on `T`. * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Intersect) + def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) /** * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] @@ -473,8 +452,7 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analagous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. @@ -542,27 +520,47 @@ class Dataset[T] private[sql]( def first(): T = rdd.first() /** - * Collects the elements to an Array. + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ def collect(): Array[T] = rdd.collect() /** - * (Java-specific) - * Collects the elements to a Java list. + * Returns an array that contains all the elements in this [[Dataset]]. * - * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at - * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method - * instead and keep the generic type for result. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = - rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava - /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ def take(num: Int): Array[T] = rdd.take(num) + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /* ******************** * * Internal Functions * * ******************** */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 40bff57a17a03..d191b50fa802e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -65,6 +65,13 @@ public void testExecution() { Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); + } + /** * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0d3b1a5af52c4..0f90de774dd3e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -68,8 +68,16 @@ private Tuple2 tuple2(T1 t1, T2 t2) { public void testCollect() { List data = Arrays.asList("hello", "world"); Dataset ds = context.createDataset(data, e.STRING()); - String[] collected = (String[]) ds.collect(); - Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); } @Test @@ -78,16 +86,16 @@ public void testCommonOperation() { Dataset ds = context.createDataset(data, e.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset filtered = ds.filter(new Function() { + Dataset filtered = ds.filter(new FilterFunction() { @Override - public Boolean call(String v) throws Exception { + public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map(new Function() { + Dataset mapped = ds.map(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -95,7 +103,7 @@ public Integer call(String v) throws Exception { }, e.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override public Iterable call(Iterator it) throws Exception { List ls = new LinkedList(); @@ -128,7 +136,7 @@ public void testForeach() { List data = Arrays.asList("a", "b", "c"); Dataset ds = context.createDataset(data, e.STRING()); - ds.foreach(new VoidFunction() { + ds.foreach(new ForeachFunction() { @Override public void call(String s) throws Exception { accum.add(1); @@ -142,28 +150,20 @@ public void testReduce() { List data = Arrays.asList(1, 2, 3); Dataset ds = context.createDataset(data, e.INT()); - int reduced = ds.reduce(new Function2() { + int reduced = ds.reduce(new ReduceFunction() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); - - int folded = ds.fold(1, new Function2() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 * v2; - } - }); - Assert.assertEquals(6, folded); } @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(new Function() { + GroupedDataset grouped = ds.groupBy(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -187,7 +187,7 @@ public String call(Integer key, Iterator data) throws Exception { List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); - GroupedDataset grouped2 = ds2.groupBy(new Function() { + GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index fcf03f7180984..63b00975e4eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { assert(ds.reduce(_ + _) == 6) } - test("fold") { - val ds = Seq(1, 2, 3).toDS() - assert(ds.fold(0)(_ + _) == 6) - } - test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupBy(_ % 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6f1174e6577e3..aea5a700d0204 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) } + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + test("map") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( @@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } - test("fold") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) - } - test("joinWith, flat schema") { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") From d8b50f70298dbf45e91074ee2d751fee7eecb119 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 8 Nov 2015 21:01:53 -0800 Subject: [PATCH 237/324] [SPARK-11453][SQL] append data to partitioned table will messes up the result The reason is that: 1. For partitioned hive table, we will move the partitioned columns after data columns. (e.g. `` partition by `a` will become ``) 2. When append data to table, we use position to figure out how to match input columns to table's columns. So when we append data to partitioned table, we will match wrong columns between input and table. A solution is reordering the input columns before match by position, like what we did for [`InsertIntoHadoopFsRelation`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala#L101-L105) Author: Wenchen Fan Closes #9408 from cloud-fan/append. --- .../apache/spark/sql/DataFrameWriter.scala | 29 ++++++++++++++++--- .../sql/sources/PartitionedWriteSuite.scala | 8 +++++ .../sql/hive/execution/SQLQuerySuite.scala | 20 +++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a3025..e63a4d5e8b10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation @@ -167,17 +167,38 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { - val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite + + // A partitioned relation's schema can be different from the input logicalPlan, since + // partition columns are all moved after data columns. We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => + parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + df.sqlContext.executePlan( InsertIntoTable( UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, + input, overwrite, ifNotExists = false)).toRdd } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => + parCols.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c9791879ec74c..3eaa817f9c0b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -53,4 +53,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { Utils.deleteRecursively(path) } + + test("partitioned columns should appear at the end of schema") { + withTempPath { f => + val path = f.getAbsolutePath + Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) + assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index af48d478953b4..9a425d7f6b265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1428,4 +1428,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) } } + + test("SPARK-11453: append data to partitioned table") { + withTable("tbl11453") { + Seq("1" -> "10", "2" -> "20").toDF("i", "j") + .write.partitionBy("i").saveAsTable("tbl11453") + + Seq("3" -> "30").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) + + // make sure case sensitivity is correct. + Seq("4" -> "40").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) + } + } } From 9e48cdfbdecc9554a425ba35c0252910fd1e8faa Mon Sep 17 00:00:00 2001 From: Charles Yeh Date: Mon, 9 Nov 2015 13:22:05 +0100 Subject: [PATCH 238/324] [SPARK-11218][CORE] show help messages for start-slave and start-master Addressing https://issues.apache.org/jira/browse/SPARK-11218, mostly copied start-thriftserver.sh. ``` charlesyeh-mbp:spark charlesyeh$ ./sbin/start-master.sh --help Usage: Master [options] Options: -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) -h HOST, --host HOST Hostname to listen on -p PORT, --port PORT Port to listen on (default: 7077) --webui-port PORT Port for web UI (default: 8080) --properties-file FILE Path to a custom Spark properties file. Default is conf/spark-defaults.conf. ``` ``` charlesyeh-mbp:spark charlesyeh$ ./sbin/start-slave.sh Usage: Worker [options] Master must be a URL of the form spark://hostname:port Options: -c CORES, --cores CORES Number of cores to use -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G) -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work) -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h) -h HOST, --host HOST Hostname to listen on -p PORT, --port PORT Port to listen on (default: random) --webui-port PORT Port for web UI (default: 8081) --properties-file FILE Path to a custom Spark properties file. Default is conf/spark-defaults.conf. ``` Author: Charles Yeh Closes #9432 from CharlesYeh/helpmsg. --- sbin/start-master.sh | 24 +++++++++++++++++++----- sbin/start-slave.sh | 24 +++++++++++++++--------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c20e19a8412df..9f2e14dff609f 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -23,6 +23,20 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.master.Master" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-master.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + ORIGINAL_ARGS="$@" START_TACHYON=false @@ -30,7 +44,7 @@ START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -56,12 +70,12 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon format -s - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon-start.sh master + "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "${SPARK_HOME}"/tachyon/bin/tachyon format -s + "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master fi diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 21455648d1c6d..8c268b8859155 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -31,18 +31,24 @@ # worker. Subsequent workers will increment this # number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" - -if [ $# -lt 1 ]; then - echo $usage - echo Called as start-slave.sh $* - exit 1 -fi - if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-slave.sh [options] " + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -72,7 +78,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } From b541b31630b1b85b48d6096079d073ccf46a62e8 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Mon, 9 Nov 2015 13:28:00 +0100 Subject: [PATCH 239/324] [DOC][MINOR][SQL] Fix internal link It doesn't show up as a hyperlink currently. It will show up as a hyperlink after this change. Author: Rohit Agarwal Closes #9544 from mindprince/patch-2. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 52e03b951f966..ccd26904329d3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2287,7 +2287,7 @@ Several caching related features are not supported yet: Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). +(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 8c0e1b50e960d3e8e51d0618c462eed2bb4936f0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 Nov 2015 08:56:22 -0800 Subject: [PATCH 240/324] [SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary``` like ```Java $DevianceResiduals Min Max -0.9509607 0.7291832 $Coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) 1.6765 0.2353597 7.123139 4.456124e-11 Sepal_Length 0.3498801 0.04630128 7.556598 4.187317e-12 Species_versicolor -0.9833885 0.07207471 -13.64402 0 Species_virginica -1.00751 0.09330565 -10.79796 0 ``` Author: Yanbo Liang Closes #9561 from yanboliang/spark-11494. --- R/pkg/R/mllib.R | 22 ++++++-- R/pkg/inst/tests/test_mllib.R | 31 +++++++++--- .../apache/spark/ml/r/SparkRWrappers.scala | 50 +++++++++++++++++-- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b0d73dd93a79d..7ff859741b4a0 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -91,12 +91,26 @@ setMethod("predict", signature(object = "PipelineModel"), #'} setMethod("summary", signature(x = "PipelineModel"), function(x, ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", x@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelFeatures", x@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelCoefficients", x@model) - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + if (modelName == "LinearRegressionModel") { + devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelDevianceResiduals", x@model) + devianceResiduals <- matrix(devianceResiduals, nrow = 1) + colnames(devianceResiduals) <- c("Min", "Max") + rownames(devianceResiduals) <- rep("", times = 1) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients)) + } else { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + } }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 4761e285a2479..2606407bdcb44 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) + coefs <- unlist(stats$Coefficients) + devianceResiduals <- unlist(stats$DevianceResiduals) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) + rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331) + rTValue <- c(7.123, 7.557, -13.644, -10.798) + rPValue <- c(0.0, 0.0, 0.0, 0.0) + rDevianceResiduals <- c(-0.95096, 0.72918) + + expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6)) + expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5)) + expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6)) + expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -85,14 +96,20 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$coefficients) + coefs <- as.vector(stats$Coefficients) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) + rStdError <- c(3.0974, 0.5169, 0.8628) + rTValue <- c(-4.212, 3.680, 0.469) + rPValue <- c(0.000, 0.000, 0.639) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4)) + expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4)) + expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 5be2f86936211..4d82b90bfdf20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -52,11 +52,36 @@ private[r] object SparkRWrappers { } def getModelCoefficients(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => { + val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ + m.summary.coefficientStandardErrors.dropRight(1) + val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) + val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ + tValuesR ++ pValuesR + } else { + m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR + } + } + case m: LogisticRegressionModel => { + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray + } else { + m.coefficients.toArray + } + } + } + } + + def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + m.summary.devianceResiduals case m: LogisticRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + throw new UnsupportedOperationException( + "No deviance residuals available for LogisticRegressionModel") } } @@ -65,11 +90,28 @@ private[r] object SparkRWrappers { case m: LinearRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } case m: LogisticRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } + } + } + + def getModelName(model: PipelineModel): String = { + model.stages.last match { + case m: LinearRegressionModel => + "LinearRegressionModel" + case m: LogisticRegressionModel => + "LogisticRegressionModel" } } } From d50a66cc04bfa1c483f04daffe465322316c745e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 Nov 2015 08:57:29 -0800 Subject: [PATCH 241/324] [SPARK-10689][ML][DOC] User guide and example code for AFTSurvivalRegression Add user guide and example code for ```AFTSurvivalRegression```. Author: Yanbo Liang Closes #9491 from yanboliang/spark-10689. --- docs/ml-guide.md | 1 + docs/ml-survival-regression.md | 96 +++++++++++++++++++ .../ml/JavaAFTSurvivalRegressionExample.java | 71 ++++++++++++++ .../main/python/ml/aft_survival_regression.py | 51 ++++++++++ .../ml/AFTSurvivalRegressionExample.scala | 62 ++++++++++++ 5 files changed, 281 insertions(+) create mode 100644 docs/ml-survival-regression.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java create mode 100644 examples/src/main/python/ml/aft_survival_regression.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala diff --git a/docs/ml-guide.md b/docs/ml-guide.md index fd3a6167bc65e..c293e71d2870e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -44,6 +44,7 @@ provide class probabilities, and linear models provide model summaries. * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) +* [Survival Regression](ml-survival-regression.html) # Main concepts in Pipelines diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md new file mode 100644 index 0000000000000..ab275213b9a84 --- /dev/null +++ b/docs/ml-survival-regression.md @@ -0,0 +1,96 @@ +--- +layout: global +title: Survival Regression - ML +displayTitle: ML - Survival Regression +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +## Example: + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java new file mode 100644 index 0000000000000..69a174562fcf5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.regression.AFTSurvivalRegression; +import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaAFTSurvivalRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)), + RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)), + RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)), + RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)), + RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + DataFrame training = jsql.createDataFrame(data, schema); + double[] quantileProbabilities = new double[]{0.3, 0.6}; + AFTSurvivalRegression aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles"); + + AFTSurvivalRegressionModel model = aft.fit(training); + + // Print the coefficients, intercept and scale parameter for AFT survival regression + System.out.println("Coefficients: " + model.coefficients() + " Intercept: " + + model.intercept() + " Scale: " + model.scale()); + model.transform(training).show(false); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py new file mode 100644 index 0000000000000..0ee01fd8258df --- /dev/null +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import AFTSurvivalRegression +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="AFTSurvivalRegressionExample") + sqlContext = SQLContext(sc) + + # $example on$ + training = sqlContext.createDataFrame([ + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"]) + quantileProbabilities = [0.3, 0.6] + aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities, + quantilesCol="quantiles") + + model = aft.fit(training) + + # Print the coefficients, intercept and scale parameter for AFT survival regression + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + print("Scale: " + str(model.scale)) + model.transform(training).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala new file mode 100644 index 0000000000000..5da285e83681f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.regression.AFTSurvivalRegression +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +/** + * An example for AFTSurvivalRegression. + */ +object AFTSurvivalRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val training = sqlContext.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226)) + )).toDF("label", "censor", "features") + val quantileProbabilities = Array(0.3, 0.6) + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + + val model = aft.fit(training) + + // Print the coefficients, intercept and scale parameter for AFT survival regression + println(s"Coefficients: ${model.coefficients} Intercept: " + + s"${model.intercept} Scale: ${model.scale}") + model.transform(training).show(false) + // $example off$ + + sc.stop() + } +} +// scalastyle:off println From 9b88e1dcad6b5b14a22cf64a1055ad9870507b5a Mon Sep 17 00:00:00 2001 From: fazlan-nazeem Date: Mon, 9 Nov 2015 08:58:55 -0800 Subject: [PATCH 242/324] [SPARK-11582][MLLIB] specifying pmml version attribute =4.2 in the root node of pmml model The current pmml models generated do not specify the pmml version in its root node. This is a problem when using this pmml model in other tools because they expect the version attribute to be set explicitly. This fix adds the pmml version attribute to the generated pmml models and specifies its value as 4.2. Author: fazlan-nazeem Closes #9558 from fazlan-nazeem/master. --- .../org/apache/spark/mllib/pmml/export/PMMLModelExport.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index c5fdecd3ca17f..9267e6dbdb857 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -32,6 +32,7 @@ private[mllib] trait PMMLModelExport { @BeanProperty val pmml: PMML = new PMML + pmml.setVersion("4.2") setHeader(pmml) private def setHeader(pmml: PMML): Unit = { From 08a7a836c393d6a62b9b216eeb01fad0b90b6c52 Mon Sep 17 00:00:00 2001 From: Charles Yeh Date: Mon, 9 Nov 2015 11:59:32 -0600 Subject: [PATCH 243/324] [SPARK-10565][CORE] add missing web UI stats to /api/v1/applications JSON I looked at the other endpoints, and they don't seem to be missing any fields. Added fields: ![image](https://cloud.githubusercontent.com/assets/613879/10948801/58159982-82e4-11e5-86dc-62da201af910.png) Author: Charles Yeh Closes #9472 from CharlesYeh/api_vars. --- .../spark/deploy/master/ui/MasterWebUI.scala | 7 +- .../api/v1/ApplicationListResource.scala | 8 ++ .../org/apache/spark/status/api/v1/api.scala | 4 + .../scala/org/apache/spark/ui/SparkUI.scala | 4 + .../deploy/master/ui/MasterWebUISuite.scala | 90 +++++++++++++++++++ project/MimaExcludes.scala | 3 + 6 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 6174fc11f83d8..e41554a5a6d26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,14 +28,17 @@ import org.apache.spark.ui.JettyUtils._ * Web UI server for the standalone master. */ private[master] -class MasterWebUI(val master: Master, requestedPort: Int) +class MasterWebUI( + val master: Master, + requestedPort: Int, + customMasterPage: Option[MasterPage] = None) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - val masterPage = new MasterPage(this) + val masterPage = customMasterPage.getOrElse(new MasterPage(this)) initialize() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 17b521f3e1d41..0fc0fb59d861f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -62,6 +62,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = app.id, name = app.name, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = app.attempts.map { internalAttemptInfo => new ApplicationAttemptInfo( attemptId = internalAttemptInfo.attemptId, @@ -81,6 +85,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = internal.id, name = internal.desc.name, + coresGranted = Some(internal.coresGranted), + maxCores = internal.desc.maxCores, + coresPerExecutor = internal.desc.coresPerExecutor, + memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(internal.startTime), diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2bec64f2ef02b..baddfc50c1a40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -25,6 +25,10 @@ import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( val id: String, val name: String, + val coresGranted: Option[Int], + val maxCores: Option[Int], + val coresPerExecutor: Option[Int], + val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) class ApplicationAttemptInfo private[spark]( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 99085ada9f0af..4608bce202ec8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -102,6 +102,10 @@ private[spark] class SparkUI private ( Iterator(new ApplicationInfo( id = appId, name = appName, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(startTime), diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala new file mode 100644 index 0000000000000..fba835f054f8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date + +import scala.io.Source +import scala.language.postfixOps + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployTestUtils._ +import org.apache.spark.deploy.master._ +import org.apache.spark.rpc.RpcEnv + + +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { + + val masterPage = mock(classOf[MasterPage]) + val master = { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + + before { + masterWebUI.bind() + } + + after { + masterWebUI.stop() + } + + test("list applications") { + val worker = createWorkerInfo() + val appDesc = createAppDesc() + // use new start date so it isn't filtered by UI + val activeApp = new ApplicationInfo( + new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) + activeApp.addExecutor(worker, 2) + + val workers = Array[WorkerInfo](worker) + val activeApps = Array(activeApp) + val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array[DriverInfo]() + val completedDrivers = Array[DriverInfo]() + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, + activeDrivers, completedDrivers, RecoveryState.ALIVE) + + when(masterPage.getMasterState).thenReturn(stateResponse) + + val resultJson = Source.fromURL( + s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") + .mkString + val parsedJson = parse(resultJson) + val firstApp = parsedJson(0) + + assert(firstApp \ "id" === JString(activeApp.id)) + assert(firstApp \ "name" === JString(activeApp.desc.name)) + assert(firstApp \ "coresGranted" === JInt(2)) + assert(firstApp \ "maxCores" === JInt(4)) + assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) + assert(firstApp \ "coresPerExecutor" === JNothing) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dacef911e397e..50220790d1f84 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -134,6 +134,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") + ) ++ Seq ( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationInfo.this") ) case v if v.startsWith("1.5") => Seq( From 404a28f4edd09cf17361dcbd770e4cafde51bf6d Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 9 Nov 2015 10:07:58 -0800 Subject: [PATCH 244/324] [SPARK-11112] Fix Scala 2.11 compilation error in RDDInfo.scala As shown in https://amplab.cs.berkeley.edu/jenkins/view/Spark-QA-Compile/job/Spark-Master-Scala211-Compile/1946/console , compilation fails with: ``` [error] /home/jenkins/workspace/Spark-Master-Scala211-Compile/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala:25: in class RDDInfo, multiple overloaded alternatives of constructor RDDInfo define default arguments. [error] class RDDInfo( [error] ``` This PR tries to fix the compilation error Author: tedyu Closes #9538 from tedyu/master. --- .../scala/org/apache/spark/storage/RDDInfo.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 3fa209b924170..87c1b981e7e13 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -28,20 +28,10 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], - val callSite: CallSite, + val callSite: CallSite = CallSite.empty, val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { - def this( - id: Int, - name: String, - numPartitions: Int, - storageLevel: StorageLevel, - parentIds: Seq[Int], - scope: Option[RDDOperationScope] = None) { - this(id, name, numPartitions, storageLevel, parentIds, CallSite.empty, scope) - } - var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L From cd174882a5a211298d6e173fe989d567d08ebc0d Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 9 Nov 2015 10:26:09 -0800 Subject: [PATCH 245/324] [SPARK-9865][SPARKR] Flaky SparkR test: test_sparkSQL.R: sample on a DataFrame Make sample test less flaky by setting the seed Tested with ``` repeat { if (count(sample(df, FALSE, 0.1)) == 3) { break } } ``` Author: felixcheung Closes #9549 from felixcheung/rsample. --- R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92cff1fba7193..fbdb9a8f1ef6b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -647,11 +647,11 @@ test_that("sample on a DataFrame", { sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") - sampled2 <- sample(df, FALSE, 0.1) + sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) # Also test sample_frac - sampled3 <- sample_frac(df, FALSE, 0.1) + sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) }) From 874cd66d4b6d156d0ef112a3d0f3bc5683c6a0ec Mon Sep 17 00:00:00 2001 From: chriskang90 Date: Mon, 9 Nov 2015 19:39:22 +0100 Subject: [PATCH 246/324] [DOCS] Fix typo for Python section on unifying Kafka streams 1) kafkaStreams is a list. The list should be unpacked when passing it into the streaming context union method, which accepts a variable number of streams. 2) print() should be pprint() for pyspark. This contribution is my original work, and I license the work to the project under the project's open source license. Author: chriskang90 Closes #9545 from c-kang/streaming_python_typo. --- docs/streaming-programming-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c751dbb41785a..e9a27f446a898 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1948,8 +1948,8 @@ unifiedStream.print(); {% highlight python %} numStreams = 5 kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] -unifiedStream = streamingContext.union(kafkaStreams) -unifiedStream.print() +unifiedStream = streamingContext.union(*kafkaStreams) +unifiedStream.pprint() {% endhighlight %}
    From 860ea0d386b5fbbe26bf2954f402a9a73ad37edc Mon Sep 17 00:00:00 2001 From: Bharat Lal Date: Mon, 9 Nov 2015 11:33:01 -0800 Subject: [PATCH 247/324] [SPARK-11581][DOCS] Example mllib code in documentation incorrectly computes MSE Author: Bharat Lal Closes #9560 from bharatl/SPARK-11581. --- docs/mllib-decision-tree.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index f31c4f88936bd..b5b454bc69245 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -439,7 +439,7 @@ Double testMSE = public Double call(Double a, Double b) { return a + b; } - }) / data.count(); + }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); From 88a3fdcc783f880a8d01c7e194ec42fc114bdf8a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 13:16:04 -0800 Subject: [PATCH 248/324] [SPARK-10280][MLLIB][PYSPARK][DOCS] Add @since annotation to pyspark.ml.classification Author: Yu ISHIKAWA Closes #8690 from yu-iskw/SPARK-10280. --- python/pyspark/ml/classification.py | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 2e468f67b8987..603f2c7f798dc 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -67,6 +67,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.3.0 """ # a placeholder to make it appear in the generated doc @@ -99,6 +101,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._checkThresholdConsistency() @keyword_only + @since("1.3.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", @@ -119,6 +122,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -129,6 +133,7 @@ def setThreshold(self, value): del self._paramMap[self.thresholds] return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. @@ -144,6 +149,7 @@ def getThreshold(self): else: return self.getOrDefault(self.threshold) + @since("1.5.0") def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. @@ -154,6 +160,7 @@ def setThresholds(self, value): del self._paramMap[self.threshold] return self + @since("1.5.0") def getThresholds(self): """ If :py:attr:`thresholds` is set, return its value. @@ -185,9 +192,12 @@ def _checkThresholdConsistency(self): class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. + + .. versionadded:: 1.3.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -205,6 +215,7 @@ def coefficients(self): return self._call_java("coefficients") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -215,6 +226,8 @@ def intercept(self): class TreeClassifierParams(object): """ Private class to track supported impurity measures. + + .. versionadded:: 1.4.0 """ supportedImpurities = ["entropy", "gini"] @@ -231,6 +244,7 @@ def __init__(self): "gain calculation (case-insensitive). Supported options: " + ", ".join(self.supportedImpurities)) + @since("1.6.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -238,6 +252,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.6.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -248,6 +263,8 @@ def getImpurity(self): class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. + + .. versionadded:: 1.4.0 """ supportedLossTypes = ["logistic"] @@ -287,6 +304,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ @keyword_only @@ -310,6 +329,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -333,6 +353,8 @@ def _create_model(self, java_model): class DecisionTreeClassificationModel(DecisionTreeModel): """ Model fitted by DecisionTreeClassifier. + + .. versionadded:: 1.4.0 """ @@ -371,6 +393,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ @keyword_only @@ -396,6 +420,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -419,6 +444,8 @@ def _create_model(self, java_model): class RandomForestClassificationModel(TreeEnsembleModels): """ Model fitted by RandomForestClassifier. + + .. versionadded:: 1.4.0 """ @@ -450,6 +477,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -482,6 +511,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -499,6 +529,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTClassificationModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -506,6 +537,7 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. @@ -516,6 +548,8 @@ def getLossType(self): class GBTClassificationModel(TreeEnsembleModels): """ Model fitted by GBTClassifier. + + .. versionadded:: 1.4.0 """ @@ -555,6 +589,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -587,6 +623,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, modelType="multinomial"): @@ -602,6 +639,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return NaiveBayesModel(java_model) + @since("1.5.0") def setSmoothing(self, value): """ Sets the value of :py:attr:`smoothing`. @@ -609,12 +647,14 @@ def setSmoothing(self, value): self._paramMap[self.smoothing] = value return self + @since("1.5.0") def getSmoothing(self): """ Gets the value of smoothing or its default value. """ return self.getOrDefault(self.smoothing) + @since("1.5.0") def setModelType(self, value): """ Sets the value of :py:attr:`modelType`. @@ -622,6 +662,7 @@ def setModelType(self, value): self._paramMap[self.modelType] = value return self + @since("1.5.0") def getModelType(self): """ Gets the value of modelType or its default value. @@ -632,9 +673,12 @@ def getModelType(self): class NaiveBayesModel(JavaModel): """ Model fitted by NaiveBayes. + + .. versionadded:: 1.5.0 """ @property + @since("1.5.0") def pi(self): """ log of class priors. @@ -642,6 +686,7 @@ def pi(self): return self._call_java("pi") @property + @since("1.5.0") def theta(self): """ log of class conditional probabilities. @@ -681,6 +726,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, |[0.0,0.0]| 0.0| +---------+----------+ ... + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -715,6 +762,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): """ @@ -731,6 +779,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return MultilayerPerceptronClassificationModel(java_model) + @since("1.6.0") def setLayers(self, value): """ Sets the value of :py:attr:`layers`. @@ -738,12 +787,14 @@ def setLayers(self, value): self._paramMap[self.layers] = value return self + @since("1.6.0") def getLayers(self): """ Gets the value of layers or its default value. """ return self.getOrDefault(self.layers) + @since("1.6.0") def setBlockSize(self, value): """ Sets the value of :py:attr:`blockSize`. @@ -751,6 +802,7 @@ def setBlockSize(self, value): self._paramMap[self.blockSize] = value return self + @since("1.6.0") def getBlockSize(self): """ Gets the value of blockSize or its default value. @@ -761,9 +813,12 @@ def getBlockSize(self): class MultilayerPerceptronClassificationModel(JavaModel): """ Model fitted by MultilayerPerceptronClassifier. + + .. versionadded:: 1.6.0 """ @property + @since("1.6.0") def layers(self): """ array of layer sizes including input and output layers. @@ -771,6 +826,7 @@ def layers(self): return self._call_java("javaLayers") @property + @since("1.6.0") def weights(self): """ vector of initial weights for the model that consists of the weights of layers. From 5039a49b636325f321daa089971107003fae9d4b Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Mon, 9 Nov 2015 13:36:14 -0800 Subject: [PATCH 249/324] [SPARK-10471][CORE][MESOS] prevent getting offers for unmet constraints this change rejects offers for slaves with unmet constraints for 120s to mitigate offer starvation. this prevents mesos to send us these offers again and again. in return, we get more offers for slaves which might meet our constraints. and it enables mesos to send the rejected offers to other frameworks. Author: Felix Bechstein Closes #8639 from felixb/decline_offers_constraint_mismatch. --- .../mesos/CoarseMesosSchedulerBackend.scala | 92 +++++++++++-------- .../cluster/mesos/MesosSchedulerBackend.scala | 48 +++++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 4 + 3 files changed, 91 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d10a77f8e5c78..2de9b6a651692 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -101,6 +101,10 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + // A client for talking to the external shuffle service, if it is a private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { @@ -249,48 +253,56 @@ private[spark] class CoarseMesosSchedulerBackend( val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + if (meetsConstraints) { + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId)) { + // Launch an executor on the slave + val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) + totalCoresAcquired += cpusToUse + val taskId = newMesosTaskId() + taskIdToSlaveId.put(taskId, slaveId) + slaveIdsWithExecutors += slaveId + coresByTaskId(taskId) = cpusToUse + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + } + + // Accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) + } else { + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } - - // accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + // This offer does not meet constraints. We don't need to see it again. + // Decline the offer for a long period of time. + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + d.declineOffer(offer.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index aaffac604a885..281965a5981bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -63,6 +63,10 @@ private[spark] class MesosSchedulerBackend( private[this] val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + @volatile var appId: String = _ override def start() { @@ -212,29 +216,47 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.asScala.partition { o => + // Fail first on offers with unmet constraints + val (offersMatchingConstraints, offersNotMatchingConstraints) = + offers.asScala.partition { o => + val offerAttributes = toAttributeMap(o.getAttributesList) + val meetsConstraints = + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + + // add some debug messaging + if (!meetsConstraints) { + val id = o.getId.getValue + logDebug(s"Declining offer: $id with attributes: $offerAttributes") + } + + meetsConstraints + } + + // These offers do not meet constraints. We don't need to see them again. + // Decline the offer for a long period of time. + offersNotMatchingConstraints.foreach { o => + d.declineOffer(o.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + } + + // Of the matching constraints, see which ones give us enough memory and cores + val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + // check offers for + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) - val meetsRequirements = - (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (meetsMemoryRequirements && meetsCPURequirements) || (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) - - // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" - val id = o.getId.getValue - logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + + s"$offerAttributes mem: $mem cpu: $cpus") meetsRequirements } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 860c8e097b3b9..721861fbbc517 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -336,4 +336,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } + protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + } + } From 51d41e4b1a3a25a3fde3a4345afcfe4766023d23 Mon Sep 17 00:00:00 2001 From: sachin aggarwal Date: Mon, 9 Nov 2015 14:25:42 -0800 Subject: [PATCH 250/324] [SPARK-11552][DOCS][Replaced example code in ml-decision-tree.md using include_example] I have tested it on my local, it is working fine, please review Author: sachin aggarwal Closes #9539 from agsachin/SPARK-11552-real. --- docs/ml-decision-tree.md | 338 +----------------- ...JavaDecisionTreeClassificationExample.java | 103 ++++++ .../ml/JavaDecisionTreeRegressionExample.java | 90 +++++ .../decision_tree_classification_example.py | 77 ++++ .../ml/decision_tree_regression_example.py | 74 ++++ .../DecisionTreeClassificationExample.scala | 94 +++++ .../ml/DecisionTreeRegressionExample.scala | 81 +++++ 7 files changed, 527 insertions(+), 330 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java create mode 100644 examples/src/main/python/ml/decision_tree_classification_example.py create mode 100644 examples/src/main/python/ml/decision_tree_regression_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index 542819e93e6dc..2bfac6f6c8378 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -118,196 +118,24 @@ We use two feature transformers to prepare the data; these help index categories More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier -import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] -println("Learned classification tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} +
    More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.DecisionTreeClassifier; -import org.apache.spark.ml.classification.DecisionTreeClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeClassifier dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -DecisionTreeClassificationModel treeModel = - (DecisionTreeClassificationModel)(model.stages()[2]); -System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} +
    More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import DecisionTreeClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and tree in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) +{% include_example python/ml/decision_tree_classification_example.py %} -treeModel = model.stages[2] -print treeModel # summary only -{% endhighlight %}
    @@ -323,171 +151,21 @@ We use a feature transformer to index categorical features, adding metadata to t More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Automatically identify categorical features, and index them. -// Here, we treat features with > 4 distinct values as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, dt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] -println("Learned regression tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %}
    More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; -import org.apache.spark.ml.regression.DecisionTreeRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, dt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -DecisionTreeRegressionModel treeModel = - (DecisionTreeRegressionModel)(model.stages()[1]); -System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %}
    More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import DecisionTreeRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeRegressor(featuresCol="indexedFeatures") - -# Chain indexer and tree in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, dt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -treeModel = model.stages[1] -print treeModel # summary only -{% endhighlight %} +{% include_example python/ml/decision_tree_regression_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..51c1730a8a085 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeClassificationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + RDD rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); + DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + + // Automatically identify categorical features, and index them. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel) (model.stages()[2]); + System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..a4098a4233ec2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + RDD rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); + DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{featureIndexer, dt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel) (model.stages()[1]); + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py new file mode 100644 index 0000000000000..0af92050e3e3b --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +import sys + +# $example on$ +from pyspark import SparkContext, SQLContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and tree in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g " % (1.0 - accuracy)) + + treeModel = model.stages[2] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py new file mode 100644 index 0000000000000..3857aed538da2 --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + + # Chain indexer and tree in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, dt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + treeModel = model.stages[1] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..a24a344f1bcf4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object DecisionTreeClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] + println("Learned classification tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..64cd986129007 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils +// $example off$ +object DecisionTreeRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + // Automatically identify categorical features, and index them. + // Here, we treat features with > 4 distinct values as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] + println("Learned regression tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} From b7720fa45525cff6e812fa448d0841cb41f6c8a5 Mon Sep 17 00:00:00 2001 From: Rishabh Bhardwaj Date: Mon, 9 Nov 2015 14:27:36 -0800 Subject: [PATCH 251/324] [SPARK-11548][DOCS] Replaced example code in mllib-collaborative-filtering.md using include_example Kindly review the changes. Author: Rishabh Bhardwaj Closes #9519 from rishabhbhardwaj/SPARK-11337. --- docs/mllib-collaborative-filtering.md | 138 +----------------- .../mllib/JavaRecommendationExample.java | 97 ++++++++++++ .../python/mllib/recommendation_example.py | 54 +++++++ .../mllib/RecommendationExample.scala | 67 +++++++++ 4 files changed, 221 insertions(+), 135 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java create mode 100644 examples/src/main/python/mllib/recommendation_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 1ad52123c74aa..7cd1b894e7cb5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,43 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.recommendation.ALS -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel -import org.apache.spark.mllib.recommendation.Rating - -// Load and parse the data -val data = sc.textFile("data/mllib/als/test.data") -val ratings = data.map(_.split(',') match { case Array(user, item, rate) => - Rating(user.toInt, item.toInt, rate.toDouble) - }) - -// Build the recommendation model using ALS -val rank = 10 -val numIterations = 10 -val model = ALS.train(ratings, rank, numIterations, 0.01) - -// Evaluate the model on rating data -val usersProducts = ratings.map { case Rating(user, product, rate) => - (user, product) -} -val predictions = - model.predict(usersProducts).map { case Rating(user, product, rate) => - ((user, product), rate) - } -val ratesAndPreds = ratings.map { case Rating(user, product, rate) => - ((user, product), rate) -}.join(predictions) -val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => - val err = (r1 - r2) - err * err -}.mean() -println("Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} If the rating matrix is derived from another source of information (e.g., it is inferred from other signals), you can use the `trainImplicit` method to get better results. @@ -123,81 +87,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.SparkConf; - -public class CollaborativeFiltering { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/als/test.data"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); - - // Build the recommendation model using ALS - int rank = 10; - int numIterations = 10; - MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); - - // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); - System.out.println("Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
    @@ -207,29 +97,7 @@ recommendation by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating - -# Load and parse the data -data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) - -# Build the recommendation model using Alternating Least Squares -rank = 10 -numIterations = 10 -model = ALS.train(ratings, rank, numIterations) - -# Evaluate the model on training data -testdata = ratings.map(lambda p: (p[0], p[1])) -predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) -ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/recommendation_example.py %} If the rating matrix is derived from other source of information (i.e., it is inferred from other signals), you can use the trainImplicit method to get better results. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java new file mode 100644 index 0000000000000..1065fde953b96 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRecommendationExample { + public static void main(String args[]) { + // $example on$ + SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = jsc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 10; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myCollaborativeFilter"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(), + "target/tmp/myCollaborativeFilter"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py new file mode 100644 index 0000000000000..615db0749b182 --- /dev/null +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Collaborative Filtering Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext + +# $example on$ +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonCollaborativeFilteringExample") + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/als/test.data") + ratings = data.map(lambda l: l.split(','))\ + .map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) + + # Build the recommendation model using Alternating Least Squares + rank = 10 + numIterations = 10 + model = ALS.train(ratings, rank, numIterations) + + # Evaluate the model on training data + testdata = ratings.map(lambda p: (p[0], p[1])) + predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) + ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) + MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala new file mode 100644 index 0000000000000..64e4602465444 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel +import org.apache.spark.mllib.recommendation.Rating +// $example off$ + +object RecommendationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("CollaborativeFilteringExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/als/test.data") + val ratings = data.map(_.split(',') match { case Array(user, item, rate) => + Rating(user.toInt, item.toInt, rate.toDouble) + }) + + // Build the recommendation model using ALS + val rank = 10 + val numIterations = 10 + val model = ALS.train(ratings, rank, numIterations, 0.01) + + // Evaluate the model on rating data + val usersProducts = ratings.map { case Rating(user, product, rate) => + (user, product) + } + val predictions = + model.predict(usersProducts).map { case Rating(user, product, rate) => + ((user, product), rate) + } + val ratesAndPreds = ratings.map { case Rating(user, product, rate) => + ((user, product), rate) + }.join(predictions) + val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => + val err = (r1 - r2) + err * err + }.mean() + println("Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + val sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + // $example off$ + } +} +// scalastyle:on println From f138cb873335654476d1cd1070900b552dd8b21a Mon Sep 17 00:00:00 2001 From: Nick Buroojy Date: Mon, 9 Nov 2015 14:30:37 -0800 Subject: [PATCH 252/324] [SPARK-9301][SQL] Add collect_set and collect_list aggregate functions For now they are thin wrappers around the corresponding Hive UDAFs. One limitation with these in Hive 0.13.0 is they only support aggregating primitive types. I chose snake_case here instead of camelCase because it seems to be used in the majority of the multi-word fns. Do we also want to add these to `functions.py`? This approach was recommended here: https://github.com/apache/spark/pull/8592#issuecomment-154247089 marmbrus rxin Author: Nick Buroojy Closes #9526 from nburoojy/nick/udaf-alias. (cherry picked from commit a6ee4f989d020420dd08b97abb24802200ff23b2) Signed-off-by: Michael Armbrust --- python/pyspark/sql/functions.py | 25 +++++++++++-------- python/pyspark/sql/tests.py | 17 +++++++++++++ .../org/apache/spark/sql/functions.scala | 20 +++++++++++++++ .../hive/HiveDataFrameAnalyticsSuite.scala | 15 +++++++++-- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2f7c2f4aacd47..962f676d406d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -124,17 +124,20 @@ def _(): _functions_1_6 = { # unary math functions - "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_pop": "Aggregate function: returns population standard deviation of" + - " the expression in a group.", - "variance": "Aggregate function: returns the population variance of the values in a group.", - "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", - "var_pop": "Aggregate function: returns the population variance of the values in a group.", - "skewness": "Aggregate function: returns the skewness of the values in a group.", - "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' } # math functions that take two arguments as input diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4c03a0d4ffe93..e224574bcb301 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1230,6 +1230,23 @@ def test_window_functions_without_partitionBy(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + if __name__ == "__main__": if xmlrunner: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 04627589886a8..3f0b24b68b816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -174,6 +174,26 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae415e54b..9864acf765265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll @@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } @@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), From 150f6a89b79f0e5bc31aa83731429dc7ac5ea76b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 9 Nov 2015 14:32:52 -0800 Subject: [PATCH 253/324] [SPARK-11595] [SQL] Fixes ADD JAR when the input path contains URL scheme Author: Cheng Lian Closes #9569 from liancheng/spark-11595.fix-add-jar. --- .../hive/thriftserver/HiveThriftServer2Suites.scala | 1 + .../apache/spark/sql/hive/client/ClientWrapper.scala | 11 +++++++++-- .../spark/sql/hive/client/IsolatedClientLoader.scala | 9 +++------ .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +++++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ff8ca0150649d..5903b9e71cdd2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -41,6 +41,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3dce86c480747..f1c2489b38271 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} import java.util.{Map => JMap} -import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.language.reflectiveCalls @@ -548,7 +547,15 @@ private[hive] class ClientWrapper( } def addJar(path: String): Unit = { - clientLoader.addJar(path) + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + clientLoader.addJar(jarURL) runSqlHive(s"ADD JAR $path") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f99c3ed2ae987..e041e0d8e5ae8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.collection.mutable import scala.language.reflectiveCalls import scala.util.Try @@ -30,10 +29,9 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.{MutableURLClassLoader, Utils} - import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ private[hive] object IsolatedClientLoader { @@ -190,9 +188,8 @@ private[hive] class IsolatedClientLoader( new NonClosableMutableURLClassLoader(isolatedClassLoader) } - private[hive] def addJar(path: String): Unit = synchronized { - val jarURL = new java.io.File(path).toURI.toURL - classLoader.addURL(jarURL) + private[hive] def addJar(path: URL): Unit = synchronized { + classLoader.addURL(path) } /** The isolated client interface to Hive. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fc72e3c7dc6aa..78378c8b69c7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -927,7 +927,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -961,10 +961,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("CREATE TEMPORARY FUNCTION") { val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath - sql(s"ADD JAR $funcJar") + val jarURL = s"file://$funcJar" + sql(s"ADD JAR $jarURL") sql( """CREATE TEMPORARY FUNCTION udtf_count2 AS - | 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'""".stripMargin) + |'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) sql("DROP TEMPORARY FUNCTION udtf_count2") } From a3a7c9103e136035d65a5564f9eb0fa04727c4f3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 14:39:18 -0800 Subject: [PATCH 254/324] [SPARK-11359][STREAMING][KINESIS] Checkpoint to DynamoDB even when new data doesn't come in Currently, the checkpoints to DynamoDB occur only when new data comes in, as we update the clock for the checkpointState. This PR makes the checkpoint a scheduled execution based on the `checkpointInterval`. Author: Burak Yavuz Closes #9421 from brkyvz/kinesis-checkpoint. --- .../kinesis/KinesisCheckpointState.scala | 54 ------- .../kinesis/KinesisCheckpointer.scala | 133 +++++++++++++++ .../streaming/kinesis/KinesisReceiver.scala | 38 ++++- .../kinesis/KinesisRecordProcessor.scala | 59 ++----- .../kinesis/KinesisCheckpointerSuite.scala | 152 ++++++++++++++++++ .../kinesis/KinesisReceiverSuite.scala | 96 +++-------- 6 files changed, 349 insertions(+), 183 deletions(-) delete mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala deleted file mode 100644 index 83a4537559512..0000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import org.apache.spark.Logging -import org.apache.spark.streaming.Duration -import org.apache.spark.util.{Clock, ManualClock, SystemClock} - -/** - * This is a helper class for managing checkpoint clocks. - * - * @param checkpointInterval - * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) - */ -private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, - currentClock: Clock = new SystemClock()) - extends Logging { - - /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ - val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) - - /** - * Check if it's time to checkpoint based on the current time and the derived time - * for the next checkpoint - * - * @return true if it's time to checkpoint - */ - def shouldCheckpoint(): Boolean = { - new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() - } - - /** - * Advance the checkpoint clock by the checkpoint interval. - */ - def advanceCheckpoint(): Unit = { - checkpointClock.advance(checkpointInterval.milliseconds) - } -} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala new file mode 100644 index 0000000000000..1ca6d4302c2bb --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} + +/** + * This is a helper class for managing Kinesis checkpointing. + * + * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint + * @param checkpointInterval How frequently we will checkpoint to DynamoDB + * @param workerId Worker Id of KCL worker for logging purposes + * @param clock In order to use ManualClocks for the purpose of testing + */ +private[kinesis] class KinesisCheckpointer( + receiver: KinesisReceiver[_], + checkpointInterval: Duration, + workerId: String, + clock: Clock = new SystemClock) extends Logging { + + // a map from shardId's to checkpointers + private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + + private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() + + private val checkpointerThread: RecurringTimer = startCheckpointerThread() + + /** Update the checkpointer instance to the most recent one for the given shardId. */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + checkpointers.put(shardId, checkpointer) + } + + /** + * Stop tracking the specified shardId. + * + * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], + * we will use that to make the final checkpoint. If `null` is provided, we will not make the + * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + synchronized { + checkpointers.remove(shardId) + checkpoint(shardId, checkpointer) + } + } + + /** Perform the checkpoint. */ + private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + try { + if (checkpointer != null) { + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + val lastSeqNum = lastCheckpointedSeqNums.get(shardId) + // Kinesis sequence numbers are monotonically increasing strings, therefore we can do + // safely do the string comparison + if (lastSeqNum == null || latestSeqNum > lastSeqNum) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + + s" $latestSeqNum for shardId $shardId") + lastCheckpointedSeqNums.put(shardId, latestSeqNum) + } + } + } else { + logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.") + } + } catch { + case NonFatal(e) => + logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e) + } + } + + /** Checkpoint the latest saved sequence numbers for all active shardId's. */ + private def checkpointAll(): Unit = synchronized { + // if this method throws an exception, then the scheduled task will not run again + try { + val shardIds = checkpointers.keys() + while (shardIds.hasMoreElements) { + val shardId = shardIds.nextElement() + checkpoint(shardId, checkpointers.get(shardId)) + } + } catch { + case NonFatal(e) => + logWarning("Failed to checkpoint to DynamoDB.", e) + } + } + + /** + * Start the checkpointer thread with the given checkpoint duration. + */ + private def startCheckpointerThread(): RecurringTimer = { + val period = checkpointInterval.milliseconds + val threadName = s"Kinesis Checkpointer - Worker $workerId" + val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) + timer.start() + logDebug(s"Started checkpointer thread: $threadName") + timer + } + + /** + * Shutdown the checkpointer. Should be called on the onStop of the Receiver. + */ + def shutdown(): Unit = { + // the recurring timer checkpoints for us one last time. + checkpointerThread.stop(interruptTimer = false) + checkpointers.clear() + lastCheckpointedSeqNums.clear() + logInfo("Successfully shutdown Kinesis Checkpointer.") + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 134d627cdaffa..50993f157cd95 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -31,8 +31,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkEnv} - +import org.apache.spark.Logging private[kinesis] case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) @@ -127,6 +126,11 @@ private[kinesis] class KinesisReceiver[T]( private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + /** + * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. + */ + @volatile private var kinesisCheckpointer: KinesisCheckpointer = null + /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ @@ -141,6 +145,7 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = @@ -157,8 +162,8 @@ private[kinesis] class KinesisReceiver[T]( * We're using our custom KinesisRecordProcessor in this case. */ val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, - workerId, new KinesisCheckpointState(checkpointInterval)) + override def createProcessor: IRecordProcessor = + new KinesisRecordProcessor(receiver, workerId) } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) @@ -198,6 +203,10 @@ private[kinesis] class KinesisReceiver[T]( logInfo(s"Stopped receiver for workerId $workerId") } workerId = null + if (kinesisCheckpointer != null) { + kinesisCheckpointer.shutdown() + kinesisCheckpointer = null + } } /** Add records of the given shard to the current block being generated */ @@ -216,6 +225,25 @@ private[kinesis] class KinesisReceiver[T]( shardIdToLatestStoredSeqNum.get(shardId) } + /** + * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the + * given shardId. + */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.setCheckpointer(shardId, checkpointer) + } + + /** + * Remove the checkpointer for the given shardId. The provided checkpointer will be used to + * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not + * checkpoint. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) + } + /** * Remember the range of sequence numbers that was added to the currently active block. * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 1d5178790ec4c..e381ffa0cbef4 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -27,26 +27,23 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.Logging +import org.apache.spark.streaming.Duration /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create - * multiple Receivers. + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes - * @param checkpointState represents the checkpoint state including the next checkpoint time. - * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor[T]( - receiver: KinesisReceiver[T], - workerId: String, - checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) + extends IRecordProcessor with Logging { - // shardId to be populated during initialize() + // shardId populated during initialize() @volatile private var shardId: String = _ @@ -74,34 +71,7 @@ private[kinesis] class KinesisRecordProcessor[T]( try { receiver.addRecords(shardId, batch) logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - - /* - * - * Checkpoint the sequence number of the last record successfully stored. - * Note that in this current implementation, the checkpointing occurs only when after - * checkpointIntervalMillis from the last checkpoint, AND when there is new record - * to process. This leads to the checkpointing lagging behind what records have been - * stored by the receiver. Ofcourse, this can lead records processed more than once, - * under failures and restarts. - * - * TODO: Instead of checkpointing here, run a separate timer task to perform - * checkpointing so that it checkpoints in a timely manner independent of whether - * new records are available or not. - */ - if (checkpointState.shouldCheckpoint()) { - receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() - - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + - s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") - } - } + receiver.setCheckpointer(shardId, checkpointer) } catch { case NonFatal(e) => { /* @@ -142,23 +112,18 @@ private[kinesis] class KinesisRecordProcessor[T]( * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) - if (latestSeqNumToCheckpointOption.nonEmpty) { - KinesisRecordProcessor.retryRandom( - checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) - } + receiver.removeCheckpointer(shardId, checkpointer) /* - * ZOMBIE Use Case. NoOp. + * ZOMBIE Use Case or Unknown reason. NoOp. * No checkpoint because other workers may have taken over and already started processing * the same records. * This may lead to records being processed more than once. */ - case ShutdownReason.ZOMBIE => - - /* Unknown reason. NoOp */ case _ => + receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala new file mode 100644 index 0000000000000..645e64a0bc3a0 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.util.concurrent.{TimeoutException, ExecutorService} + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.ManualClock + +class KinesisCheckpointerSuite extends TestSuiteBase + with MockitoSugar + with BeforeAndAfterEach + with PrivateMethodTester + with Eventually { + + private val workerId = "dummyWorkerId" + private val shardId = "dummyShardId" + private val seqNum = "123" + private val otherSeqNum = "245" + private val checkpointInterval = Duration(10) + private val someSeqNum = Some(seqNum) + private val someOtherSeqNum = Some(otherSeqNum) + + private var receiverMock: KinesisReceiver[Array[Byte]] = _ + private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var kinesisCheckpointer: KinesisCheckpointer = _ + private var clock: ManualClock = _ + + private val checkpoint = PrivateMethod[Unit]('checkpoint) + + override def beforeEach(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + clock = new ManualClock() + kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + } + + test("checkpoint is not called twice for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("checkpoint is called after sequence number increases") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(5 * checkpointInterval.milliseconds) + + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds / 2) + + verify(checkpointerMock, never()).checkpoint(anyString()) + } + + test("should not checkpoint for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + + clock.advance(checkpointInterval.milliseconds * 5) + eventually(timeout(1 second)) { + verify(checkpointerMock, atMost(1)).checkpoint(anyString()) + } + } + + test("removing checkpointer checkpoints one last time") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("if checkpointing is going on, wait until finished before removing and checkpointing") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2) + } + }) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + // don't block test thread + val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))( + ExecutionContext.global) + + intercept[TimeoutException] { + Await.ready(f, 50 millis) + } + + clock.advance(checkpointInterval.milliseconds / 2) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(2)).checkpoint(anyString()) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 17ab444704f44..e5c70db554a27 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -25,12 +25,13 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfter, Matchers} -import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -44,6 +45,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val workerId = "dummyWorkerId" val shardId = "dummyShardId" val seqNum = "dummySeqNum" + val checkpointInterval = Duration(10) val someSeqNum = Some(seqNum) val record1 = new Record() @@ -54,24 +56,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ - var checkpointClockMock: ManualClock = _ - var checkpointStateMock: KinesisCheckpointState = _ - var currentClockMock: Clock = _ override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] - checkpointClockMock = mock[ManualClock] - checkpointStateMock = mock[KinesisCheckpointState] - currentClockMock = mock[Clock] - } - - override def afterFunction(): Unit = { - super.afterFunction() - // Since this suite was originally written using EasyMock, add this to preserve the old - // mocking semantics (see SPARK-5735 for more details) - verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, - checkpointStateMock, currentClockMock) } test("check serializability of SerializableAWSCredentials") { @@ -79,113 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft Utils.serialize(new SerializableAWSCredentials("x", "y"))) } - test("process records including store and checkpoint") { + test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint(anyString) - verify(checkpointStateMock, times(1)).advanceCheckpoint() + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) } - test("shouldn't store and checkpoint when receiver is stopped") { + test("shouldn't store and update checkpointer when receiver is stopped") { when(receiverMock.isStopped()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("shouldn't checkpoint when exception occurs during store") { + test("shouldn't update checkpointer when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) when( receiverMock.addRecords(anyString, anyListOf(classOf[Record])) ).thenThrow(new RuntimeException()) intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) } verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should checkpoint if we have exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should add to time when advancing checkpoint") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) - - verify(currentClockMock, times(1)).getTimeMillis() + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } test("shutdown should checkpoint if the reason is TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointerMock, times(1)).checkpoint(anyString) + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) recordProcessor.shutdown(checkpointerMock, null) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), + meq[IRecordProcessorCheckpointer](null)) } test("retry success on first attempt") { From 8a2336893a7ff610a6c4629dd567b85078730616 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 14:56:36 -0800 Subject: [PATCH 255/324] [SPARK-6517][MLLIB] Implement the Algorithm of Hierarchical Clustering I implemented a hierarchical clustering algorithm again. This PR doesn't include examples, documentation and spark.ml APIs. I am going to send another PRs later. https://issues.apache.org/jira/browse/SPARK-6517 - This implementation based on a bi-sectiong K-means clustering. - It derives from the freeman-lab 's implementation - The basic idea is not changed from the previous version. (#2906) - However, It is 1000x faster than the previous version through parallel processing. Thank you for your great cooperation, RJ Nowling(rnowling), Jeremy Freeman(freeman-lab), Xiangrui Meng(mengxr) and Sean Owen(srowen). Author: Yu ISHIKAWA Author: Xiangrui Meng Author: Yu ISHIKAWA Closes #5267 from yu-iskw/new-hierarchical-clustering. --- .../mllib/clustering/BisectingKMeans.scala | 491 ++++++++++++++++++ .../clustering/BisectingKMeansModel.scala | 95 ++++ .../clustering/JavaBisectingKMeansSuite.java | 73 +++ .../clustering/BisectingKMeansSuite.scala | 182 +++++++ 4 files changed, 841 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..29a7aa0bb63f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if + * there are no divisible leaf clusters. + * @param maxIterations the max number of k-means iterations to split clusters (default: 20) + * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1) + * @param seed a random seed (default: hash value of the class name) + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("1.6.0") +@Experimental +class BisectingKMeans private ( + private var k: Int, + private var maxIterations: Int, + private var minDivisibleClusterSize: Double, + private var seed: Long) extends Logging { + + import BisectingKMeans._ + + /** + * Constructs with the default configuration + */ + @Since("1.6.0") + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + + /** + * Sets the desired number of leaf clusters (default: 4). + * The actual number could be smaller if there are no divisible leaf clusters. + */ + @Since("1.6.0") + def setK(k: Int): this.type = { + require(k > 0, s"k must be positive but got $k.") + this.k = k + this + } + + /** + * Gets the desired number of leaf clusters. + */ + @Since("1.6.0") + def getK: Int = this.k + + /** + * Sets the max number of k-means iterations to split clusters (default: 20). + */ + @Since("1.6.0") + def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.") + this.maxIterations = maxIterations + this + } + + /** + * Gets the max number of k-means iterations to split clusters. + */ + @Since("1.6.0") + def getMaxIterations: Int = this.maxIterations + + /** + * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster (default: 1). + */ + @Since("1.6.0") + def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = { + require(minDivisibleClusterSize > 0.0, + s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.") + this.minDivisibleClusterSize = minDivisibleClusterSize + this + } + + /** + * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster. + */ + @Since("1.6.0") + def getMinDivisibleClusterSize: Double = minDivisibleClusterSize + + /** + * Sets the random seed (default: hash value of the class name). + */ + @Since("1.6.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Gets the random seed. + */ + @Since("1.6.0") + def getSeed: Long = this.seed + + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + if (input.getStorageLevel == StorageLevel.NONE) { + logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + + " its parent RDDs are also not cached.") + } + val d = input.map(_.size).first() + logInfo(s"Feature dimension: $d.") + // Compute and cache vector norms for fast distance computation. + val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) + val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } + var assignments = vectors.map(v => (ROOT_INDEX, v)) + var activeClusters = summarize(d, assignments) + val rootSummary = activeClusters(ROOT_INDEX) + val n = rootSummary.size + logInfo(s"Number of points: $n.") + logInfo(s"Initial cost: ${rootSummary.cost}.") + val minSize = if (minDivisibleClusterSize >= 1.0) { + math.ceil(minDivisibleClusterSize).toLong + } else { + math.ceil(minDivisibleClusterSize * n).toLong + } + logInfo(s"The minimum number of points of a divisible cluster is $minSize.") + var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)] + val random = new Random(seed) + var numLeafClustersNeeded = k - 1 + var level = 1 + while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) { + // Divisible clusters are sufficiently large and have non-trivial cost. + var divisibleClusters = activeClusters.filter { case (_, summary) => + (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size) + } + // If we don't need all divisible clusters, take the larger ones. + if (divisibleClusters.size > numLeafClustersNeeded) { + divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) => + -summary.size + }.take(numLeafClustersNeeded) + .toMap + } + if (divisibleClusters.nonEmpty) { + val divisibleIndices = divisibleClusters.keys.toSet + logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") + var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => + val (left, right) = splitCenter(summary.center, random) + Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) + }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map + var newClusters: Map[Long, ClusterSummary] = null + var newAssignments: RDD[(Long, VectorWithNorm)] = null + for (iter <- 0 until maxIterations) { + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + .filter { case (index, _) => + divisibleIndices.contains(parentIndex(index)) + } + newClusters = summarize(d, newAssignments) + newClusterCenters = newClusters.mapValues(_.center).map(identity) + } + // TODO: Unpersist old indices. + val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + .persist(StorageLevel.MEMORY_AND_DISK) + assignments = indices.zip(vectors) + inactiveClusters ++= activeClusters + activeClusters = newClusters + numLeafClustersNeeded -= divisibleClusters.size + } else { + logInfo(s"None active and divisible clusters left on level $level. Stop iterations.") + inactiveClusters ++= activeClusters + activeClusters = Map.empty + } + level += 1 + } + val clusters = activeClusters ++ inactiveClusters + val root = buildTree(clusters) + new BisectingKMeansModel(root) + } + + /** + * Java-friendly version of [[run(RDD[Vector])*]] + */ + def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) +} + +private object BisectingKMeans extends Serializable { + + /** The index of the root node of a tree. */ + private val ROOT_INDEX: Long = 1 + + private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2 + + private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2) + + /** Returns the left child index of the given node index. */ + private def leftChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.") + 2 * index + } + + /** Returns the right child index of the given node index. */ + private def rightChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.") + 2 * index + 1 + } + + /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */ + private def parentIndex(index: Long): Long = { + index / 2 + } + + /** + * Summarizes data by each cluster as Map. + * @param d feature dimension + * @param assignments pairs of point and its cluster index + * @return a map from cluster indices to corresponding cluster summaries + */ + private def summarize( + d: Int, + assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + seqOp = (agg, v) => agg.add(v), + combOp = (agg1, agg2) => agg1.merge(agg2) + ).mapValues(_.summary) + .collect().toMap + } + + /** + * Cluster summary aggregator. + * @param d feature dimension + */ + private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private var n: Long = 0L + private val sum: Vector = Vectors.zeros(d) + private var sumSq: Double = 0.0 + + /** Adds a point. */ + def add(v: VectorWithNorm): this.type = { + n += 1L + // TODO: use a numerically stable approach to estimate cost + sumSq += v.norm * v.norm + BLAS.axpy(1.0, v.vector, sum) + this + } + + /** Merges another aggregator. */ + def merge(other: ClusterSummaryAggregator): this.type = { + n += other.n + sumSq += other.sumSq + BLAS.axpy(1.0, other.sum, sum) + this + } + + /** Returns the summary. */ + def summary: ClusterSummary = { + val mean = sum.copy + if (n > 0L) { + BLAS.scal(1.0 / n, mean) + } + val center = new VectorWithNorm(mean) + val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) + new ClusterSummary(n, center, cost) + } + } + + /** + * Bisects a cluster center. + * + * @param center current cluster center + * @param random a random number generator + * @return initial centers + */ + private def splitCenter( + center: VectorWithNorm, + random: Random): (VectorWithNorm, VectorWithNorm) = { + val d = center.vector.size + val norm = center.norm + val level = 1e-4 * norm + val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) + val left = center.vector.copy + BLAS.axpy(-level, noise, left) + val right = center.vector.copy + BLAS.axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * Updates assignments. + * @param assignments current assignments + * @param divisibleIndices divisible cluster indices + * @param newClusterCenters new cluster centers + * @return new assignments + */ + private def updateAssignments( + assignments: RDD[(Long, VectorWithNorm)], + divisibleIndices: Set[Long], + newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + assignments.map { case (index, v) => + if (divisibleIndices.contains(index)) { + val children = Seq(leftChildIndex(index), rightChildIndex(index)) + val selected = children.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) + } + } + } + + /** + * Builds a clustering tree by re-indexing internal and leaf clusters. + * @param clusters a map from cluster indices to corresponding cluster summaries + * @return the root node of the clustering tree + */ + private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + var leafIndex = 0 + var internalIndex = -1 + + /** + * Builds a subtree from this given node index. + */ + def buildSubTree(rawIndex: Long): ClusteringTreeNode = { + val cluster = clusters(rawIndex) + val size = cluster.size + val center = cluster.center + val cost = cluster.cost + val isInternal = clusters.contains(leftChildIndex(rawIndex)) + if (isInternal) { + val index = internalIndex + internalIndex -= 1 + val leftIndex = leftChildIndex(rawIndex) + val rightIndex = rightChildIndex(rawIndex) + val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + KMeans.fastSquaredDistance(center, clusters(childIndex).center) + }.max) + val left = buildSubTree(leftIndex) + val right = buildSubTree(rightIndex) + new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + } else { + val index = leafIndex + leafIndex += 1 + val height = 0.0 + new ClusteringTreeNode(index, size, center, cost, height, Array.empty) + } + } + + buildSubTree(ROOT_INDEX) + } + + /** + * Summary of a cluster. + * + * @param size the number of points within this cluster + * @param center the center of the points within this cluster + * @param cost the sum of squared distances to the center + */ + private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double) +} + +/** + * Represents a node in a clustering tree. + * + * @param index node index, negative for internal nodes and non-negative for leaf nodes + * @param size size of the cluster + * @param centerWithNorm cluster center with norm + * @param cost cost of the cluster, i.e., the sum of squared distances to the center + * @param height height of the node in the dendrogram. Currently this is defined as the max distance + * from the center to the centers of the children's, but subject to change. + * @param children children nodes + */ +@Since("1.6.0") +@Experimental +class ClusteringTreeNode private[clustering] ( + val index: Int, + val size: Long, + private val centerWithNorm: VectorWithNorm, + val cost: Double, + val height: Double, + val children: Array[ClusteringTreeNode]) extends Serializable { + + /** Whether this is a leaf node. */ + val isLeaf: Boolean = children.isEmpty + + require((isLeaf && index >= 0) || (!isLeaf && index < 0)) + + /** Cluster center. */ + def center: Vector = centerWithNorm.vector + + /** Predicts the leaf cluster node index that the input point belongs to. */ + def predict(point: Vector): Int = { + val (index, _) = predict(new VectorWithNorm(point)) + index + } + + /** Returns the full prediction path from root to leaf. */ + def predictPath(point: Vector): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point)).toArray + } + + /** Returns the full prediction path from root to leaf. */ + private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + if (isLeaf) { + this :: Nil + } else { + val selected = children.minBy { child => + KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + } + selected :: selected.predictPath(pointWithNorm) + } + } + + /** + * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + */ + def computeCost(point: Vector): Double = { + val (_, cost) = predict(new VectorWithNorm(point)) + cost + } + + /** + * Predicts the cluster index and the cost of the input point. + */ + private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { + predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + } + + /** + * Predicts the cluster index and the cost of the input point. + * @param pointWithNorm input point + * @param cost the cost to the current center + * @return (predicted leaf cluster index, cost) + */ + private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + if (isLeaf) { + (index, cost) + } else { + val (selectedChild, minCost) = children.map { child => + (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + }.minBy(_._2) + selectedChild.predict(pointWithNorm, minCost) + } + } + + /** + * Returns all leaf nodes from this node. + */ + def leafNodes: Array[ClusteringTreeNode] = { + if (isLeaf) { + Array(this) + } else { + children.flatMap(_.leafNodes) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala new file mode 100644 index 0000000000000..5015f1540d920 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * Clustering model produced by [[BisectingKMeans]]. + * The prediction is done level-by-level from the root node to a leaf node, and at each node among + * its children the closest to the input point is selected. + * + * @param root the root node of the clustering tree + */ +@Since("1.6.0") +@Experimental +class BisectingKMeansModel @Since("1.6.0") ( + @Since("1.6.0") val root: ClusteringTreeNode + ) extends Serializable with Logging { + + /** + * Leaf cluster centers. + */ + @Since("1.6.0") + def clusterCenters: Array[Vector] = root.leafNodes.map(_.center) + + /** + * Number of leaf clusters. + */ + lazy val k: Int = clusterCenters.length + + /** + * Predicts the index of the cluster that the input point belongs to. + */ + @Since("1.6.0") + def predict(point: Vector): Int = { + root.predict(point) + } + + /** + * Predicts the indices of the clusters that the input points belong to. + */ + @Since("1.6.0") + def predict(points: RDD[Vector]): RDD[Int] = { + points.map { p => root.predict(p) } + } + + /** + * Java-friendly version of [[predict(RDD[Vector])*]] + */ + @Since("1.6.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + + /** + * Computes the squared distance between the input point and the cluster center it belongs to. + */ + @Since("1.6.0") + def computeCost(point: Vector): Double = { + root.computeCost(point) + } + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("1.6.0") + def computeCost(data: RDD[Vector]): Double = { + data.map(root.computeCost).sum() + } + + /** + * Java-friendly version of [[computeCost(RDD[Vector])*]]. + */ + @Since("1.6.0") + def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java new file mode 100644 index 0000000000000..a714620ff7e4b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaBisectingKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void twoDimensionalData() { + JavaRDD points = sc.parallelize(Lists.newArrayList( + Vectors.dense(4, -1), + Vectors.dense(4, 1), + Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + ), 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(2) + .setSeed(1L); + BisectingKMeansModel model = bkm.run(points); + Assert.assertEquals(3, model.k()); + Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child: model.root().children()) { + double[] center = child.center().toArray(); + if (center[0] > 2) { + Assert.assertEquals(2, child.size()); + Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + } else { + Assert.assertEquals(1, child.size()); + Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..41b9d5c0d93bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("default values") { + val bkm0 = new BisectingKMeans() + assert(bkm0.getK === 4) + assert(bkm0.getMaxIterations === 20) + assert(bkm0.getMinDivisibleClusterSize === 1.0) + val bkm1 = new BisectingKMeans() + assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.") + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + + val k = 10 + assert(bkm.getK !== k) + assert(bkm.setK(k).getK === k) + val maxIter = 100 + assert(bkm.getMaxIterations !== maxIter) + assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter) + val minSize = 2.0 + assert(bkm.getMinDivisibleClusterSize !== minSize) + assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize) + val seed = 10L + assert(bkm.getSeed !== seed) + assert(bkm.setSeed(seed).getSeed === seed) + + intercept[IllegalArgumentException] { + bkm.setK(0) + } + intercept[IllegalArgumentException] { + bkm.setMaxIterations(0) + } + intercept[IllegalArgumentException] { + bkm.setMinDivisibleClusterSize(0.0) + } + } + + test("1D data") { + val points = Vectors.sparse(1, Array.empty, Array.empty) +: + (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(1) + .setSeed(1L) + // The clusters should be + // (0, 1, 2, 3, 4, 5, 6, 7) + // - (0, 1, 2, 3) + // - (0, 1) + // - (2, 3) + // - (4, 5, 6, 7) + // - (4, 5) + // - (6, 7) + val model = bkm.run(data) + assert(model.k === 4) + // The total cost should be 8 * 0.5 * 0.5 = 2.0. + assert(model.computeCost(data) ~== 2.0 relTol 1e-12) + val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap() + Range(0, 8, 2).foreach { i => + assert(predictions(i) === predictions(i + 1), + s"$i and ${i + 1} should belong to the same cluster.") + } + val root = model.root + assert(root.center(0) ~== 3.5 relTol 1e-12) + assert(root.height ~== 2.0 relTol 1e-12) + assert(root.children.length === 2) + assert(root.children(0).height ~== 1.0 relTol 1e-12) + assert(root.children(1).height ~== 1.0 relTol 1e-12) + } + + test("points are the same") { + val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2) + val bkm = new BisectingKMeans() + .setK(2) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 1) + } + + test("more desired clusters than points") { + val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2) + val bkm = new BisectingKMeans() + .setK(8) + .setMaxIterations(2) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 4) + } + + test("min divisible cluster") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMinDivisibleClusterSize(10) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + + bkm.setMinDivisibleClusterSize(0.5) + val sameModel = bkm.run(data) + assert(sameModel.k === 3) + } + + test("larger clusters get selected first") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + } + + test("2D data") { + val points = Seq( + (11, 10), (9, 10), (10, 9), (10, 11), + (11, -10), (9, -10), (10, -9), (10, -11), + (0, 1), (0, -1) + ).map { case (x, y) => + if (x == 0) { + Vectors.sparse(2, Array(1), Array(y)) + } else { + Vectors.dense(x, y) + } + } + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(4) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12) + model.root.leafNodes.foreach { node => + if (node.center(0) < 5) { + assert(node.size === 2) + assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12) + } else if (node.center(1) > 0) { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12) + } else { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12) + } + } + } +} From fcb57e9c7323e24b8563800deb035f94f616474e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 9 Nov 2015 15:16:47 -0800 Subject: [PATCH 256/324] [SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction` Author: Wenchen Fan Closes #9564 from cloud-fan/map. --- .../api/java/function/CoGroupFunction.java | 29 +++++++++++++++ .../api/java/function/FlatMapFunction.java | 2 +- .../api/java/function/FlatMapFunction2.java | 2 +- .../java/function/FlatMapGroupFunction.java | 28 +++++++++++++++ .../api/java/function/MapGroupFunction.java | 28 +++++++++++++++ .../plans/logical/basicOperators.scala | 4 +-- .../org/apache/spark/sql/GroupedDataset.scala | 12 +++---- .../spark/sql/execution/basicOperators.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 36 ++++++++++++------- 9 files changed, 118 insertions(+), 25 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 0000000000000..279639af5d430 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction extends Serializable { + Iterable call(K key, Iterator left, Iterator right) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd43631b..ef0d1824121ec 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,5 +23,5 @@ * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterable call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f535ff5..14a98a38ef5ab 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,5 +23,5 @@ * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - public Iterable call(T1 t1, T2 t2) throws Exception; + Iterable call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java new file mode 100644 index 0000000000000..18a2d733ca70d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupFunction extends Serializable { + Iterable call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java new file mode 100644 index 0000000000000..2935f9986a560 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's map function. + */ +public interface MapGroupFunction extends Serializable { + R call(K key, Iterator values) throws Exception; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e151ac04ede2a..d771088d69dea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -527,7 +527,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, @@ -551,7 +551,7 @@ object CoGroup { * right children. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5c3f626545875..850315e281dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } - def flatMap[U]( - f: JFunction2[K, JIterator[T], JIterator[U]], - encoder: Encoder[U]): Dataset[U] = { + def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) } @@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(func, groupingAttributes, logicalPlan)) } - def map[U]( - f: JFunction2[K, JIterator[T], U], - encoder: Encoder[U]): Dataset[U] = { + def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } @@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql]( */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.tEncoder new Dataset[R]( sqlContext, @@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R]( other: GroupedDataset[K, U], - f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + f: CoGroupFunction[K, T, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2593b16b1c8d7..145de0db9edaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -391,7 +391,7 @@ case class MapGroups[K, T, U]( * The result of this function is encoded and flattened before being output. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0f90de774dd3e..312cf33e4c2d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -29,7 +29,6 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.catalyst.encoders.Encoder; import org.apache.spark.sql.catalyst.encoders.Encoder$; @@ -170,20 +169,33 @@ public Integer call(String v) throws Exception { } }, e.INT()); - Dataset mapped = grouped.map( - new Function2, String>() { + Dataset mapped = grouped.map(new MapGroupFunction() { + @Override + public String call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset flatMapped = grouped.flatMap( + new FlatMapGroupFunction() { @Override - public String call(Integer key, Iterator data) throws Exception { + public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); + while (values.hasNext()) { + sb.append(values.next()); } - return sb.toString(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); @@ -196,9 +208,9 @@ public Integer call(Integer v) throws Exception { Dataset cogrouped = grouped.cogroup( grouped2, - new Function3, Iterator, Iterator>() { + new CoGroupFunction() { @Override - public Iterator call( + public Iterable call( Integer key, Iterator left, Iterator right) throws Exception { @@ -210,7 +222,7 @@ public Iterator call( while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); @@ -225,7 +237,7 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); Dataset mapped = grouped.map( - new Function2, String>() { + new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); From 9565c246eadecf4836d247d0067f2200f061d25f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 9 Nov 2015 15:20:50 -0800 Subject: [PATCH 257/324] [SPARK-9557][SQL] Refactor ParquetFilterSuite and remove old ParquetFilters code Actually this was resolved by https://github.com/apache/spark/pull/8275. But I found the JIRA issue for this is not marked as resolved since the PR above was made for another issue but the PR above resolved both. I commented that this is resolved by the PR above; however, I opened this PR as I would like to just add a little bit of corrections. In the previous PR, I refactored the test by not reducing just collecting filters; however, this would not test properly `And` filter (which is not given to the tests). I unintentionally changed this from the original way (before being refactored). In this PR, I just followed the original way to collect filters by reducing. I would like to close this if this PR is inappropriate and somebody would like this deal with it in the separate PR related with this. Author: hyukjinkwon Closes #9554 from HyukjinKwon/SPARK-9557. --- .../datasources/parquet/ParquetFilterSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c24c9f025dad7..579dabf73318b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -54,12 +54,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val analyzedPredicate = query.queryExecution.optimizedPlan.collect { + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten - assert(analyzedPredicate.nonEmpty) + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined) - val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter) + val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) assert(selectedFilters.nonEmpty) selectedFilters.foreach { pred => From 2f38378856fb56bdd9be7ccedf56427e81701f4e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 9 Nov 2015 16:06:48 -0800 Subject: [PATCH 258/324] [SPARK-11360][DOC] Loss of nullability when writing parquet files This fix is to add one line to explain the current behavior of Spark SQL when writing Parquet files. All columns are forced to be nullable for compatibility reasons. Author: gatorsmile Closes #9314 from gatorsmile/lossNull. --- docs/sql-programming-guide.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ccd26904329d3..6e02d6564b002 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -982,7 +982,8 @@ when a table is dropped. [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +compatibility reasons. ### Loading Data Programmatically From 9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 9 Nov 2015 16:11:00 -0800 Subject: [PATCH 259/324] [SPARK-11578][SQL] User API for Typed Aggregation This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust Closes #9555 from marmbrus/dataset-useragg. --- .../scala/org/apache/spark/sql/Column.scala | 11 +- .../scala/org/apache/spark/sql/Dataset.scala | 30 ++-- .../org/apache/spark/sql/GroupedDataset.scala | 51 ++++--- .../org/apache/spark/sql/SQLContext.scala | 1 - .../aggregate/TypedAggregateExpression.scala | 129 ++++++++++++++++++ .../spark/sql/expressions/Aggregator.scala | 81 +++++++++++ .../org/apache/spark/sql/functions.scala | 30 +++- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../spark/sql/DatasetAggregatorSuite.scala | 65 +++++++++ 9 files changed, 360 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c32c93897ce0b..d26b6c3579205 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -39,10 +39,13 @@ private[sql] object Column { } /** - * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. * @since 1.6.0 + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. */ -class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) +class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr) /** * :: Experimental :: @@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 959e0f5ba03e6..6d2968e2881f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -358,7 +358,7 @@ class Dataset[T] private[sql]( * }}} * @since 1.6.0 */ - def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) } @@ -367,7 +367,7 @@ class Dataset[T] private[sql]( * code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) @@ -385,7 +385,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** @@ -393,9 +393,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** @@ -403,10 +403,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** @@ -414,11 +414,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4], - c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 850315e281dfe..db61499229284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.{Iterator => JIterator} + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.QueryExecution + /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql]( * that cast appropriately for the user facing interface. * TODO: does not handle aggrecations that return nonflat results, */ - protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = (groupingAttributes ++ columns.map(_.expr)).map { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr @@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql]( } val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) + + // Fill in the input encoders for any aggregators in the plan. + val withEncoders = unresolvedPlan transformAllExpressions { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]), + children = dataAttributes) + } + val execution = new QueryExecution(sqlContext, withEncoders) val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) @@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql]( case (e, a) => e.unbind(a :: Nil).resolve(execution.analyzed.output) } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(encoders)) } /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. */ - def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + def agg[U1, U2, U3]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3, A4]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3], - col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + def agg[U1, U2, U3, U4]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3], + col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5598731af5fcc..1cf1e30f967cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 0000000000000..24d8122b6222b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructType, DataType} + +object TypedAggregateExpression { + def apply[A, B : Encoder, C : Encoder]( + aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], + Nil, + 0, + 0) + } +} + +/** + * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has + * the following limitations: + * - It assumes the aggregator reduces and returns a single column of type `long`. + * - It might only work when there is a single aggregator in the first column. + * - It assumes the aggregator has a zero, `0`. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + aEncoder: Option[ExpressionEncoder[Any]], + bEncoder: ExpressionEncoder[Any], + cEncoder: ExpressionEncoder[Any], + children: Seq[Expression], + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + // TODO: this assumes flat results... + override def dataType: DataType = cEncoder.schema.head.dataType + + override def deterministic: Boolean = true + + override lazy val resolved: Boolean = aEncoder.isDefined + + override lazy val inputTypes: Seq[DataType] = + aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil) + + override val aggBufferSchema: StructType = bEncoder.schema + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + lazy val inputAttributes = aEncoder.get.schema.toAttributes + lazy val inputMapping = AttributeMap(inputAttributes.zip(children)) + lazy val boundA = + aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform { + case a: AttributeReference => inputMapping(a) + }) + + // TODO: this probably only works when we are in the first column. + val bAttributes = bEncoder.schema.toAttributes + lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + + override def initialize(buffer: MutableRow): Unit = { + // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for + // this in execution. + buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputA = boundA.fromRow(input) + val currentB = boundB.fromRow(buffer) + val merged = aggregator.reduce(currentB, inputA) + val returned = boundB.toRow(merged) + buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + buffer1.setLong( + mutableAggBufferOffset, + buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + } + + override def eval(buffer: InternalRow): Any = { + buffer.getInt(mutableAggBufferOffset) + } + + override def toString: String = { + s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = aggregator.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 0000000000000..0b3192a6da9d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} + +/** + * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] + * operations to take all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero = 0 + * def reduce(b: Int, a: Data) = b + a.i + * def present(r: Int) = r + * }.toColumn() + * + * val ds: Dataset[Data] + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam A The input type for the aggregation. + * @tparam B The type of the intermediate value of the reduction. + * @tparam C The type of the final result. + */ +abstract class Aggregator[-A, B, C] { + + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + def zero: B + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + def reduce(b: B, a: A): B + + /** + * Transform the output of the reduction. + */ + def present(reduction: B): C + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * operations. + */ + def toColumn( + implicit bEncoder: Encoder[B], + cEncoder: Encoder[C]): TypedColumn[A, C] = { + val expr = + new AggregateExpression2( + TypedAggregateExpression(this), + Complete, + false) + + new TypedColumn[A, C](expr, encoderFor[C]) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3f0b24b68b816..6d56542ee0875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql + + import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -24,11 +26,32 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have + * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate + * "bridge" methods due to the use of covariant return types. + * + * {{{ + * In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * + * In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); + * }}} + * + * This allows us to use the same functions both in typed [[Dataset]] operations and untyped + * [[DataFrame]] operations when the return type for a given function is statically known. + */ +private[sql] abstract class LegacyFunctions { + def count(columnName: String): Column +} + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -48,11 +71,14 @@ import org.apache.spark.util.Utils */ @Experimental // scalastyle:off -object functions { +object functions extends LegacyFunctions { // scalastyle:on private def withExpr(expr: Expression): Column = Column(expr) + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + + /** * Returns a [[Column]] based on the given column name. * @@ -234,7 +260,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long] /** * Aggregate function: returns the number of distinct items in a group. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 312cf33e4c2d4..2da63d1b96706 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -258,8 +258,8 @@ public void testSelect() { Dataset ds = context.createDataset(data, e.INT()); Dataset> selected = ds.select( - expr("value + 1").as(e.INT()), - col("value").cast("string").as(e.STRING())); + expr("value + 1"), + col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 0000000000000..340470c096b87 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.functions._ + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def present(reduction: N): N = reduction +} + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum(_._2)), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Int], + count("*")), + ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + } +} From 675c7e723cadff588405c23826a00686587728b8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 9 Nov 2015 16:22:15 -0800 Subject: [PATCH 260/324] [SPARK-11564][SQL] Fix documentation for DataFrame.take/collect Author: Reynold Xin Closes #9557 from rxin/SPARK-11564-1. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8ab958adadcca..d25807cf8d09c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1479,8 +1479,8 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. * - * Running take requires moving data into the application's driver process, and doing so on a - * very large dataset can crash the driver process with OutOfMemoryError. + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. * * @group action * @since 1.3.0 @@ -1501,8 +1501,8 @@ class DataFrame private[sql]( /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * From 7dc9d8dba6c4bc655896b137062d896dec4ef64a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 16:25:29 -0800 Subject: [PATCH 261/324] [SPARK-11610][MLLIB][PYTHON][DOCS] Make the docs of LDAModel.describeTopics in Python more specific cc jkbradley Author: Yu ISHIKAWA Closes #9577 from yu-iskw/SPARK-11610. --- python/pyspark/mllib/clustering.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12081f8c69075..1fa061dc2da99 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -734,6 +734,12 @@ def describeTopics(self, maxTermsPerTopic=None): """Return the topics described by weighted terms. WARNING: If vocabSize and k are large, this can return a large object! + + :param maxTermsPerTopic: Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: Array over topics. Each topic is represented as a pair of matching arrays: + (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") From 61f9c8711c79f35d67b0456155866da316b131d9 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 9 Nov 2015 16:55:23 -0800 Subject: [PATCH 262/324] [SPARK-11069][ML] Add RegexTokenizer option to convert to lowercase jira: https://issues.apache.org/jira/browse/SPARK-11069 quotes from jira: Tokenizer converts strings to lowercase automatically, but RegexTokenizer does not. It would be nice to add an option to RegexTokenizer to convert to lowercase. Proposal: call the Boolean Param "toLowercase" set default to false (so behavior does not change) Actually sklearn converts to lowercase before tokenizing too Author: Yuhao Yang Closes #9092 from hhbyyh/tokenLower. --- .../apache/spark/ml/feature/Tokenizer.scala | 19 ++++++++++++++-- .../spark/ml/feature/JavaTokenizerSuite.java | 1 + .../spark/ml/feature/TokenizerSuite.scala | 22 ++++++++++++++----- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 248288ca73e99..1b82b40caac18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -100,10 +100,25 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + /** + * Indicates whether to convert all characters to lowercase before tokenizing. + * Default: true + * @group param + */ + final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", + "whether to convert all characters to lowercase before tokenizing.") + + /** @group setParam */ + def setToLowercase(value: Boolean): this.type = set(toLowercase, value) + + /** @group getParam */ + def getToLowercase: Boolean = $(toLowercase) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { str => + override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + val str = if ($(toLowercase)) originStr.toLowerCase() else originStr val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 02309ce63219a..c407d98f1b795 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -53,6 +53,7 @@ public void regexTokenizer() { .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) + .setToLowercase(false) .setMinTokenLength(3); diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index e5fd21c3f6fca..a02992a2407b3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -48,13 +48,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer0.setMinTokenLength(3) @@ -64,11 +64,23 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) testRegexTokenizer(tokenizer2, dataset2) } + + test("RegexTokenizer with toLowercase false") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setToLowercase(false) + val dataset = sqlContext.createDataFrame(Seq( + TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), + TokenizerTestData("java scala", Array("java", "scala")) + )) + testRegexTokenizer(tokenizer, dataset) + } } object RegexTokenizerSuite extends SparkFunSuite { From 26062d22607e1f9854bc2588ba22a4e0f8bba48c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 17:18:49 -0800 Subject: [PATCH 263/324] [SPARK-11198][STREAMING][KINESIS] Support de-aggregation of records during recovery While the KCL handles de-aggregation during the regular operation, during recovery we use the lower level api, and therefore need to de-aggregate the records. tdas Testing is an issue, we need protobuf magic to do the aggregated records. Maybe we could depend on KPL for tests? Author: Burak Yavuz Closes #9403 from brkyvz/kinesis-deaggregation. --- extras/kinesis-asl/pom.xml | 6 ++ .../kinesis/KinesisBackedBlockRDD.scala | 6 +- .../streaming/kinesis/KinesisReceiver.scala | 1 - .../kinesis/KinesisRecordProcessor.scala | 2 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 12 +++- .../kinesis/KinesisStreamSuite.scala | 17 +++--- .../streaming/kinesis/KinesisTestUtils.scala | 55 +++++++++++++++---- pom.xml | 2 + 8 files changed, 76 insertions(+), 25 deletions(-) rename extras/kinesis-asl/src/{main => test}/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala (80%) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index ef72d97eae69d..519a920279c97 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -64,6 +64,12 @@ aws-java-sdk ${aws.java.sdk.version} + + com.amazonaws + amazon-kinesis-producer + ${aws.kinesis.producer.version} + test + org.mockito mockito-core diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 000897a4e7290..691c1790b207f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ import org.apache.spark._ @@ -210,7 +211,10 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) + (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 50993f157cd95..97dbb918573a3 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -216,7 +216,6 @@ private[kinesis] class KinesisReceiver[T]( val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) - } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index e381ffa0cbef4..b5b76cb92d866 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -80,7 +80,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * more than once. */ logError(s"Exception: WorkerId $workerId encountered and exception while storing " + - " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + s" or checkpointing a batch for workerId $workerId and shardId $shardId.", e) /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 9f9e146a08d46..52c61dfb1c023 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} -class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { +abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) + extends KinesisFunSuite with BeforeAndAfterAll { private val testData = 1 to 8 @@ -37,13 +38,12 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll private var sc: SparkContext = null private var blockManager: BlockManager = null - override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KinesisTestUtils() testUtils.createStream() - shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq @@ -247,3 +247,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll Array.tabulate(num) { i => new StreamBlockId(0, i) } } } + +class WithAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = true) + +class WithoutAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ba84e557dfcc2..dee30444d8cc6 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils import org.apache.spark.{SparkConf, SparkContext} -class KinesisStreamSuite extends KinesisFunSuite +abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB @@ -182,13 +182,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) assert(collected === testData.toSet, "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) @@ -207,13 +207,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) val modData = testData.map(_ + 5) assert(collected === modData.toSet, "\nData received does not match data sent") } @@ -254,7 +254,7 @@ class KinesisStreamSuite extends KinesisFunSuite // If this times out because numBatchesWithData is empty, then its likely that foreachRDD // function failed with exceptions, and nothing got added to `collectedData` eventually(timeout(2 minutes), interval(1 seconds)) { - testUtils.pushData(1 to 5) + testUtils.pushData(1 to 5, aggregateTestData) assert(isCheckpointPresent && numBatchesWithData > 10) } ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused @@ -285,5 +285,8 @@ class KinesisStreamSuite extends KinesisFunSuite } ssc.stop() } - } + +class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true) + +class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 80% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 634bf94521079..7487aa1c12639 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,6 +31,8 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ +import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging @@ -64,6 +66,16 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } + private lazy val kinesisProducer: KinesisProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KinesisProducer(conf) + } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName @@ -90,22 +102,41 @@ private[kinesis] class KinesisTestUtils extends Logging { * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard */ - def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() testData.foreach { num => val str = num.toString - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(ByteBuffer.wrap(str.getBytes())) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) + val data = ByteBuffer.wrap(str.getBytes()) + if (aggregate) { + val future = kinesisProducer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + + Futures.addCallback(future, kinesisCallBack) + kinesisProducer.flushSync() // make sure we send all data before returning the map + } else { + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } } logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") @@ -116,7 +147,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(testData.asScala) + pushData(testData.asScala, aggregate = false) } def deleteStream(): Unit = { diff --git a/pom.xml b/pom.xml index 4ed1c0c82dee6..fd8c773513881 100644 --- a/pom.xml +++ b/pom.xml @@ -154,6 +154,8 @@ 0.7.1 1.9.40 1.4.0 + + 0.10.1 4.3.2 From 0ce6f9b2d203ce67aeb4d3aedf19bbd997fe01b9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 17:35:12 -0800 Subject: [PATCH 264/324] [SPARK-11141][STREAMING] Batch ReceivedBlockTrackerLogEvents for WAL writes When using S3 as a directory for WALs, the writes take too long. The driver gets very easily bottlenecked when multiple receivers send AddBlock events to the ReceiverTracker. This PR adds batching of events in the ReceivedBlockTracker so that receivers don't get blocked by the driver for too long. cc zsxwing tdas Author: Burak Yavuz Closes #9143 from brkyvz/batch-wal-writes. --- .../scheduler/ReceivedBlockTracker.scala | 62 ++- .../streaming/scheduler/ReceiverTracker.scala | 25 +- .../streaming/util/BatchedWriteAheadLog.scala | 223 ++++++++ .../streaming/util/WriteAheadLogUtils.scala | 21 +- .../streaming/util/WriteAheadLogSuite.scala | 506 ++++++++++++------ .../util/WriteAheadLogUtilsSuite.scala | 122 +++++ 6 files changed, 767 insertions(+), 192 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index f2711d1355e60..500dc70c98506 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -22,12 +22,13 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} +import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} import org.apache.spark.{Logging, SparkConf} @@ -41,7 +42,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: private[streaming] case class BatchCleanupEvent(times: Seq[Time]) extends ReceivedBlockTrackerLogEvent - /** Class representing the blocks of all the streams allocated to a batch */ private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { @@ -82,15 +82,22 @@ private[streaming] class ReceivedBlockTracker( } /** Add received block. This event will get written to the write ahead log (if enabled). */ - def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { try { - writeToLog(BlockAdditionEvent(receivedBlockInfo)) - getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug(s"Stream ${receivedBlockInfo.streamId} received " + - s"block ${receivedBlockInfo.blockStoreResult.blockId}") - true + val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo)) + if (writeResult) { + synchronized { + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + } else { + logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " + + s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.") + } + writeResult } catch { - case e: Exception => + case NonFatal(e) => logError(s"Error adding block $receivedBlockInfo", e) false } @@ -106,10 +113,12 @@ private[streaming] class ReceivedBlockTracker( (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) - writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) - timeToAllocatedBlocks(batchTime) = allocatedBlocks - lastAllocatedBatchTime = batchTime - allocatedBlocks + if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime + } else { + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + } } else { // This situation occurs when: // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, @@ -157,9 +166,12 @@ private[streaming] class ReceivedBlockTracker( require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) - writeToLog(BatchCleanupEvent(timesToCleanup)) - timeToAllocatedBlocks --= timesToCleanup - writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + if (writeToLog(BatchCleanupEvent(timesToCleanup))) { + timeToAllocatedBlocks --= timesToCleanup + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + } else { + logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.") + } } /** Stop the block tracker. */ @@ -185,8 +197,8 @@ private[streaming] class ReceivedBlockTracker( logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } - lastAllocatedBatchTime = batchTime timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime } // Cleanup the batch allocations @@ -213,12 +225,20 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { - logDebug(s"Writing to log $record") - writeAheadLogOption.foreach { logManager => - logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) + logTrace(s"Writing record: $record") + try { + writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)), + clock.getTimeMillis()) + true + } catch { + case NonFatal(e) => + logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e) + false } + } else { + true } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b183d856f50c3..ea5d12b50fcc5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.language.existentials import scala.util.{Failure, Success} @@ -437,7 +437,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) + + private val walBatchingThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) + + @volatile private var active: Boolean = true override def receive: PartialFunction[Any, Unit] = { // Local messages @@ -488,7 +493,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => - context.reply(addBlock(receivedBlockInfo)) + if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) { + walBatchingThreadPool.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + if (active) { + context.reply(addBlock(receivedBlockInfo)) + } else { + throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + } + } + }) + } else { + context.reply(addBlock(receivedBlockInfo)) + } case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) @@ -599,6 +616,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def onStop(): Unit = { submitJobThreadPool.shutdownNow() + active = false + walBatchingThreadPool.shutdown() } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala new file mode 100644 index 0000000000000..9727ed2ba1445 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils + +/** + * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation + * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle + * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned + * after the write will contain the batch of records rather than individual records. + * + * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp + * of the latest record in the batch. This is very important in achieving correctness. Consider the + * following example: + * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive + * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding + * 5 and 7. + * + * This means the caller can assume the same write semantics as any other WriteAheadLog + * implementation despite the batching in the background - when the write() returns, the data is + * written to the WAL and is durable. To take advantage of the batching, the caller can write from + * multiple threads, each of which will stay blocked until the corresponding data has been written. + * + * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog. + */ +private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf) + extends WriteAheadLog with Logging { + + import BatchedWriteAheadLog._ + + private val walWriteQueue = new LinkedBlockingQueue[Record]() + + // Whether the writer thread is active + @volatile private var active: Boolean = true + private val buffer = new ArrayBuffer[Record]() + + private val batchedWriterThread = startBatchedWriterThread() + + /** + * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks + * until the record is properly written by the parent. + */ + override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + val promise = Promise[WriteAheadLogRecordHandle]() + val putSuccessfully = synchronized { + if (active) { + walWriteQueue.offer(Record(byteBuffer, time, promise)) + true + } else { + false + } + } + if (putSuccessfully) { + Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + } else { + throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + + s"write request with time $time could be fulfilled.") + } + } + + /** + * This method is not supported as the resulting ByteBuffer would actually require de-aggregation. + * This method is primarily used in testing, and to ensure that it is not used in production, + * we throw an UnsupportedOperationException. + */ + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " + + "as the data may require de-aggregation.") + } + + /** + * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog + * will be de-aggregated. + */ + override def readAll(): JIterator[ByteBuffer] = { + wrappedLog.readAll().asScala.flatMap(deaggregate).asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * This method is handled by the parent WriteAheadLog. + */ + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wrappedLog.clean(threshTime, waitForCompletion) + } + + + /** + * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL. + */ + override def close(): Unit = { + logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") + synchronized { + active = false + } + batchedWriterThread.interrupt() + batchedWriterThread.join() + while (!walWriteQueue.isEmpty) { + val Record(_, time, promise) = walWriteQueue.poll() + promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " + + s"before write request with time $time could be fulfilled.")) + } + wrappedLog.close() + } + + /** Start the actual log writer on a separate thread. */ + private def startBatchedWriterThread(): Thread = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + while (active) { + try { + flushRecords() + } catch { + case NonFatal(e) => + logWarning("Encountered exception in Batched Writer Thread.", e) + } + } + logInfo("BatchedWriteAheadLog Writer thread exiting.") + } + }, "BatchedWriteAheadLog Writer") + thread.setDaemon(true) + thread.start() + thread + } + + /** Write all the records in the buffer to the write ahead log. */ + private def flushRecords(): Unit = { + try { + buffer.append(walWriteQueue.take()) + val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1 + logDebug(s"Received $numBatched records from queue") + } catch { + case _: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.") + } + try { + var segment: WriteAheadLogRecordHandle = null + if (buffer.length > 0) { + logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // We take the latest record for the timestamp. Please refer to the class Javadoc for + // detailed explanation + val time = buffer.last.time + segment = wrappedLog.write(aggregate(buffer), time) + } + buffer.foreach(_.promise.success(segment)) + } catch { + case e: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.", e) + buffer.foreach(_.promise.failure(e)) + case NonFatal(e) => + logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e) + buffer.foreach(_.promise.failure(e)) + } finally { + buffer.clear() + } + } +} + +/** Static methods for aggregating and de-aggregating records. */ +private[util] object BatchedWriteAheadLog { + + /** + * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled + * with the timestamp for the write request of the record, and the promise that will block the + * write request, while a separate thread is actually performing the write. + */ + case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) + + /** Copies the byte array of a ByteBuffer. */ + private def getByteArray(buffer: ByteBuffer): Array[Byte] = { + val byteArray = new Array[Byte](buffer.remaining()) + buffer.get(byteArray) + byteArray + } + + /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ + def aggregate(records: Seq[Record]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( + records.map(record => getByteArray(record.data)).toArray)) + } + + /** + * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. + * A stream may not have used batching initially, but started using it after a restart. This + * method therefore needs to be backwards compatible. + */ + def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + try { + Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + } catch { + case _: ClassCastException => // users may restart a stream with batching enabled + Array(buffer) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 0ea970e61b694..731a369fc92c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -38,6 +38,8 @@ private[streaming] object WriteAheadLogUtils extends Logging { val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching" + val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout" val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" @@ -64,6 +66,18 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = false) + } + + /** + * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records + * before we fail the write attempt to unblock receivers. + */ + def getBatchingTimeout(conf: SparkConf): Long = { + conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000) + } + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { if (isDriver) { conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) @@ -115,7 +129,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } else { sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) } - classNameOption.map { className => + val wal = classNameOption.map { className => try { instantiateClass( Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) @@ -128,6 +142,11 @@ private[streaming] object WriteAheadLogUtils extends Logging { getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), shouldCloseFileAfterWrite(sparkConf, isDriver)) } + if (isBatchingEnabled(sparkConf, isDriver)) { + new BatchedWriteAheadLog(wal, sparkConf) + } else { + wal + } } /** Instantiate the class, either using single arg constructor or zero arg constructor */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 93ae41a3d2ecd..e96f4c2a29347 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,31 +18,47 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util +import java.util.concurrent.{ExecutionException, ThreadPoolExecutor} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.reflect.ClassTag +import scala.util.{Failure, Success} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{eq => meq} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.mock.MockitoSugar -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{ThreadUtils, ManualClock, Utils} +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} -class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { +/** Common tests for WriteAheadLogs that we would like to test with different configurations. */ +abstract class CommonWriteAheadLogTests( + allowBatching: Boolean, + closeFileAfterWrite: Boolean, + testTag: String = "") + extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - val hadoopConf = new Configuration() - var tempDir: File = null - var testDir: String = null - var testFile: String = null - var writeAheadLog: FileBasedWriteAheadLog = null + protected val hadoopConf = new Configuration() + protected var tempDir: File = null + protected var testDir: String = null + protected var testFile: String = null + protected var writeAheadLog: WriteAheadLog = null + protected def testPrefix = if (testTag != "") testTag + " - " else testTag before { tempDir = Utils.createTempDir() @@ -58,49 +74,130 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogUtils - log selection and creation") { - val logDir = Utils.createTempDir().getAbsolutePath() + test(testPrefix + "read all logs") { + // Write data manually for testing reading through WriteAheadLog + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten - def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + val logDirectoryPath = new Path(testDir) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + assert(fileSystem.exists(logDirectoryPath) === true) + + // Read data using manager and verify + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === writtenData) + } + + test(testPrefix + "write logs") { + // Write data with rotation using WriteAheadLog class + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite, + allowBatching = allowBatching) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val writtenData = readAndDeserializeDataManually(logFiles, allowBatching) + assert(writtenData === dataToWrite) + } + + test(testPrefix + "read all logs after write") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(dataToWrite === readData) + } + + test(testPrefix + "clean old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test(testPrefix + "clean old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { + // Write data with manager, recover with new manager and verify + val manualClock = new ManualClock + val dataToWrite = generateRandomData() + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, + allowBatching, manualClock, closeLog = false) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + + if (waitForCompletion) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(Eventually.timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } } + } - def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + test(testPrefix + "handling file errors while reading rotating logs") { + // Generate a set of log files + val manualClock = new ManualClock + val dataToWrite1 = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles1 = getLogFilesInDirectory(testDir) + assert(logFiles1.size > 1) + + + // Recover old files and generate a second set of log files + val dataToWrite2 = generateRandomData() + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching , + manualClock) + val logFiles2 = getLogFilesInDirectory(testDir) + assert(logFiles2.size > logFiles1.size) + + // Read the files and verify that all the written data can be read + val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + + // Corrupt the first set of files so that they are basically unreadable + logFiles1.foreach { f => + val raf = new FileOutputStream(f, true).getChannel() + raf.truncate(1) + raf.close() } - val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) - assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) - - // Verify setting driver WAL class - val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](conf1) - assertReceiverLogClass[FileBasedWriteAheadLog](conf1) - - // Verify setting receiver WAL class - val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) - assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) - - // Verify setting receiver WAL class with 1-arg constructor - val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog1].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) - - // Verify failure setting receiver WAL class with 2-arg constructor - intercept[SparkException] { - val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog2].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + // Verify that the corrupted files do not prevent reading of the second set of data + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === dataToWrite2) + } + + test(testPrefix + "do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + if (allowBatching) { + intercept[UnsupportedOperationException](wal.read(writtenSegment.head)) + } else { + wal.read(writtenSegment.head) } + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") } +} + +class FileBasedWriteAheadLogSuite + extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") { + + import WriteAheadLogSuite._ test("FileBasedWriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() @@ -122,7 +219,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() - writeDataManually(writtenData, testFile) + writeDataManually(writtenData, testFile, allowBatching = false) val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) @@ -166,7 +263,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() - val segments = writeDataManually(writtenData, testFile) + val segments = writeDataManually(writtenData, testFile, allowBatching = false) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten @@ -190,163 +287,212 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { } reader.close() } +} - test("FileBasedWriteAheadLog - write rotating logs") { - // Write data with rotation using WriteAheadLog class - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - - // Read data manually to verify the written data - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val writtenData = logFiles.flatMap { file => readDataManually(file)} - assert(writtenData === dataToWrite) - } +abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String) + extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) { - test("FileBasedWriteAheadLog - close after write flag") { + import WriteAheadLogSuite._ + test(testPrefix + "close after write flag") { // Write data with rotation using WriteAheadLog class val numFiles = 3 val dataToWrite = Seq.tabulate(numFiles)(_.toString) // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, - closeFileAfterWrite = true) + closeFileAfterWrite = true, allowBatching = allowBatching) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size === numFiles) - val writtenData = logFiles.flatMap { file => readDataManually(file)} + val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching) assert(writtenData === dataToWrite) } +} - test("FileBasedWriteAheadLog - read rotating logs") { - // Write data manually for testing reading through WriteAheadLog - val writtenData = (1 to 10).map { i => - val data = generateRandomData() - val file = testDir + s"/log-$i-$i" - writeDataManually(data, file) - data - }.flatten +class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog") - val logDirectoryPath = new Path(testDir) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - assert(fileSystem.exists(logDirectoryPath) === true) +class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( + allowBatching = true, + closeFileAfterWrite = false, + "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually { - // Read data using manager and verify - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === writtenData) - } + import BatchedWriteAheadLog._ + import WriteAheadLogSuite._ - test("FileBasedWriteAheadLog - recover past logs when creating new manager") { - // Write data with manager, recover with new manager and verify - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val readData = readDataUsingWriteAheadLog(testDir) - assert(dataToWrite === readData) + private var wal: WriteAheadLog = _ + private var walHandle: WriteAheadLogRecordHandle = _ + private var walBatchingThreadPool: ThreadPoolExecutor = _ + private var walBatchingExecutionContext: ExecutionContextExecutorService = _ + private val sparkConf = new SparkConf() + + override def beforeEach(): Unit = { + wal = mock[WriteAheadLog] + walHandle = mock[WriteAheadLogRecordHandle] + walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") + walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool) } - test("FileBasedWriteAheadLog - clean old logs") { - logCleanUpTest(waitForCompletion = false) + override def afterEach(): Unit = { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() + } } - test("FileBasedWriteAheadLog - clean old logs synchronously") { - logCleanUpTest(waitForCompletion = true) - } + test("BatchedWriteAheadLog - serializing and deserializing batched records") { + val events = Seq( + BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)), + BatchAllocationEvent(null, null), + BatchCleanupEvent(Nil) + ) - private def logCleanUpTest(waitForCompletion: Boolean): Unit = { - // Write data with manager, recover with new manager and verify - val manualClock = new ManualClock - val dataToWrite = generateRandomData() - writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) + val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null)) + val batched = BatchedWriteAheadLog.aggregate(buffers) + val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array())) - writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + assert(deaggregate.toSeq === events) + } - if (waitForCompletion) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } else { - eventually(timeout(1 second), interval(10 milliseconds)) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } + test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") { + when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!")) + // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + + intercept[RuntimeException] { + val buffer = mock[ByteBuffer] + batchedWal.write(buffer, 2L) } } - test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { - // Generate a set of log files - val manualClock = new ManualClock - val dataToWrite1 = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) - val logFiles1 = getLogFilesInDirectory(testDir) - assert(logFiles1.size > 1) + // we make the write requests in separate threads so that we don't block the test thread + private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + val p = Promise[Unit]() + p.completeWith(Future { + val v = wal.write(event, time) + assert(v === walHandle) + }(walBatchingExecutionContext)) + p + } + /** + * In order to block the writes on the writer thread, we mock the write method, and block it + * for some time with a promise. + */ + private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = { + // we would like to block the write so that we can queue requests + val promise = Promise[Any]() + when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( + new Answer[WriteAheadLogRecordHandle] { + override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { + Await.ready(promise.future, 4.seconds) + walHandle + } + } + ) + promise + } - // Recover old files and generate a second set of log files - val dataToWrite2 = generateRandomData() - manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) - val logFiles2 = getLogFilesInDirectory(testDir) - assert(logFiles2.size > logFiles1.size) + test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + // block the write so that we can batch some records + val promise = writeBlockingPromise(wal) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + val event4 = "is" + val event5 = "doge" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + promiseWriteEvent(batchedWal, event1, 3L) + // rest of the records will be batched while it takes 3 to get written + promiseWriteEvent(batchedWal, event2, 5L) + promiseWriteEvent(batchedWal, event3, 8L) + promiseWriteEvent(batchedWal, event4, 12L) + promiseWriteEvent(batchedWal, event5, 10L) + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + } + promise.success(true) - // Read the files and verify that all the written data can be read - val readData1 = readDataUsingWriteAheadLog(testDir) - assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + val buffer1 = wrapArrayArrayByte(Array(event1)) + val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) - // Corrupt the first set of files so that they are basically unreadable - logFiles1.foreach { f => - val raf = new FileOutputStream(f, true).getChannel() - raf.truncate(1) - raf.close() + eventually(timeout(1 second)) { + verify(wal, times(1)).write(meq(buffer1), meq(3L)) + // the file name should be the timestamp of the last record, as events should be naturally + // in order of timestamp, and we need the last element. + verify(wal, times(1)).write(meq(buffer2), meq(10L)) } - - // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === dataToWrite2) } - test("FileBasedWriteAheadLog - do not create directories or files unless write") { - val nonexistentTempPath = File.createTempFile("test", "") - nonexistentTempPath.delete() - assert(!nonexistentTempPath.exists()) + test("BatchedWriteAheadLog - shutdown properly") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + batchedWal.close() + verify(wal, times(1)).close() - val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath, - new Configuration(), 1, 1, closeFileAfterWrite = false) - assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") - wal.read(writtenSegment.head) - assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L)) } -} -object WriteAheadLogSuite { + test("BatchedWriteAheadLog - fail everything in queue during shutdown") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - class MockWriteAheadLog0() extends WriteAheadLog { - override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } - override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } - override def readAll(): util.Iterator[ByteBuffer] = { null } - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } - override def close(): Unit = { } - } + // block the write so that we can batch some records + writeBlockingPromise(wal) + + val event1 = ("hello", 3L) + val event2 = ("world", 5L) + val event3 = ("this", 8L) + val event4 = ("is", 9L) + val event5 = ("doge", 10L) + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + val writePromises = Seq(event1, event2, event3, event4, event5).map { event => + promiseWriteEvent(batchedWal, event._1, event._2) + } - class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + } + + batchedWal.close() + eventually(timeout(1 second)) { + assert(writePromises.forall(_.isCompleted)) + assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed + } + } +} - class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +class BatchedWriteAheadLogWithCloseFileAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog") +object WriteAheadLogSuite { private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + def writeDataManually( + data: Seq[String], + file: String, + allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = { val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) - data.foreach { item => + def writeToStream(bytes: Array[Byte]): Unit = { val offset = writer.getPos - val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } + if (allowBatching) { + writeToStream(wrapArrayArrayByte(data.toArray[String]).array()) + } else { + data.foreach { item => + writeToStream(Utils.serialize(item)) + } + } writer.close() segments } @@ -356,8 +502,7 @@ object WriteAheadLogSuite { */ def writeDataUsingWriter( filePath: String, - data: Seq[String] - ): Seq[FileBasedWriteAheadLogSegment] = { + data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = { val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) @@ -370,13 +515,13 @@ object WriteAheadLogSuite { def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], + closeFileAfterWrite: Boolean, + allowBatching: Boolean, manualClock: ManualClock = new ManualClock, closeLog: Boolean = true, - clockAdvanceTime: Int = 500, - closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = { + clockAdvanceTime: Int = 500): WriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite) + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => @@ -406,16 +551,16 @@ object WriteAheadLogSuite { } /** Read all the data from a log file directly and return the list of byte buffers. */ - def readDataManually(file: String): Seq[String] = { + def readDataManually[T](file: String): Seq[T] = { val reader = HdfsUtils.getInputStream(file, hadoopConf) - val buffer = new ArrayBuffer[String] + val buffer = new ArrayBuffer[T] try { while (true) { // Read till EOF is thrown val length = reader.readInt() val bytes = new Array[Byte](length) reader.read(bytes) - buffer += Utils.deserialize[String](bytes) + buffer += Utils.deserialize[T](bytes) } } catch { case ex: EOFException => @@ -434,15 +579,17 @@ object WriteAheadLogSuite { } /** Read all the data in the log file in a directory using the WriteAheadLog class. */ - def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite = false) + def readDataUsingWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): Seq[String] = { + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data } - /** Get the log files in a direction */ + /** Get the log files in a directory. */ def getLogFilesInDirectory(directory: String): Seq[String] = { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) @@ -458,10 +605,31 @@ object WriteAheadLogSuite { } } + def createWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): WriteAheadLog = { + val sparkConf = new SparkConf + val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) + if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal + } + def generateRandomData(): Seq[String] = { (1 to 100).map { _.toString } } + def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = { + if (allowBatching) { + logFiles.flatMap { file => + val data = readDataManually[Array[Array[Byte]]](file) + data.flatMap(byteArray => byteArray.map(Utils.deserialize[String])) + } + } else { + logFiles.flatMap { file => readDataManually[String](file)} + } + } + implicit def stringToByteBuffer(str: String): ByteBuffer = { ByteBuffer.wrap(Utils.serialize(str)) } @@ -469,4 +637,8 @@ object WriteAheadLogSuite { implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { Utils.deserialize[String](byteBuffer.array) } + + def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala new file mode 100644 index 0000000000000..9152728191ea1 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils + +class WriteAheadLogUtilsSuite extends SparkFunSuite { + import WriteAheadLogUtilsSuite._ + + private val logDir = Utils.createTempDir().getAbsolutePath() + private val hadoopConf = new Configuration() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag]( + conf: SparkConf, + isBatched: Boolean = false): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + if (isBatched) { + assert(log.isInstanceOf[BatchedWriteAheadLog]) + val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog + assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass) + } else { + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + } + log + } + + def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + test("log selection and creation") { + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") { + def getBatchedSparkConf: SparkConf = + new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true") + + val justBatchingConf = getBatchedSparkConf + assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf) + + // Verify setting driver WAL class + val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify receivers are not wrapped + val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + } +} + +object WriteAheadLogUtilsSuite { + + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +} From 1f0f14efe35f986e338ee2cbc1ef2a9ce7395c00 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 9 Nov 2015 17:38:19 -0800 Subject: [PATCH 265/324] [SPARK-11462][STREAMING] Add JavaStreamingListener Currently, StreamingListener is not Java friendly because it exposes some Scala collections to Java users directly, such as Option, Map. This PR added a Java version of StreamingListener and a bunch of Java friendly classes for Java users. Author: zsxwing Author: Shixiong Zhu Closes #9420 from zsxwing/java-streaming-listener. --- .../api/java/JavaStreamingListener.scala | 168 ++++++++++ .../java/JavaStreamingListenerWrapper.scala | 122 ++++++++ .../JavaStreamingListenerAPISuite.java | 85 +++++ .../JavaStreamingListenerWrapperSuite.scala | 290 ++++++++++++++++++ 4 files changed, 665 insertions(+) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala new file mode 100644 index 0000000000000..c86c7101ff6d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.Time + +/** + * A listener interface for receiving information about an ongoing streaming computation. + */ +private[streaming] class JavaStreamingListener { + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { } +} + +/** + * Base trait for events related to JavaStreamingListener + */ +private[streaming] sealed trait JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchCompleted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchStarted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationStarted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationCompleted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStarted(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverError(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStopped(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +/** + * Class having information on batches. + * + * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info + * @param submissionTime Clock time of when jobs of this batch was submitted to the streaming + * scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing. + * `-1` means the batch has not yet started + * @param processingEndTime Clock time of when the last job of this batch finished processing. `-1` + * means the batch has not yet completed. + * @param schedulingDelay Time taken for the first job of this batch to start processing from the + * time this batch was submitted to the streaming scheduler. Essentially, it + * is `processingStartTime` - `submissionTime`. `-1` means the batch has not + * yet started + * @param processingDelay Time taken for the all jobs of this batch to finish processing from the + * time they started processing. Essentially, it is + * `processingEndTime` - `processingStartTime`. `-1` means the batch has not + * yet completed. + * @param totalDelay Time taken for all the jobs of this batch to finish processing from the time + * they were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + * `-1` means the batch has not yet completed. + * @param numRecords The number of recorders received by the receivers in this batch + * @param outputOperationInfos The output operations in this batch + */ +private[streaming] case class JavaBatchInfo( + batchTime: Time, + streamIdToInputInfo: java.util.Map[Int, JavaStreamInputInfo], + submissionTime: Long, + processingStartTime: Long, + processingEndTime: Long, + schedulingDelay: Long, + processingDelay: Long, + totalDelay: Long, + numRecords: Long, + outputOperationInfos: java.util.Map[Int, JavaOutputOperationInfo]) + +/** + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + * @param metadataDescription description of this input stream + */ +private[streaming] case class JavaStreamInputInfo( + inputStreamId: Int, + numRecords: Long, + metadata: java.util.Map[String, Any], + metadataDescription: String) + +/** + * Class having information about a receiver + */ +private[streaming] case class JavaReceiverInfo( + streamId: Int, + name: String, + active: Boolean, + location: String, + lastErrorMessage: String, + lastError: String, + lastErrorTime: Long) + +/** + * Class having information on output operations. + * + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing. `-1` means the + * output operation has not yet started + * @param endTime Clock time of when the output operation started processing. `-1` means the output + * operation has not yet completed + * @param failureReason Failure reason if this output operation fails. If the output operation is + * successful, this field is `null`. + */ +private[streaming] case class JavaOutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Long, + endTime: Long, + failureReason: String) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala new file mode 100644 index 0000000000000..2c60b396a6616 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.streaming.scheduler._ + +/** + * A wrapper to convert a [[JavaStreamingListener]] to a [[StreamingListener]]. + */ +private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: JavaStreamingListener) + extends StreamingListener { + + private def toJavaReceiverInfo(receiverInfo: ReceiverInfo): JavaReceiverInfo = { + JavaReceiverInfo( + receiverInfo.streamId, + receiverInfo.name, + receiverInfo.active, + receiverInfo.location, + receiverInfo.lastErrorMessage, + receiverInfo.lastError, + receiverInfo.lastErrorTime + ) + } + + private def toJavaStreamInputInfo(streamInputInfo: StreamInputInfo): JavaStreamInputInfo = { + JavaStreamInputInfo( + streamInputInfo.inputStreamId, + streamInputInfo.numRecords: Long, + streamInputInfo.metadata.asJava, + streamInputInfo.metadataDescription.orNull + ) + } + + private def toJavaOutputOperationInfo( + outputOperationInfo: OutputOperationInfo): JavaOutputOperationInfo = { + JavaOutputOperationInfo( + outputOperationInfo.batchTime, + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description: String, + outputOperationInfo.startTime.getOrElse(-1), + outputOperationInfo.endTime.getOrElse(-1), + outputOperationInfo.failureReason.orNull + ) + } + + private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { + JavaBatchInfo( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.submissionTime, + batchInfo.processingStartTime.getOrElse(-1), + batchInfo.processingEndTime.getOrElse(-1), + batchInfo.schedulingDelay.getOrElse(-1), + batchInfo.processingDelay.getOrElse(-1), + batchInfo.totalDelay.getOrElse(-1), + batchInfo.numRecords, + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + ) + } + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + javaStreamingListener.onReceiverStarted( + new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) + } + + override def onReceiverError(receiverError: StreamingListenerReceiverError): Unit = { + javaStreamingListener.onReceiverError( + new JavaStreamingListenerReceiverError(toJavaReceiverInfo(receiverError.receiverInfo))) + } + + override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped): Unit = { + javaStreamingListener.onReceiverStopped( + new JavaStreamingListenerReceiverStopped(toJavaReceiverInfo(receiverStopped.receiverInfo))) + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + javaStreamingListener.onBatchSubmitted( + new JavaStreamingListenerBatchSubmitted(toJavaBatchInfo(batchSubmitted.batchInfo))) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + javaStreamingListener.onBatchStarted( + new JavaStreamingListenerBatchStarted(toJavaBatchInfo(batchStarted.batchInfo))) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + javaStreamingListener.onBatchCompleted( + new JavaStreamingListenerBatchCompleted(toJavaBatchInfo(batchCompleted.batchInfo))) + } + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + javaStreamingListener.onOutputOperationStarted(new JavaStreamingListenerOutputOperationStarted( + toJavaOutputOperationInfo(outputOperationStarted.outputOperationInfo))) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + javaStreamingListener.onOutputOperationCompleted( + new JavaStreamingListenerOutputOperationCompleted( + toJavaOutputOperationInfo(outputOperationCompleted.outputOperationInfo))) + } + +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java new file mode 100644 index 0000000000000..8cc285aa7fb34 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.*; + +public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + + @Override + public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { + JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { + JavaReceiverInfo receiverInfo = receiverError.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopped) { + JavaReceiverInfo receiverInfo = receiverStopped.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onBatchSubmitted(JavaStreamingListenerBatchSubmitted batchSubmitted) { + super.onBatchSubmitted(batchSubmitted); + } + + @Override + public void onBatchStarted(JavaStreamingListenerBatchStarted batchStarted) { + super.onBatchStarted(batchStarted); + } + + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + super.onBatchCompleted(batchCompleted); + } + + @Override + public void onOutputOperationStarted(JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + super.onOutputOperationStarted(outputOperationStarted); + } + + @Override + public void onOutputOperationCompleted(JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + super.onOutputOperationCompleted(outputOperationCompleted); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala new file mode 100644 index 0000000000000..6d6d61e70cafc --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler._ + +class JavaStreamingListenerWrapperSuite extends SparkFunSuite { + + test("basic") { + val listener = new TestJavaStreamingListener() + val listenerWrapper = new JavaStreamingListenerWrapper(listener) + + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( + streamId = 2, + name = "test", + active = true, + location = "localhost" + )) + listenerWrapper.onReceiverStarted(receiverStarted) + assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) + + val receiverStopped = StreamingListenerReceiverStopped(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost" + )) + listenerWrapper.onReceiverStopped(receiverStopped) + assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) + + val receiverError = StreamingListenerReceiverError(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + lastErrorMessage = "failed", + lastError = "failed", + lastErrorTime = System.currentTimeMillis() + )) + listenerWrapper.onReceiverError(receiverError) + assertReceiverInfo(listener.receiverError.receiverInfo, receiverError.receiverInfo) + + val batchSubmitted = StreamingListenerBatchSubmitted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + None, + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = None, + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = None, + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchSubmitted(batchSubmitted) + assertBatchInfo(listener.batchSubmitted.batchInfo, batchSubmitted.batchInfo) + + val batchStarted = StreamingListenerBatchStarted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchStarted(batchStarted) + assertBatchInfo(listener.batchStarted.batchInfo, batchStarted.batchInfo) + + val batchCompleted = StreamingListenerBatchCompleted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + Some(1010L), + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = Some(1010L), + failureReason = None)) + )) + listenerWrapper.onBatchCompleted(batchCompleted) + assertBatchInfo(listener.batchCompleted.batchInfo, batchCompleted.batchInfo) + + val outputOperationStarted = StreamingListenerOutputOperationStarted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None + )) + listenerWrapper.onOutputOperationStarted(outputOperationStarted) + assertOutputOperationInfo(listener.outputOperationStarted.outputOperationInfo, + outputOperationStarted.outputOperationInfo) + + val outputOperationCompleted = StreamingListenerOutputOperationCompleted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None + )) + listenerWrapper.onOutputOperationCompleted(outputOperationCompleted) + assertOutputOperationInfo(listener.outputOperationCompleted.outputOperationInfo, + outputOperationCompleted.outputOperationInfo) + } + + private def assertReceiverInfo( + javaReceiverInfo: JavaReceiverInfo, receiverInfo: ReceiverInfo): Unit = { + assert(javaReceiverInfo.streamId === receiverInfo.streamId) + assert(javaReceiverInfo.name === receiverInfo.name) + assert(javaReceiverInfo.active === receiverInfo.active) + assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) + assert(javaReceiverInfo.lastError === receiverInfo.lastError) + assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) + } + + private def assertBatchInfo(javaBatchInfo: JavaBatchInfo, batchInfo: BatchInfo): Unit = { + assert(javaBatchInfo.batchTime === batchInfo.batchTime) + assert(javaBatchInfo.streamIdToInputInfo.size === batchInfo.streamIdToInputInfo.size) + batchInfo.streamIdToInputInfo.foreach { case (streamId, streamInputInfo) => + assertStreamingInfo(javaBatchInfo.streamIdToInputInfo.get(streamId), streamInputInfo) + } + assert(javaBatchInfo.submissionTime === batchInfo.submissionTime) + assert(javaBatchInfo.processingStartTime === batchInfo.processingStartTime.getOrElse(-1)) + assert(javaBatchInfo.processingEndTime === batchInfo.processingEndTime.getOrElse(-1)) + assert(javaBatchInfo.schedulingDelay === batchInfo.schedulingDelay.getOrElse(-1)) + assert(javaBatchInfo.processingDelay === batchInfo.processingDelay.getOrElse(-1)) + assert(javaBatchInfo.totalDelay === batchInfo.totalDelay.getOrElse(-1)) + assert(javaBatchInfo.numRecords === batchInfo.numRecords) + assert(javaBatchInfo.outputOperationInfos.size === batchInfo.outputOperationInfos.size) + batchInfo.outputOperationInfos.foreach { case (outputOperationId, outputOperationInfo) => + assertOutputOperationInfo( + javaBatchInfo.outputOperationInfos.get(outputOperationId), outputOperationInfo) + } + } + + private def assertStreamingInfo( + javaStreamInputInfo: JavaStreamInputInfo, streamInputInfo: StreamInputInfo): Unit = { + assert(javaStreamInputInfo.inputStreamId === streamInputInfo.inputStreamId) + assert(javaStreamInputInfo.numRecords === streamInputInfo.numRecords) + assert(javaStreamInputInfo.metadata === streamInputInfo.metadata.asJava) + assert(javaStreamInputInfo.metadataDescription === streamInputInfo.metadataDescription.orNull) + } + + private def assertOutputOperationInfo( + javaOutputOperationInfo: JavaOutputOperationInfo, + outputOperationInfo: OutputOperationInfo): Unit = { + assert(javaOutputOperationInfo.batchTime === outputOperationInfo.batchTime) + assert(javaOutputOperationInfo.id === outputOperationInfo.id) + assert(javaOutputOperationInfo.name === outputOperationInfo.name) + assert(javaOutputOperationInfo.description === outputOperationInfo.description) + assert(javaOutputOperationInfo.startTime === outputOperationInfo.startTime.getOrElse(-1)) + assert(javaOutputOperationInfo.endTime === outputOperationInfo.endTime.getOrElse(-1)) + assert(javaOutputOperationInfo.failureReason === outputOperationInfo.failureReason.orNull) + } +} + +class TestJavaStreamingListener extends JavaStreamingListener { + + var receiverStarted: JavaStreamingListenerReceiverStarted = null + var receiverError: JavaStreamingListenerReceiverError = null + var receiverStopped: JavaStreamingListenerReceiverStopped = null + var batchSubmitted: JavaStreamingListenerBatchSubmitted = null + var batchStarted: JavaStreamingListenerBatchStarted = null + var batchCompleted: JavaStreamingListenerBatchCompleted = null + var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null + var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + this.receiverStarted = receiverStarted + } + + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + this.receiverError = receiverError + } + + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + this.receiverStopped = receiverStopped + } + + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + this.batchSubmitted = batchSubmitted + } + + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + this.batchStarted = batchStarted + } + + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + this.batchCompleted = batchCompleted + } + + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + this.outputOperationStarted = outputOperationStarted + } + + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + this.outputOperationCompleted = outputOperationCompleted + } +} From 6502944f39893b9dfb472f8406d5f3a02a316eff Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 9 Nov 2015 18:13:37 -0800 Subject: [PATCH 266/324] [SPARK-11333][STREAMING] Add executorId to ReceiverInfo and display it in UI Expose executorId to `ReceiverInfo` and UI since it's helpful when there are multiple executors running in the same host. Screenshot: screen shot 2015-11-02 at 10 52 19 am Author: Shixiong Zhu Author: zsxwing Closes #9418 from zsxwing/SPARK-11333. --- .../spark/streaming/api/java/JavaStreamingListener.scala | 1 + .../streaming/api/java/JavaStreamingListenerWrapper.scala | 1 + .../apache/spark/streaming/scheduler/ReceiverInfo.scala | 1 + .../spark/streaming/scheduler/ReceiverTrackingInfo.scala | 1 + .../org/apache/spark/streaming/ui/StreamingPage.scala | 8 ++++++-- .../spark/streaming/JavaStreamingListenerAPISuite.java | 3 +++ .../api/java/JavaStreamingListenerWrapperSuite.scala | 8 ++++++-- .../streaming/ui/StreamingJobProgressListenerSuite.scala | 6 +++--- 8 files changed, 22 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala index c86c7101ff6d5..34429074fe804 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -140,6 +140,7 @@ private[streaming] case class JavaReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String, lastError: String, lastErrorTime: Long) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala index 2c60b396a6616..b109b9f1cbeae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -33,6 +33,7 @@ private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: Jav receiverInfo.name, receiverInfo.active, receiverInfo.location, + receiverInfo.executorId, receiverInfo.lastErrorMessage, receiverInfo.lastError, receiverInfo.lastErrorTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index 59df892397fe0..3b35964114c02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -30,6 +30,7 @@ case class ReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index ab0a84f05214d..4dc5bb9c3bfbe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -49,6 +49,7 @@ private[streaming] case class ReceiverTrackingInfo( name.getOrElse(""), state == ReceiverState.ACTIVE, location = runningExecutor.map(_.host).getOrElse(""), + executorId = runningExecutor.map(_.executorId).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 96d943e75d272..4588b2163cd44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -402,7 +402,7 @@ private[ui] class StreamingPage(parent: StreamingTab)

    Status
    Location
    Executor ID / Host
    Last Error Time
    Last Error Message