diff --git a/LICENSE b/LICENSE
index d6b9ccf07d999..9d1b00beff748 100644
--- a/LICENSE
+++ b/LICENSE
@@ -861,7 +861,7 @@ The following components are provided under a BSD-style license. See project lin
(BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model)
- (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/)
+ (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
(BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org)
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 819e9a24e5c0e..64ffdcffc9caf 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -37,7 +37,7 @@ exportMethods("arrange",
"registerTempTable",
"rename",
"repartition",
- "sampleDF",
+ "sample",
"sample_frac",
"saveAsParquetFile",
"saveAsTable",
@@ -53,38 +53,62 @@ exportMethods("arrange",
"unpersist",
"where",
"withColumn",
- "withColumnRenamed")
+ "withColumnRenamed",
+ "write.df")
exportClasses("Column")
exportMethods("abs",
+ "acos",
"alias",
"approxCountDistinct",
"asc",
+ "asin",
+ "atan",
+ "atan2",
"avg",
"cast",
+ "cbrt",
+ "ceiling",
"contains",
+ "cos",
+ "cosh",
"countDistinct",
"desc",
"endsWith",
+ "exp",
+ "expm1",
+ "floor",
"getField",
"getItem",
+ "hypot",
"isNotNull",
"isNull",
"last",
"like",
+ "log",
+ "log10",
+ "log1p",
"lower",
"max",
"mean",
"min",
"n",
"n_distinct",
+ "rint",
"rlike",
+ "sign",
+ "sin",
+ "sinh",
"sqrt",
"startsWith",
"substr",
"sum",
"sumDistinct",
+ "tan",
+ "tanh",
+ "toDegrees",
+ "toRadians",
"upper")
exportClasses("GroupedData")
@@ -101,6 +125,7 @@ export("cacheTable",
"jsonFile",
"loadDF",
"parquetFile",
+ "read.df",
"sql",
"table",
"tableNames",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 2705817531019..a7fa32e291fb1 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -294,8 +294,8 @@ setMethod("registerTempTable",
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
-#' df <- loadDF(sqlCtx, path, "parquet")
-#' df2 <- loadDF(sqlCtx, path2, "parquet")
+#' df <- read.df(sqlCtx, path, "parquet")
+#' df2 <- read.df(sqlCtx, path2, "parquet")
#' registerTempTable(df, "table1")
#' insertInto(df2, "table1", overwrite = TRUE)
#'}
@@ -473,14 +473,14 @@ setMethod("distinct",
dataFrame(sdf)
})
-#' SampleDF
+#' Sample
#'
#' Return a sampled subset of this DataFrame using a random seed.
#'
#' @param x A SparkSQL DataFrame
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
-#' @rdname sampleDF
+#' @rdname sample
#' @aliases sample_frac
#' @export
#' @examples
@@ -489,10 +489,10 @@ setMethod("distinct",
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
-#' collect(sampleDF(df, FALSE, 0.5))
-#' collect(sampleDF(df, TRUE, 0.5))
+#' collect(sample(df, FALSE, 0.5))
+#' collect(sample(df, TRUE, 0.5))
#'}
-setMethod("sampleDF",
+setMethod("sample",
# TODO : Figure out how to send integer as java.lang.Long to JVM so
# we can send seed as an argument through callJMethod
signature(x = "DataFrame", withReplacement = "logical",
@@ -503,13 +503,13 @@ setMethod("sampleDF",
dataFrame(sdf)
})
-#' @rdname sampleDF
-#' @aliases sampleDF
+#' @rdname sample
+#' @aliases sample
setMethod("sample_frac",
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
function(x, withReplacement, fraction) {
- sampleDF(x, withReplacement, fraction)
+ sample(x, withReplacement, fraction)
})
#' Count
@@ -1303,7 +1303,7 @@ setMethod("except",
#' @param source A name for external data source
#' @param mode One of 'append', 'overwrite', 'error', 'ignore'
#'
-#' @rdname saveAsTable
+#' @rdname write.df
#' @export
#' @examples
#'\dontrun{
@@ -1311,9 +1311,9 @@ setMethod("except",
#' sqlCtx <- sparkRSQL.init(sc)
#' path <- "path/to/file.json"
#' df <- jsonFile(sqlCtx, path)
-#' saveAsTable(df, "myfile")
+#' write.df(df, "myfile", "parquet", "overwrite")
#' }
-setMethod("saveDF",
+setMethod("write.df",
signature(df = "DataFrame", path = 'character', source = 'character',
mode = 'character'),
function(df, path = NULL, source = NULL, mode = "append", ...){
@@ -1334,6 +1334,15 @@ setMethod("saveDF",
callJMethod(df@sdf, "save", source, jmode, options)
})
+#' @rdname write.df
+#' @aliases saveDF
+#' @export
+setMethod("saveDF",
+ signature(df = "DataFrame", path = 'character', source = 'character',
+ mode = 'character'),
+ function(df, path = NULL, source = NULL, mode = "append", ...){
+ write.df(df, path, source, mode, ...)
+ })
#' saveAsTable
#'
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 9138629cac9c0..d3a68fff780ce 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -927,7 +927,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
MAXINT)))))
# TODO(zongheng): investigate if this call is an in-place shuffle?
- sample(samples)[1:total]
+ base::sample(samples)[1:total]
})
# Creates tuples of the elements in this RDD by applying a function.
@@ -996,7 +996,7 @@ setMethod("coalesce",
if (shuffle || numPartitions > SparkR:::numPartitions(x)) {
func <- function(partIndex, part) {
set.seed(partIndex) # partIndex as seed
- start <- as.integer(sample(numPartitions, 1) - 1)
+ start <- as.integer(base::sample(numPartitions, 1) - 1)
lapply(seq_along(part),
function(i) {
pos <- (start + i) %% numPartitions
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index cae06e6af2bff..531442e8459e4 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -421,7 +421,7 @@ clearCache <- function(sqlCtx) {
#' \dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
-#' df <- loadDF(sqlCtx, path, "parquet")
+#' df <- read.df(sqlCtx, path, "parquet")
#' registerTempTable(df, "table")
#' dropTempTable(sqlCtx, "table")
#' }
@@ -450,10 +450,10 @@ dropTempTable <- function(sqlCtx, tableName) {
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
-#' df <- load(sqlCtx, "path/to/file.json", source = "json")
+#' df <- read.df(sqlCtx, "path/to/file.json", source = "json")
#' }
-loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) {
+read.df <- function(sqlCtx, path = NULL, source = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
options[['path']] <- path
@@ -462,6 +462,13 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) {
dataFrame(sdf)
}
+#' @aliases loadDF
+#' @export
+
+loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) {
+ read.df(sqlCtx, path, source, ...)
+}
+
#' Create an external table
#'
#' Creates an external table based on the dataset in a data source,
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 9a68445ab451a..80e92d3105a36 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -55,12 +55,17 @@ operators <- list(
"+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod",
"==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq",
# we can not override `&&` and `||`, so use `&` and `|` instead
- "&" = "and", "|" = "or" #, "!" = "unary_$bang"
+ "&" = "and", "|" = "or", #, "!" = "unary_$bang"
+ "^" = "pow"
)
column_functions1 <- c("asc", "desc", "isNull", "isNotNull")
column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains")
functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt",
- "first", "last", "lower", "upper", "sumDistinct")
+ "first", "last", "lower", "upper", "sumDistinct",
+ "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp",
+ "expm1", "floor", "log", "log10", "log1p", "rint", "sign",
+ "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians")
+binary_mathfunctions<- c("atan2", "hypot")
createOperator <- function(op) {
setMethod(op,
@@ -76,7 +81,11 @@ createOperator <- function(op) {
if (class(e2) == "Column") {
e2 <- e2@jc
}
- callJMethod(e1@jc, operators[[op]], e2)
+ if (op == "^") {
+ jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2)
+ } else {
+ callJMethod(e1@jc, operators[[op]], e2)
+ }
}
column(jc)
})
@@ -106,11 +115,29 @@ createStaticFunction <- function(name) {
setMethod(name,
signature(x = "Column"),
function(x) {
+ if (name == "ceiling") {
+ name <- "ceil"
+ }
+ if (name == "sign") {
+ name <- "signum"
+ }
jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
column(jc)
})
}
+createBinaryMathfunctions <- function(name) {
+ setMethod(name,
+ signature(y = "Column"),
+ function(y, x) {
+ if (class(x) == "Column") {
+ x <- x@jc
+ }
+ jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x)
+ column(jc)
+ })
+}
+
createMethods <- function() {
for (op in names(operators)) {
createOperator(op)
@@ -124,6 +151,9 @@ createMethods <- function() {
for (x in functions) {
createStaticFunction(x)
}
+ for (name in binary_mathfunctions) {
+ createBinaryMathfunctions(name)
+ }
}
createMethods()
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 557128a419f19..a23d3b217b2fd 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -456,19 +456,19 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") })
#' @export
setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
-#' @rdname sampleDF
+#' @rdname sample
#' @export
-setGeneric("sample_frac",
+setGeneric("sample",
function(x, withReplacement, fraction, seed) {
- standardGeneric("sample_frac")
- })
+ standardGeneric("sample")
+ })
-#' @rdname sampleDF
+#' @rdname sample
#' @export
-setGeneric("sampleDF",
+setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) {
- standardGeneric("sampleDF")
- })
+ standardGeneric("sample_frac")
+ })
#' @rdname saveAsParquetFile
#' @export
@@ -480,7 +480,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
standardGeneric("saveAsTable")
})
-#' @rdname saveAsTable
+#' @rdname write.df
+#' @export
+setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") })
+
+#' @rdname write.df
#' @export
setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") })
@@ -548,6 +552,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") })
#' @export
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
+#' @rdname column
+#' @export
+setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
+
#' @rdname column
#' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") })
@@ -571,6 +579,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") })
#' @export
setGeneric("getItem", function(x, ...) { standardGeneric("getItem") })
+#' @rdname column
+#' @export
+setGeneric("hypot", function(y, x) { standardGeneric("hypot") })
+
#' @rdname column
#' @export
setGeneric("isNull", function(x) { standardGeneric("isNull") })
@@ -599,6 +611,10 @@ setGeneric("n", function(x) { standardGeneric("n") })
#' @export
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
+#' @rdname column
+#' @export
+setGeneric("rint", function(x, ...) { standardGeneric("rint") })
+
#' @rdname column
#' @export
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
@@ -611,6 +627,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") })
#' @export
setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
+#' @rdname column
+#' @export
+setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") })
+
+#' @rdname column
+#' @export
+setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
+
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 99c28830c6237..3e5658eb5b24b 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -209,18 +209,18 @@ test_that("registerTempTable() results in a queryable table and sql() results in
})
test_that("insertInto() on a registered table", {
- df <- loadDF(sqlCtx, jsonPath, "json")
- saveDF(df, parquetPath, "parquet", "overwrite")
- dfParquet <- loadDF(sqlCtx, parquetPath, "parquet")
+ df <- read.df(sqlCtx, jsonPath, "json")
+ write.df(df, parquetPath, "parquet", "overwrite")
+ dfParquet <- read.df(sqlCtx, parquetPath, "parquet")
lines <- c("{\"name\":\"Bob\", \"age\":24}",
"{\"name\":\"James\", \"age\":35}")
jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp")
parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
writeLines(lines, jsonPath2)
- df2 <- loadDF(sqlCtx, jsonPath2, "json")
- saveDF(df2, parquetPath2, "parquet", "overwrite")
- dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet")
+ df2 <- read.df(sqlCtx, jsonPath2, "json")
+ write.df(df2, parquetPath2, "parquet", "overwrite")
+ dfParquet2 <- read.df(sqlCtx, parquetPath2, "parquet")
registerTempTable(dfParquet, "table1")
insertInto(dfParquet2, "table1")
@@ -421,12 +421,12 @@ test_that("distinct() on DataFrames", {
expect_true(count(uniques) == 3)
})
-test_that("sampleDF on a DataFrame", {
+test_that("sample on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
- sampled <- sampleDF(df, FALSE, 1.0)
+ sampled <- sample(df, FALSE, 1.0)
expect_equal(nrow(collect(sampled)), count(df))
expect_true(inherits(sampled, "DataFrame"))
- sampled2 <- sampleDF(df, FALSE, 0.1)
+ sampled2 <- sample(df, FALSE, 0.1)
expect_true(count(sampled2) < 3)
# Also test sample_frac
@@ -491,16 +491,16 @@ test_that("column calculation", {
expect_true(count(df2) == 3)
})
-test_that("load() from json file", {
- df <- loadDF(sqlCtx, jsonPath, "json")
+test_that("read.df() from json file", {
+ df <- read.df(sqlCtx, jsonPath, "json")
expect_true(inherits(df, "DataFrame"))
expect_true(count(df) == 3)
})
-test_that("save() as parquet file", {
- df <- loadDF(sqlCtx, jsonPath, "json")
- saveDF(df, parquetPath, "parquet", mode="overwrite")
- df2 <- loadDF(sqlCtx, parquetPath, "parquet")
+test_that("write.df() as parquet file", {
+ df <- read.df(sqlCtx, jsonPath, "json")
+ write.df(df, parquetPath, "parquet", mode="overwrite")
+ df2 <- read.df(sqlCtx, parquetPath, "parquet")
expect_true(inherits(df2, "DataFrame"))
expect_true(count(df2) == 3)
})
@@ -530,6 +530,7 @@ test_that("column operators", {
c2 <- (- c + 1 - 2) * 3 / 4.0
c3 <- (c + c2 - c2) * c2 %% c2
c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3)
+ c5 <- c2 ^ c3 ^ c4
})
test_that("column functions", {
@@ -538,6 +539,29 @@ test_that("column functions", {
c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
c5 <- n(c) + n_distinct(c)
+ c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
+ c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c)
+ c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
+ c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
+ c9 <- toDegrees(c) + toRadians(c)
+})
+
+test_that("column binary mathfunctions", {
+ lines <- c("{\"a\":1, \"b\":5}",
+ "{\"a\":2, \"b\":6}",
+ "{\"a\":3, \"b\":7}",
+ "{\"a\":4, \"b\":8}")
+ jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(lines, jsonPathWithDup)
+ df <- jsonFile(sqlCtx, jsonPathWithDup)
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
})
test_that("string operators", {
@@ -670,7 +694,7 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
"{\"name\":\"James\", \"age\":35}")
jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
writeLines(lines, jsonPath2)
- df2 <- loadDF(sqlCtx, jsonPath2, "json")
+ df2 <- read.df(sqlCtx, jsonPath2, "json")
unioned <- arrange(unionAll(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
@@ -712,9 +736,9 @@ test_that("mutate() and rename()", {
expect_true(columns(newDF2)[1] == "newerAge")
})
-test_that("saveDF() on DataFrame and works with parquetFile", {
+test_that("write.df() on DataFrame and works with parquetFile", {
df <- jsonFile(sqlCtx, jsonPath)
- saveDF(df, parquetPath, "parquet", mode="overwrite")
+ write.df(df, parquetPath, "parquet", mode="overwrite")
parquetDF <- parquetFile(sqlCtx, parquetPath)
expect_true(inherits(parquetDF, "DataFrame"))
expect_equal(count(df), count(parquetDF))
@@ -722,9 +746,9 @@ test_that("saveDF() on DataFrame and works with parquetFile", {
test_that("parquetFile works with multiple input paths", {
df <- jsonFile(sqlCtx, jsonPath)
- saveDF(df, parquetPath, "parquet", mode="overwrite")
+ write.df(df, parquetPath, "parquet", mode="overwrite")
parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
- saveDF(df, parquetPath2, "parquet", mode="overwrite")
+ write.df(df, parquetPath2, "parquet", mode="overwrite")
parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2)
expect_true(inherits(parquetDF, "DataFrame"))
expect_true(count(parquetDF) == count(df)*2)
diff --git a/core/pom.xml b/core/pom.xml
index 262a3320db106..bfa49d0d6dc25 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -361,6 +361,16 @@
junittest
+
+ org.hamcrest
+ hamcrest-core
+ test
+
+
+ org.hamcrest
+ hamcrest-library
+ test
+ com.novocodejunit-interface
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
new file mode 100644
index 0000000000000..3f746b886bc9b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+import scala.reflect.ClassTag;
+
+import org.apache.spark.serializer.DeserializationStream;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ * Our shuffle write path doesn't actually use this serializer (since we end up calling the
+ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ * around this, we pass a dummy no-op serializer.
+ */
+final class DummySerializerInstance extends SerializerInstance {
+
+ public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
+
+ private DummySerializerInstance() { }
+
+ @Override
+ public SerializationStream serializeStream(final OutputStream s) {
+ return new SerializationStream() {
+ @Override
+ public void flush() {
+ // Need to implement this because DiskObjectWriter uses it to flush the compression stream
+ try {
+ s.flush();
+ } catch (IOException e) {
+ PlatformDependent.throwException(e);
+ }
+ }
+
+ @Override
+ public SerializationStream writeObject(T t, ClassTag ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() {
+ // Need to implement this because DiskObjectWriter uses it to close the compression stream
+ try {
+ s.close();
+ } catch (IOException e) {
+ PlatformDependent.throwException(e);
+ }
+ }
+ };
+ }
+
+ @Override
+ public ByteBuffer serialize(T t, ClassTag ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public DeserializationStream deserializeStream(InputStream s) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T deserialize(ByteBuffer bytes, ClassTag ev1) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
new file mode 100644
index 0000000000000..4ee6a82c0423e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ *
+ * Within the long, the data is laid out as follows:
+ *
+ * [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ *
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ *
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+ static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes
+
+ /**
+ * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+ */
+ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+
+ /** Bit mask for the lower 40 bits of a long. */
+ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+ /** Bit mask for the upper 24 bits of a long */
+ private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+ /** Bit mask for the lower 27 bits of a long. */
+ private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Pack a record address and partition id into a single word.
+ *
+ * @param recordPointer a record pointer encoded by TaskMemoryManager.
+ * @param partitionId a shuffle partition id (maximum value of 2^24).
+ * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+ */
+ public static long packPointer(long recordPointer, int partitionId) {
+ assert (partitionId <= MAXIMUM_PARTITION_ID);
+ // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+ // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+ final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+ final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+ return (((long) partitionId) << 40) | compressedAddress;
+ }
+
+ private long packedRecordPointer;
+
+ public void set(long packedRecordPointer) {
+ this.packedRecordPointer = packedRecordPointer;
+ }
+
+ public int getPartitionId() {
+ return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+ }
+
+ public long getRecordPointer() {
+ final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+ final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+ return pageNumber | offsetInPage;
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
new file mode 100644
index 0000000000000..7bac0dc0bbeb6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ */
+final class SpillInfo {
+ final long[] partitionLengths;
+ final File file;
+ final TempShuffleBlockId blockId;
+
+ public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+ this.partitionLengths = new long[numPartitions];
+ this.file = file;
+ this.blockId = blockId;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
new file mode 100644
index 0000000000000..9e9ed94b7890c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ *
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ *
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class UnsafeShuffleExternalSorter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+
+ private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+ @VisibleForTesting
+ static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+ @VisibleForTesting
+ static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+ private final int initialSize;
+ private final int numPartitions;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final BlockManager blockManager;
+ private final TaskContext taskContext;
+ private final ShuffleWriteMetrics writeMetrics;
+
+ /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+ private final int fileBufferSizeBytes;
+
+ /**
+ * Memory pages that hold the records being sorted. The pages in this list are freed when
+ * spilling, although in principle we could recycle these pages across spills (on the other hand,
+ * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+ * itself).
+ */
+ private final LinkedList allocatedPages = new LinkedList();
+
+ private final LinkedList spills = new LinkedList();
+
+ // These variables are reset after spilling:
+ private UnsafeShuffleInMemorySorter sorter;
+ private MemoryBlock currentPage = null;
+ private long currentPagePosition = -1;
+ private long freeSpaceInCurrentPage = 0;
+
+ public UnsafeShuffleExternalSorter(
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ int initialSize,
+ int numPartitions,
+ SparkConf conf,
+ ShuffleWriteMetrics writeMetrics) throws IOException {
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.blockManager = blockManager;
+ this.taskContext = taskContext;
+ this.initialSize = initialSize;
+ this.numPartitions = numPartitions;
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+
+ this.writeMetrics = writeMetrics;
+ initializeForWriting();
+ }
+
+ /**
+ * Allocates new sort data structures. Called when creating the sorter and after each spill.
+ */
+ private void initializeForWriting() throws IOException {
+ // TODO: move this sizing calculation logic into a static method of sorter:
+ final long memoryRequested = initialSize * 8L;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryAcquired != memoryRequested) {
+ shuffleMemoryManager.release(memoryAcquired);
+ throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ }
+
+ this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+ }
+
+ /**
+ * Sorts the in-memory records and writes the sorted records to an on-disk file.
+ * This method does not free the sort data structures.
+ *
+ * @param isLastFile if true, this indicates that we're writing the final output file and that the
+ * bytes written should be counted towards shuffle spill metrics rather than
+ * shuffle write metrics.
+ */
+ private void writeSortedFile(boolean isLastFile) throws IOException {
+
+ final ShuffleWriteMetrics writeMetricsToUse;
+
+ if (isLastFile) {
+ // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ } else {
+ // We're spilling, so bytes written should be counted towards spill rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ }
+
+ // This call performs the actual sort.
+ final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+ sorter.getSortedIterator();
+
+ // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+ // after SPARK-5581 is fixed.
+ BlockObjectWriter writer;
+
+ // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+ // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+ // data through a byte array. This array does not need to be large enough to hold a single
+ // record;
+ final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+ // Because this output will be read during shuffle, its compression codec must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more details.
+ final Tuple2 spilledFileInfo =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = spilledFileInfo._2();
+ final TempShuffleBlockId blockId = spilledFileInfo._1();
+ final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+ // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+ // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ // around this, we pass a dummy no-op serializer.
+ final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+ int currentPartition = -1;
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+ assert (partition >= currentPartition);
+ if (partition != currentPartition) {
+ // Switch to the new partition
+ if (currentPartition != -1) {
+ writer.commitAndClose();
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ }
+ currentPartition = partition;
+ writer =
+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+ }
+
+ final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+ final Object recordPage = memoryManager.getPage(recordPointer);
+ final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+ int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+ long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+ while (dataRemaining > 0) {
+ final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+ PlatformDependent.copyMemory(
+ recordPage,
+ recordReadPosition,
+ writeBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ toTransfer);
+ writer.write(writeBuffer, 0, toTransfer);
+ recordReadPosition += toTransfer;
+ dataRemaining -= toTransfer;
+ }
+ writer.recordWritten();
+ }
+
+ if (writer != null) {
+ writer.commitAndClose();
+ // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+ // then the file might be empty. Note that it might be better to avoid calling
+ // writeSortedFile() in that case.
+ if (currentPartition != -1) {
+ spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+ spills.add(spillInfo);
+ }
+ }
+
+ if (!isLastFile) { // i.e. this is a spill file
+ // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+ // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+ // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+ // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+ // counter at a higher-level, then the in-progress metrics for records written and bytes
+ // written would get out of sync.
+ //
+ // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+ // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+ // metrics to the true write metrics here. The reason for performing this copying is so that
+ // we can avoid reporting spilled bytes as shuffle write bytes.
+ //
+ // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+ // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+ // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+ writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+ taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+ }
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
+ @VisibleForTesting
+ void spill() throws IOException {
+ logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spills.size(),
+ spills.size() > 1 ? " times" : " time");
+
+ writeSortedFile(false);
+ final long sorterMemoryUsage = sorter.getMemoryUsage();
+ sorter = null;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ final long spillSize = freeMemory();
+ taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+ initializeForWriting();
+ }
+
+ private long getMemoryUsage() {
+ return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+ }
+
+ private long freeMemory() {
+ long memoryFreed = 0;
+ for (MemoryBlock block : allocatedPages) {
+ memoryManager.freePage(block);
+ shuffleMemoryManager.release(block.size());
+ memoryFreed += block.size();
+ }
+ allocatedPages.clear();
+ currentPage = null;
+ currentPagePosition = -1;
+ freeSpaceInCurrentPage = 0;
+ return memoryFreed;
+ }
+
+ /**
+ * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+ */
+ public void cleanupAfterError() {
+ freeMemory();
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && !spill.file.delete()) {
+ logger.error("Unable to delete spill file {}", spill.file.getPath());
+ }
+ }
+ if (sorter != null) {
+ shuffleMemoryManager.release(sorter.getMemoryUsage());
+ sorter = null;
+ }
+ }
+
+ /**
+ * Checks whether there is enough space to insert a new record into the sorter.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+
+ * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ */
+ private boolean haveSpaceForRecord(int requiredSpace) {
+ assert (requiredSpace > 0);
+ return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size.
+ */
+ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+ if (!sorter.hasSpaceForAnotherRecord()) {
+ logger.debug("Attempting to expand sort pointer array");
+ final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ if (memoryAcquired < memoryToGrowPointerArray) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ } else {
+ sorter.expandPointerArray();
+ shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ }
+ }
+ if (requiredSpace > freeSpaceInCurrentPage) {
+ logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+ freeSpaceInCurrentPage);
+ // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+ // without using the free space at the end of the current page. We should also do this for
+ // BytesToBytesMap.
+ if (requiredSpace > PAGE_SIZE) {
+ throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+ PAGE_SIZE + ")");
+ } else {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquired < PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+ if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+ }
+ }
+ currentPage = memoryManager.allocatePage(PAGE_SIZE);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = PAGE_SIZE;
+ allocatedPages.add(currentPage);
+ }
+ }
+ }
+
+ /**
+ * Write a record to the shuffle sorter.
+ */
+ public void insertRecord(
+ Object recordBaseObject,
+ long recordBaseOffset,
+ int lengthInBytes,
+ int partitionId) throws IOException {
+ // Need 4 bytes to store the record length.
+ final int totalSpaceRequired = lengthInBytes + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+ currentPagePosition += 4;
+ freeSpaceInCurrentPage -= 4;
+ PlatformDependent.copyMemory(
+ recordBaseObject,
+ recordBaseOffset,
+ dataPageBaseObject,
+ currentPagePosition,
+ lengthInBytes);
+ currentPagePosition += lengthInBytes;
+ freeSpaceInCurrentPage -= lengthInBytes;
+ sorter.insertRecord(recordAddress, partitionId);
+ }
+
+ /**
+ * Close the sorter, causing any buffered data to be sorted and written out to disk.
+ *
+ * @return metadata for the spill files written by this sorter. If no records were ever inserted
+ * into this sorter, then this will return an empty array.
+ * @throws IOException
+ */
+ public SpillInfo[] closeAndGetSpills() throws IOException {
+ try {
+ if (sorter != null) {
+ // Do not count the final file towards the spill count.
+ writeSortedFile(true);
+ freeMemory();
+ }
+ return spills.toArray(new SpillInfo[spills.size()]);
+ } catch (IOException e) {
+ cleanupAfterError();
+ throw e;
+ }
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
new file mode 100644
index 0000000000000..5bab501da9364
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class UnsafeShuffleInMemorySorter {
+
+ private final Sorter sorter;
+ private static final class SortComparator implements Comparator {
+ @Override
+ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+ return left.getPartitionId() - right.getPartitionId();
+ }
+ }
+ private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+ /**
+ * An array of record pointers and partition ids that have been encoded by
+ * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+ * records.
+ */
+ private long[] pointerArray;
+
+ /**
+ * The position in the pointer array where new records can be inserted.
+ */
+ private int pointerArrayInsertPosition = 0;
+
+ public UnsafeShuffleInMemorySorter(int initialSize) {
+ assert (initialSize > 0);
+ this.pointerArray = new long[initialSize];
+ this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE);
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = pointerArray;
+ // Guard against overflow:
+ final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+ pointerArray = new long[newLength];
+ System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pointerArrayInsertPosition + 1 < pointerArray.length;
+ }
+
+ public long getMemoryUsage() {
+ return pointerArray.length * 8L;
+ }
+
+ /**
+ * Inserts a record to be sorted.
+ *
+ * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+ * certain pointer compression techniques used by the sorter, the sort can
+ * only operate on pointers that point to locations in the first
+ * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+ * @param partitionId the partition id, which must be less than or equal to
+ * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+ */
+ public void insertRecord(long recordPointer, int partitionId) {
+ if (!hasSpaceForAnotherRecord()) {
+ if (pointerArray.length == Integer.MAX_VALUE) {
+ throw new IllegalStateException("Sort pointer array has reached maximum size");
+ } else {
+ expandPointerArray();
+ }
+ }
+ pointerArray[pointerArrayInsertPosition] =
+ PackedRecordPointer.packPointer(recordPointer, partitionId);
+ pointerArrayInsertPosition++;
+ }
+
+ /**
+ * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+ */
+ public static final class UnsafeShuffleSorterIterator {
+
+ private final long[] pointerArray;
+ private final int numRecords;
+ final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+ private int position = 0;
+
+ public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ this.numRecords = numRecords;
+ this.pointerArray = pointerArray;
+ }
+
+ public boolean hasNext() {
+ return position < numRecords;
+ }
+
+ public void loadNext() {
+ packedRecordPointer.set(pointerArray[position]);
+ position++;
+ }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order.
+ */
+ public UnsafeShuffleSorterIterator getSortedIterator() {
+ sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+ return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
new file mode 100644
index 0000000000000..a66d74ee44782
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class UnsafeShuffleSortDataFormat extends SortDataFormat {
+
+ public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+
+ private UnsafeShuffleSortDataFormat() { }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos) {
+ // Since we re-use keys, this method shouldn't be called.
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PackedRecordPointer newKey() {
+ return new PackedRecordPointer();
+ }
+
+ @Override
+ public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+ reuse.set(data[pos]);
+ return reuse;
+ }
+
+ @Override
+ public void swap(long[] data, int pos0, int pos1) {
+ final long temp = data[pos0];
+ data[pos0] = data[pos1];
+ data[pos1] = temp;
+ }
+
+ @Override
+ public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+ dst[dstPos] = src[srcPos];
+ }
+
+ @Override
+ public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+ System.arraycopy(src, srcPos, dst, dstPos, length);
+ }
+
+ @Override
+ public long[] allocate(int length) {
+ return new long[length];
+ }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
new file mode 100644
index 0000000000000..ad7eb04afcd8c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter extends ShuffleWriter {
+
+ private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+ private static final ClassTag