diff --git a/core/pom.xml b/core/pom.xml
index e80829b7a7f3d..317fb3bb879af 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -91,6 +91,11 @@
spark-network-shuffle_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${project.version}
+ net.java.dev.jets3tjets3t
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 959aefabd8de4..0c4d28f786edd 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}
/**
@@ -69,6 +70,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
@@ -382,6 +384,15 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
+ val executorMemoryManager: ExecutorMemoryManager = {
+ val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
+ MemoryAllocator.UNSAFE
+ } else {
+ MemoryAllocator.HEAP
+ }
+ new ExecutorMemoryManager(allocator)
+ }
+
val envInstance = new SparkEnv(
executorId,
rpcEnv,
@@ -398,6 +409,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ executorMemoryManager,
outputCommitCoordinator,
conf)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 7d7fe1a446313..d09e17dea0911 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
/** ::DeveloperApi:: */
@DeveloperApi
def taskMetrics(): TaskMetrics
+
+ /**
+ * Returns the manager for this task's managed memory.
+ */
+ private[spark] def taskMemoryManager(): TaskMemoryManager
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 337c8e4ebebcd..b4d572cb52313 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
+ override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 327d155b38c22..b31082ff2de9c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
/**
@@ -179,6 +180,7 @@ private[spark] class Executor(
}
override def run(): Unit = {
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
@@ -191,6 +193,7 @@ private[spark] class Executor(
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
@@ -207,7 +210,21 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ val value = try {
+ task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ } finally {
+ // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
+ // when changing this, make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(errMsg)
+ } else {
+ logError(errMsg)
+ }
+ }
+ }
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4a32f8936fb0e..956c75afdd45b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -643,8 +644,15 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
- attemptNumber = 0, runningLocally = true)
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
+ val taskContext =
+ new TaskContextImpl(
+ job.finalStage.id,
+ job.partitions(0),
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = true)
TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -652,6 +660,16 @@ class DAGScheduler(
} finally {
taskContext.markTaskCompleted()
TaskContext.unset()
+ // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
+ // make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
+ } else {
+ logError(s"Managed memory leak detected; size = $freedMemory bytes")
+ }
+ }
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 8b592867ee31d..c4187a0cfab69 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
@@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* @return the result of the task
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
- context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
- taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
+ context = new TaskContextImpl(
+ stageId = stageId,
+ partitionId = partitionId,
+ taskAttemptId = taskAttemptId,
+ attemptNumber = attemptNumber,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
@@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
}
}
+ private var taskMemoryManager: TaskMemoryManager = _
+
+ def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
+ this.taskMemoryManager = taskMemoryManager
+ }
+
def runTask(context: TaskContext): T
def preferredLocations: Seq[TaskLocation] = Nil
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8a4f2a08fe701..34ac9361d46c6 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1009,7 +1009,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 70529d9216591..668ddf9f5f0a9 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index aea76c1adcc09..85eb2a1d07ba4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0, 0, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 057e226916027..83ae8701243e5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 37b593b2c5f79..2080c432d77db 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0),
+ new TaskContextImpl(0, 0, 0, 0, null),
transfer,
blockManager,
blocksByAddress,
@@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/pom.xml b/pom.xml
index bcc2f57f1af5d..92275ad4400f6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -97,6 +97,7 @@
sql/catalystsql/coresql/hive
+ unsafeassemblyexternal/twitterexternal/flume
@@ -1205,6 +1206,7 @@
falsefalsetrue
+ truefalse
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 09b4976d10c26..b7dbcd9bc562a 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -34,11 +34,11 @@ object BuildCommons {
val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
- streamingMqtt, streamingTwitter, streamingZeromq, launcher) =
+ streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
"streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
- "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _))
+ "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _))
val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl,
sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
@@ -159,7 +159,7 @@ object SparkBuild extends PomBuild {
// TODO: Add Sql to mima checks
// TODO: remove launcher from this list after 1.3.
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
- networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach {
+ networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
@@ -496,6 +496,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.ui.enabled=false",
javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
+ javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 3dea2ee76542f..5c322d032d474 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -50,6 +50,11 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${project.version}
+ org.scalacheckscalacheck_${scala.binary.version}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
new file mode 100644
index 0000000000000..299ff3728a6d9
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
+ *
+ * This map supports a maximum of 2 billion keys.
+ */
+public final class UnsafeFixedWidthAggregationMap {
+
+ /**
+ * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+ * map, we copy this buffer and use it as the value.
+ */
+ private final long[] emptyAggregationBuffer;
+
+ private final StructType aggregationBufferSchema;
+
+ private final StructType groupingKeySchema;
+
+ /**
+ * Encodes grouping keys as UnsafeRows.
+ */
+ private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+
+ /**
+ * A hashmap which maps from opaque bytearray keys to bytearray values.
+ */
+ private final BytesToBytesMap map;
+
+ /**
+ * Re-used pointer to the current aggregation buffer
+ */
+ private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+
+ /**
+ * Scratch space that is used when encoding grouping keys into UnsafeRow format.
+ *
+ * By default, this is a 1MB array, but it will grow as necessary in case larger keys are
+ * encountered.
+ */
+ private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];
+
+ private final boolean enablePerfMetrics;
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
+ * false otherwise.
+ */
+ public static boolean supportsGroupKeySchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+ * schema, false otherwise.
+ */
+ public static boolean supportsAggregationBufferSchema(StructType schema) {
+ for (StructField field: schema.fields()) {
+ if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Create a new UnsafeFixedWidthAggregationMap.
+ *
+ * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+ * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+ * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+ * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
+ */
+ public UnsafeFixedWidthAggregationMap(
+ Row emptyAggregationBuffer,
+ StructType aggregationBufferSchema,
+ StructType groupingKeySchema,
+ TaskMemoryManager memoryManager,
+ int initialCapacity,
+ boolean enablePerfMetrics) {
+ this.emptyAggregationBuffer =
+ convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
+ this.aggregationBufferSchema = aggregationBufferSchema;
+ this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
+ this.groupingKeySchema = groupingKeySchema;
+ this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.enablePerfMetrics = enablePerfMetrics;
+ }
+
+ /**
+ * Convert a Java object row into an UnsafeRow, allocating it into a new long array.
+ */
+ private static long[] convertToUnsafeRow(Row javaRow, StructType schema) {
+ final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
+ final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)];
+ final long writtenLength =
+ converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET);
+ assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
+ return unsafeRow;
+ }
+
+ /**
+ * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+ * return the same object.
+ */
+ public UnsafeRow getAggregationBuffer(Row groupingKey) {
+ final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
+ // Make sure that the buffer is large enough to hold the key. If it's not, grow it:
+ if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
+ // This new array will be initially zero, so there's no need to zero it out here
+ groupingKeyConversionScratchSpace = new long[groupingKeySize];
+ } else {
+ // Zero out the buffer that's used to hold the current row. This is necessary in order
+ // to ensure that rows hash properly, since garbage data from the previous row could
+ // otherwise end up as padding in this row. As a performance optimization, we only zero out
+ // the portion of the buffer that we'll actually write to.
+ Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0);
+ }
+ final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
+ groupingKey,
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET);
+ assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
+
+ // Probe our map using the serialized key
+ final BytesToBytesMap.Location loc = map.lookup(
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ groupingKeySize);
+ if (!loc.isDefined()) {
+ // This is the first time that we've seen this grouping key, so we'll insert a copy of the
+ // empty aggregation buffer into the map:
+ loc.putNewKey(
+ groupingKeyConversionScratchSpace,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ groupingKeySize,
+ emptyAggregationBuffer,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ emptyAggregationBuffer.length
+ );
+ }
+
+ // Reset the pointer to point to the value that we just stored or looked up:
+ final MemoryLocation address = loc.getValueAddress();
+ currentAggregationBuffer.pointTo(
+ address.getBaseObject(),
+ address.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ aggregationBufferSchema
+ );
+ return currentAggregationBuffer;
+ }
+
+ /**
+ * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
+ */
+ public static class MapEntry {
+ private MapEntry() { };
+ public final UnsafeRow key = new UnsafeRow();
+ public final UnsafeRow value = new UnsafeRow();
+ }
+
+ /**
+ * Returns an iterator over the keys and values in this map.
+ *
+ * For efficiency, each call returns the same object.
+ */
+ public Iterator iterator() {
+ return new Iterator() {
+
+ private final MapEntry entry = new MapEntry();
+ private final Iterator mapLocationIterator = map.iterator();
+
+ @Override
+ public boolean hasNext() {
+ return mapLocationIterator.hasNext();
+ }
+
+ @Override
+ public MapEntry next() {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ entry.key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ groupingKeySchema
+ );
+ entry.value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ aggregationBufferSchema
+ );
+ return entry;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Free the unsafe memory associated with this map.
+ */
+ public void free() {
+ map.free();
+ }
+
+ @SuppressWarnings("UseOfSystemOutOrSystemErr")
+ public void printPerfMetrics() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException("Perf metrics not enabled");
+ }
+ System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+ System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+ System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+ System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ }
+
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
new file mode 100644
index 0000000000000..0a358ed408aa1
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import scala.collection.Map;
+import scala.collection.Seq;
+import scala.collection.mutable.ArraySeq;
+
+import javax.annotation.Nullable;
+import java.math.BigDecimal;
+import java.sql.Date;
+import java.util.*;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.DataType;
+import static org.apache.spark.sql.types.DataTypes.*;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.UTF8String;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
+
+/**
+ * An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
+ *
+ * Each tuple has three parts: [null bit set] [values] [variable length portion]
+ *
+ * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores
+ * one bit per field.
+ *
+ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length
+ * primitive types, such as long, double, or int, we store the value directly in the word. For
+ * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
+ * base address of the row) that points to the beginning of the variable-length field.
+ *
+ * Instances of `UnsafeRow` act as pointers to row data stored in this format.
+ */
+public final class UnsafeRow implements MutableRow {
+
+ private Object baseObject;
+ private long baseOffset;
+
+ Object getBaseObject() { return baseObject; }
+ long getBaseOffset() { return baseOffset; }
+
+ /** The number of fields in this row, used for calculating the bitset width (and in assertions) */
+ private int numFields;
+
+ /** The width of the null tracking bit set, in bytes */
+ private int bitSetWidthInBytes;
+ /**
+ * This optional schema is required if you want to call generic get() and set() methods on
+ * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE()
+ * methods. This should be removed after the planned InternalRow / Row split; right now, it's only
+ * needed by the generic get() method, which is only called internally by code that accesses
+ * UTF8String-typed columns.
+ */
+ @Nullable
+ private StructType schema;
+
+ private long getFieldOffset(int ordinal) {
+ return baseOffset + bitSetWidthInBytes + ordinal * 8L;
+ }
+
+ public static int calculateBitSetWidthInBytes(int numFields) {
+ return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
+ }
+
+ /**
+ * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
+ */
+ public static final Set settableFieldTypes;
+
+ /**
+ * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
+ */
+ public static final Set readableFieldTypes;
+
+ static {
+ settableFieldTypes = Collections.unmodifiableSet(
+ new HashSet(
+ Arrays.asList(new DataType[] {
+ NullType,
+ BooleanType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ LongType,
+ FloatType,
+ DoubleType
+ })));
+
+ // We support get() on a superset of the types for which we support set():
+ final Set _readableFieldTypes = new HashSet(
+ Arrays.asList(new DataType[]{
+ StringType
+ }));
+ _readableFieldTypes.addAll(settableFieldTypes);
+ readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
+ }
+
+ /**
+ * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
+ * since the value returned by this constructor is equivalent to a null pointer.
+ */
+ public UnsafeRow() { }
+
+ /**
+ * Update this UnsafeRow to point to different backing data.
+ *
+ * @param baseObject the base object
+ * @param baseOffset the offset within the base object
+ * @param numFields the number of fields in this row
+ * @param schema an optional schema; this is necessary if you want to call generic get() or set()
+ * methods on this row, but is optional if the caller will only use type-specific
+ * getTYPE() and setTYPE() methods.
+ */
+ public void pointTo(
+ Object baseObject,
+ long baseOffset,
+ int numFields,
+ @Nullable StructType schema) {
+ assert numFields >= 0 : "numFields should >= 0";
+ assert schema == null || schema.fields().length == numFields;
+ this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
+ this.baseObject = baseObject;
+ this.baseOffset = baseOffset;
+ this.numFields = numFields;
+ this.schema = schema;
+ }
+
+ private void assertIndexIsValid(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < numFields : "index (" + index + ") should <= " + numFields;
+ }
+
+ @Override
+ public void setNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.set(baseObject, baseOffset, i);
+ // To preserve row equality, zero out the value when setting the column to null.
+ // Since this row does does not currently support updates to variable-length values, we don't
+ // have to worry about zeroing out that data.
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0);
+ }
+
+ private void setNotNullAt(int i) {
+ assertIndexIsValid(i);
+ BitSetMethods.unset(baseObject, baseOffset, i);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ assertIndexIsValid(ordinal);
+ setNotNullAt(ordinal);
+ PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
+ }
+
+ @Override
+ public void setString(int ordinal, String value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int size() {
+ return numFields;
+ }
+
+ @Override
+ public int length() {
+ return size();
+ }
+
+ @Override
+ public StructType schema() {
+ return schema;
+ }
+
+ @Override
+ public Object apply(int i) {
+ return get(i);
+ }
+
+ @Override
+ public Object get(int i) {
+ assertIndexIsValid(i);
+ assert (schema != null) : "Schema must be defined when calling generic get() method";
+ final DataType dataType = schema.fields()[i].dataType();
+ // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
+ // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
+ // separate the internal and external row interfaces, then internal code can fetch strings via
+ // a new getUTF8String() method and we'll be able to remove this method.
+ if (isNullAt(i)) {
+ return null;
+ } else if (dataType == StringType) {
+ return getUTF8String(i);
+ } else {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int i) {
+ assertIndexIsValid(i);
+ return BitSetMethods.isSet(baseObject, baseOffset, i);
+ }
+
+ @Override
+ public boolean getBoolean(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public byte getByte(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public short getShort(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public int getInt(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public long getLong(int i) {
+ assertIndexIsValid(i);
+ return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
+ }
+
+ @Override
+ public float getFloat(int i) {
+ assertIndexIsValid(i);
+ if (isNullAt(i)) {
+ return Float.NaN;
+ } else {
+ return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i));
+ }
+ }
+
+ @Override
+ public double getDouble(int i) {
+ assertIndexIsValid(i);
+ if (isNullAt(i)) {
+ return Float.NaN;
+ } else {
+ return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i));
+ }
+ }
+
+ public UTF8String getUTF8String(int i) {
+ assertIndexIsValid(i);
+ final UTF8String str = new UTF8String();
+ final long offsetToStringSize = getLong(i);
+ final int stringSizeInBytes =
+ (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
+ final byte[] strBytes = new byte[stringSizeInBytes];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
+ strBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ stringSizeInBytes
+ );
+ str.set(strBytes);
+ return str;
+ }
+
+ @Override
+ public String getString(int i) {
+ return getUTF8String(i).toString();
+ }
+
+ @Override
+ public BigDecimal getDecimal(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Date getDate(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Seq getSeq(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List getList(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Map getMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public scala.collection.immutable.Map getValuesMap(Seq fieldNames) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public java.util.Map getJavaMap(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row getStruct(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T getAs(int i) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public T getAs(String fieldName) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int fieldIndex(String name) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Row copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean anyNull() {
+ return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
+ }
+
+ @Override
+ public Seq