Skip to content

Commit 50af8fc

Browse files
authored
[SPARK-18231] Optimise SizeEstimator implementation
Several improvements to the SizeEstimator for performance, most of the benefit comes from, when estimating, contending to not contending on multiple threads. There can be a small boost in uncontended scenarios from the removal of the synchronisation code but the cost of that synchronisation when not truly contended is low. On the PageRank workload for HiBench we see 49~ second durations reduced to ~41 second durations. I don't see any changes for other workloads. Observed with both IBM's SDK for Java and OpenJDK.
1 parent 79f5f28 commit 50af8fc

File tree

1 file changed

+58
-58
lines changed

1 file changed

+58
-58
lines changed

core/src/main/scala/org/apache/spark/util/SizeEstimator.scala

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ package org.apache.spark.util
1919

2020
import java.lang.management.ManagementFactory
2121
import java.lang.reflect.{Field, Modifier}
22-
import java.util.{IdentityHashMap, Random}
22+
import java.util.{IdentityHashMap, WeakHashMap}
23+
import java.util.concurrent.ThreadLocalRandom
2324

2425
import scala.collection.mutable.ArrayBuffer
25-
import scala.runtime.ScalaRunTime
26+
import scala.concurrent.util.Unsafe
2627

2728
import com.google.common.collect.MapMaker
2829

@@ -89,7 +90,13 @@ object SizeEstimator extends Logging {
8990

9091
// A cache of ClassInfo objects for each class
9192
// We use weakKeys to allow GC of dynamically created classes
92-
private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]()
93+
private val classInfos = new ThreadLocal[WeakHashMap[Class[_], ClassInfo]] {
94+
override def initialValue(): java.util.WeakHashMap[Class[_], ClassInfo] = {
95+
val toReturn = new WeakHashMap[Class[_], ClassInfo]()
96+
toReturn.put(classOf[Object], new ClassInfo(objectSize, new Array[Int](0)))
97+
return toReturn
98+
}
99+
}
93100

94101
// Object and pointer sizes are arch dependent
95102
private var is64bit = false
@@ -119,8 +126,6 @@ object SizeEstimator extends Logging {
119126
}
120127
}
121128
pointerSize = if (is64bit && !isCompressedOops) 8 else 4
122-
classInfos.clear()
123-
classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil))
124129
}
125130

126131
private def getIsCompressedOops: Boolean = {
@@ -192,7 +197,7 @@ object SizeEstimator extends Logging {
192197
*/
193198
private class ClassInfo(
194199
val shellSize: Long,
195-
val pointerFields: List[Field]) {}
200+
val fieldOffsets: Array[Int]) {}
196201

197202
private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = {
198203
val state = new SearchState(visited)
@@ -221,8 +226,13 @@ object SizeEstimator extends Logging {
221226
case _ =>
222227
val classInfo = getClassInfo(cls)
223228
state.size += alignSize(classInfo.shellSize)
224-
for (field <- classInfo.pointerFields) {
225-
state.enqueue(field.get(obj))
229+
// avoid an iterator based for loop for performance
230+
var index = 0
231+
val fieldCount = classInfo.fieldOffsets.length
232+
val us = Unsafe.instance
233+
while (index < fieldCount) {
234+
state.enqueue(us.getObject(obj, classInfo.fieldOffsets(index).toLong))
235+
index += 1
226236
}
227237
}
228238
}
@@ -233,7 +243,7 @@ object SizeEstimator extends Logging {
233243
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
234244

235245
private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
236-
val length = ScalaRunTime.array_length(array)
246+
val length = java.lang.reflect.Array.getLength(array)
237247
val elementClass = arrayClass.getComponentType()
238248

239249
// Arrays have object header and length field which is an integer
@@ -243,47 +253,59 @@ object SizeEstimator extends Logging {
243253
arrSize += alignSize(length.toLong * primitiveSize(elementClass))
244254
state.size += arrSize
245255
} else {
256+
// We know that the array we are dealing with is an array of references
257+
// so explicitly expose this type so we can directly manipulate the array
258+
// without help form the Scala runtime for efficency
246259
arrSize += alignSize(length.toLong * pointerSize)
247260
state.size += arrSize
248261

262+
val objArray = array.asInstanceOf[Array[AnyRef]]
263+
249264
if (length <= ARRAY_SIZE_FOR_SAMPLING) {
250265
var arrayIndex = 0
251266
while (arrayIndex < length) {
252-
state.enqueue(ScalaRunTime.array_apply(array, arrayIndex).asInstanceOf[AnyRef])
267+
state.enqueue(objArray(arrayIndex))
253268
arrayIndex += 1
254269
}
255270
} else {
256271
// Estimate the size of a large array by sampling elements without replacement.
257272
// To exclude the shared objects that the array elements may link, sample twice
258-
// and use the min one to calculate array size.
259-
val rand = new Random(42)
273+
// and use the min one to calculate array size.
274+
// Use ThreadLocalRandom here since the random is only accessed from 1 thread
275+
// and we can save the overhead of the full thread-safe Random
276+
val rand = ThreadLocalRandom.current
260277
val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
261-
val s1 = sampleArray(array, state, rand, drawn, length)
262-
val s2 = sampleArray(array, state, rand, drawn, length)
278+
val s1 = sampleArray(objArray, state, rand, drawn, length)
279+
val s2 = sampleArray(objArray, state, rand, drawn, length)
263280
val size = math.min(s1, s2)
281+
264282
state.size += math.max(s1, s2) +
265283
(size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong
266284
}
267285
}
268286
}
269287

270288
private def sampleArray(
271-
array: AnyRef,
289+
array: Array[AnyRef],
272290
state: SearchState,
273-
rand: Random,
291+
rand: ThreadLocalRandom,
274292
drawn: OpenHashSet[Int],
275293
length: Int): Long = {
276294
var size = 0L
277-
for (i <- 0 until ARRAY_SAMPLE_SIZE) {
295+
// avoid the use of an iterator derrived from the range syntax here for performance
296+
var count = 0
297+
val end = ARRAY_SAMPLE_SIZE
298+
while (count <= end) {
278299
var index = 0
279300
do {
280301
index = rand.nextInt(length)
281302
} while (drawn.contains(index))
282303
drawn.add(index)
283-
val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
304+
val obj = array(index)
284305
if (obj != null) {
285306
size += SizeEstimator.estimate(obj, state.visited).toLong
286307
}
308+
count += 1
287309
}
288310
size
289311
}
@@ -316,62 +338,40 @@ object SizeEstimator extends Logging {
316338
*/
317339
private def getClassInfo(cls: Class[_]): ClassInfo = {
318340
// Check whether we've already cached a ClassInfo for this class
319-
val info = classInfos.get(cls)
341+
val info = classInfos.get().get(cls)
320342
if (info != null) {
321343
return info
322344
}
323345

324346
val parent = getClassInfo(cls.getSuperclass)
347+
val fields = cls.getDeclaredFields
348+
val fieldCount = fields.length
325349
var shellSize = parent.shellSize
326-
var pointerFields = parent.pointerFields
327-
val sizeCount = Array.fill(fieldSizes.max + 1)(0)
350+
var fieldOffsets = parent.fieldOffsets.toList
351+
352+
var index = 0
328353

329-
// iterate through the fields of this class and gather information.
330-
for (field <- cls.getDeclaredFields) {
354+
while (index < fieldCount) {
355+
val field = fields(index)
331356
if (!Modifier.isStatic(field.getModifiers)) {
332357
val fieldClass = field.getType
333358
if (fieldClass.isPrimitive) {
334-
sizeCount(primitiveSize(fieldClass)) += 1
359+
if (cls == classOf[Double] || cls == classOf[Long]) {
360+
shellSize += 8
361+
} else {
362+
shellSize += 4
363+
}
335364
} else {
336-
field.setAccessible(true) // Enable future get()'s on this field
337-
sizeCount(pointerSize) += 1
338-
pointerFields = field :: pointerFields
365+
shellSize += pointerSize
366+
fieldOffsets = Unsafe.instance.objectFieldOffset(field).toInt :: fieldOffsets
339367
}
340368
}
369+
index += 1
341370
}
342371

343-
// Based on the simulated field layout code in Aleksey Shipilev's report:
344-
// http://cr.openjdk.java.net/~shade/papers/2013-shipilev-fieldlayout-latest.pdf
345-
// The code is in Figure 9.
346-
// The simplified idea of field layout consists of 4 parts (see more details in the report):
347-
//
348-
// 1. field alignment: HotSpot lays out the fields aligned by their size.
349-
// 2. object alignment: HotSpot rounds instance size up to 8 bytes
350-
// 3. consistent fields layouts throughout the hierarchy: This means we should layout
351-
// superclass first. And we can use superclass's shellSize as a starting point to layout the
352-
// other fields in this class.
353-
// 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed
354-
// with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322
355-
//
356-
// The real world field layout is much more complicated. There are three kinds of fields
357-
// order in Java 8. And we don't consider the @contended annotation introduced by Java 8.
358-
// see the HotSpot classloader code, layout_fields method for more details.
359-
// hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp
360-
var alignedSize = shellSize
361-
for (size <- fieldSizes if sizeCount(size) > 0) {
362-
val count = sizeCount(size).toLong
363-
// If there are internal gaps, smaller field can fit in.
364-
alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count)
365-
shellSize += size * count
366-
}
367-
368-
// Should choose a larger size to be new shellSize and clearly alignedSize >= shellSize, and
369-
// round up the instance filed blocks
370-
shellSize = alignSizeUp(alignedSize, pointerSize)
371-
372372
// Create and cache a new ClassInfo
373-
val newInfo = new ClassInfo(shellSize, pointerFields)
374-
classInfos.put(cls, newInfo)
373+
val newInfo = new ClassInfo(shellSize, fieldOffsets.toArray)
374+
classInfos.get().put(cls, newInfo)
375375
newInfo
376376
}
377377

0 commit comments

Comments
 (0)