diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 386fdfd218a8..7f386bd6f81c 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -19,10 +19,11 @@ package org.apache.spark.util import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} -import java.util.{IdentityHashMap, Random} +import java.util.{IdentityHashMap, WeakHashMap} +import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable.ArrayBuffer -import scala.runtime.ScalaRunTime +import scala.concurrent.util.Unsafe import com.google.common.collect.MapMaker @@ -89,7 +90,13 @@ object SizeEstimator extends Logging { // A cache of ClassInfo objects for each class // We use weakKeys to allow GC of dynamically created classes - private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]() + private val classInfos = new ThreadLocal[WeakHashMap[Class[_], ClassInfo]] { + override def initialValue(): java.util.WeakHashMap[Class[_], ClassInfo] = { + val toReturn = new WeakHashMap[Class[_], ClassInfo]() + toReturn.put(classOf[Object], new ClassInfo(objectSize, new Array[Int](0))) + return toReturn + } + } // Object and pointer sizes are arch dependent private var is64bit = false @@ -119,8 +126,6 @@ object SizeEstimator extends Logging { } } pointerSize = if (is64bit && !isCompressedOops) 8 else 4 - classInfos.clear() - classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) } private def getIsCompressedOops: Boolean = { @@ -192,7 +197,7 @@ object SizeEstimator extends Logging { */ private class ClassInfo( val shellSize: Long, - val pointerFields: List[Field]) {} + val fieldOffsets: Array[Int]) {} private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { val state = new SearchState(visited) @@ -221,8 +226,13 @@ object SizeEstimator extends Logging { case _ => val classInfo = getClassInfo(cls) state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) + // avoid an iterator based for loop for performance + var index = 0 + val fieldCount = classInfo.fieldOffsets.length + val us = Unsafe.instance + while (index < fieldCount) { + state.enqueue(us.getObject(obj, classInfo.fieldOffsets(index).toLong)) + index += 1 } } } @@ -233,7 +243,7 @@ object SizeEstimator extends Logging { private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { - val length = ScalaRunTime.array_length(array) + val length = java.lang.reflect.Array.getLength(array) val elementClass = arrayClass.getComponentType() // Arrays have object header and length field which is an integer @@ -243,24 +253,32 @@ object SizeEstimator extends Logging { arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { + // We know that the array we are dealing with is an array of references + // so explicitly expose this type so we can directly manipulate the array + // without help form the Scala runtime for efficency arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize + val objArray = array.asInstanceOf[Array[AnyRef]] + if (length <= ARRAY_SIZE_FOR_SAMPLING) { var arrayIndex = 0 while (arrayIndex < length) { - state.enqueue(ScalaRunTime.array_apply(array, arrayIndex).asInstanceOf[AnyRef]) + state.enqueue(objArray(arrayIndex)) arrayIndex += 1 } } else { // Estimate the size of a large array by sampling elements without replacement. // To exclude the shared objects that the array elements may link, sample twice - // and use the min one to calculate array size. - val rand = new Random(42) + // and use the min one to calculate array size. + // Use ThreadLocalRandom here since the random is only accessed from 1 thread + // and we can save the overhead of the full thread-safe Random + val rand = ThreadLocalRandom.current val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) - val s1 = sampleArray(array, state, rand, drawn, length) - val s2 = sampleArray(array, state, rand, drawn, length) + val s1 = sampleArray(objArray, state, rand, drawn, length) + val s2 = sampleArray(objArray, state, rand, drawn, length) val size = math.min(s1, s2) + state.size += math.max(s1, s2) + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } @@ -268,22 +286,26 @@ object SizeEstimator extends Logging { } private def sampleArray( - array: AnyRef, + array: Array[AnyRef], state: SearchState, - rand: Random, + rand: ThreadLocalRandom, drawn: OpenHashSet[Int], length: Int): Long = { var size = 0L - for (i <- 0 until ARRAY_SAMPLE_SIZE) { + // avoid the use of an iterator derrived from the range syntax here for performance + var count = 0 + val end = ARRAY_SAMPLE_SIZE + while (count <= end) { var index = 0 do { index = rand.nextInt(length) } while (drawn.contains(index)) drawn.add(index) - val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] + val obj = array(index) if (obj != null) { size += SizeEstimator.estimate(obj, state.visited).toLong } + count += 1 } size } @@ -316,62 +338,40 @@ object SizeEstimator extends Logging { */ private def getClassInfo(cls: Class[_]): ClassInfo = { // Check whether we've already cached a ClassInfo for this class - val info = classInfos.get(cls) + val info = classInfos.get().get(cls) if (info != null) { return info } val parent = getClassInfo(cls.getSuperclass) + val fields = cls.getDeclaredFields + val fieldCount = fields.length var shellSize = parent.shellSize - var pointerFields = parent.pointerFields - val sizeCount = Array.fill(fieldSizes.max + 1)(0) + var fieldOffsets = parent.fieldOffsets.toList + + var index = 0 - // iterate through the fields of this class and gather information. - for (field <- cls.getDeclaredFields) { + while (index < fieldCount) { + val field = fields(index) if (!Modifier.isStatic(field.getModifiers)) { val fieldClass = field.getType if (fieldClass.isPrimitive) { - sizeCount(primitiveSize(fieldClass)) += 1 + if (cls == classOf[Double] || cls == classOf[Long]) { + shellSize += 8 + } else { + shellSize += 4 + } } else { - field.setAccessible(true) // Enable future get()'s on this field - sizeCount(pointerSize) += 1 - pointerFields = field :: pointerFields + shellSize += pointerSize + fieldOffsets = Unsafe.instance.objectFieldOffset(field).toInt :: fieldOffsets } } + index += 1 } - // Based on the simulated field layout code in Aleksey Shipilev's report: - // http://cr.openjdk.java.net/~shade/papers/2013-shipilev-fieldlayout-latest.pdf - // The code is in Figure 9. - // The simplified idea of field layout consists of 4 parts (see more details in the report): - // - // 1. field alignment: HotSpot lays out the fields aligned by their size. - // 2. object alignment: HotSpot rounds instance size up to 8 bytes - // 3. consistent fields layouts throughout the hierarchy: This means we should layout - // superclass first. And we can use superclass's shellSize as a starting point to layout the - // other fields in this class. - // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed - // with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322 - // - // The real world field layout is much more complicated. There are three kinds of fields - // order in Java 8. And we don't consider the @contended annotation introduced by Java 8. - // see the HotSpot classloader code, layout_fields method for more details. - // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp - var alignedSize = shellSize - for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size).toLong - // If there are internal gaps, smaller field can fit in. - alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) - shellSize += size * count - } - - // Should choose a larger size to be new shellSize and clearly alignedSize >= shellSize, and - // round up the instance filed blocks - shellSize = alignSizeUp(alignedSize, pointerSize) - // Create and cache a new ClassInfo - val newInfo = new ClassInfo(shellSize, pointerFields) - classInfos.put(cls, newInfo) + val newInfo = new ClassInfo(shellSize, fieldOffsets.toArray) + classInfos.get().put(cls, newInfo) newInfo }