Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 58 additions & 58 deletions core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.util

import java.lang.management.ManagementFactory
import java.lang.reflect.{Field, Modifier}
import java.util.{IdentityHashMap, Random}
import java.util.{IdentityHashMap, WeakHashMap}
import java.util.concurrent.ThreadLocalRandom

import scala.collection.mutable.ArrayBuffer
import scala.runtime.ScalaRunTime
import scala.concurrent.util.Unsafe

import com.google.common.collect.MapMaker

Expand Down Expand Up @@ -89,7 +90,13 @@ object SizeEstimator extends Logging {

// A cache of ClassInfo objects for each class
// We use weakKeys to allow GC of dynamically created classes
private val classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]()
private val classInfos = new ThreadLocal[WeakHashMap[Class[_], ClassInfo]] {
override def initialValue(): java.util.WeakHashMap[Class[_], ClassInfo] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove java.util, and 'return' below. "map" is better than "toReturn"

This is going to expand the memory footprint, because redundant copies of this info will be maintained per thread. Is the contention that significant?

val toReturn = new WeakHashMap[Class[_], ClassInfo]()
toReturn.put(classOf[Object], new ClassInfo(objectSize, new Array[Int](0)))
return toReturn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keep the returned value same as before ?
And move the initialization back into initialize() - so that use of classInfos Map across threads wont happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Built and profiled, averaging 42 sec run times with the initial commit, averaging 45 second run times with this. No changes = 48 sec.

My code as a diff (so using a ConcurrentHashMap and var not val so we can initialise it later) provided here:

 import java.lang.management.ManagementFactory
 import java.lang.reflect.{Field, Modifier}
 import java.util.{IdentityHashMap, WeakHashMap}
-import java.util.concurrent.ThreadLocalRandom
+import java.util.concurrent.{ThreadLocalRandom, ConcurrentMap}

 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.util.Unsafe
@@ -88,16 +88,6 @@ object SizeEstimator extends Logging {
   // TODO: Is this arch dependent ?
   private val ALIGN_SIZE = 8

-  // A cache of ClassInfo objects for each class
-  // We use weakKeys to allow GC of dynamically created classes
-  private val classInfos = new ThreadLocal[WeakHashMap[Class[_], ClassInfo]] {
-    override def initialValue(): java.util.WeakHashMap[Class[_], ClassInfo] = {
-      val toReturn = new WeakHashMap[Class[_], ClassInfo]()
-      toReturn.put(classOf[Object], new ClassInfo(objectSize, new Array[Int](0)))
-      return toReturn
-    }
-  }
-
   // Object and pointer sizes are arch dependent
   private var is64bit = false

@@ -109,6 +99,8 @@ object SizeEstimator extends Logging {
   // Minimum size of a java.lang.Object
   private var objectSize = 8

+  private var classInfos: ConcurrentMap[Class[_], ClassInfo] = null
+
   initialize()

   // Sets object size, pointer size based on architecture and CompressedOops settings
@@ -126,6 +118,9 @@ object SizeEstimator extends Logging {
       }
     }
     pointerSize = if (is64bit && !isCompressedOops) 8 else 4
+
+    classInfos = new MapMaker().weakKeys().makeMap[Class[_], ClassInfo]()
+    classInfos.put(classOf[Object], new ClassInfo(objectSize, new Array[Int](0)))
   }

   private def getIsCompressedOops: Boolean = {
@@ -338,7 +333,7 @@ object SizeEstimator extends Logging {
    */
   private def getClassInfo(cls: Class[_]): ClassInfo = {
     // Check whether we've already cached a ClassInfo for this class
-    val info = classInfos.get().get(cls)
+    val info = classInfos.get(cls)
     if (info != null) {
       return info
     }
@@ -371,7 +366,7 @@ object SizeEstimator extends Logging {

     // Create and cache a new ClassInfo
     val newInfo = new ClassInfo(shellSize, fieldOffsets.toArray)
-    classInfos.get().put(cls, newInfo)
+    classInfos.put(cls, newInfo)
     newInfo
   }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant was, continue to use ThreadLocal, but maintain the MapMaker's result for thlocal.get()

And move the initilization to initialize() instead of in initialValue()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks will give this a try

}
}

// Object and pointer sizes are arch dependent
private var is64bit = false
Expand Down Expand Up @@ -119,8 +126,6 @@ object SizeEstimator extends Logging {
}
}
pointerSize = if (is64bit && !isCompressedOops) 8 else 4
classInfos.clear()
classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil))
}

private def getIsCompressedOops: Boolean = {
Expand Down Expand Up @@ -192,7 +197,7 @@ object SizeEstimator extends Logging {
*/
private class ClassInfo(
val shellSize: Long,
val pointerFields: List[Field]) {}
val fieldOffsets: Array[Int]) {}

private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = {
val state = new SearchState(visited)
Expand Down Expand Up @@ -221,8 +226,13 @@ object SizeEstimator extends Logging {
case _ =>
val classInfo = getClassInfo(cls)
state.size += alignSize(classInfo.shellSize)
for (field <- classInfo.pointerFields) {
state.enqueue(field.get(obj))
// avoid an iterator based for loop for performance
var index = 0
val fieldCount = classInfo.fieldOffsets.length
val us = Unsafe.instance
while (index < fieldCount) {
state.enqueue(us.getObject(obj, classInfo.fieldOffsets(index).toLong))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand avoiding reflection, but this is a dicier way to access fields of an object. I don't have a specific reason this would fail but the fact that it uses unsafe is riskier. Is this worth it?

index += 1
}
}
}
Expand All @@ -233,7 +243,7 @@ object SizeEstimator extends Logging {
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING

private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
val length = ScalaRunTime.array_length(array)
val length = java.lang.reflect.Array.getLength(array)
val elementClass = arrayClass.getComponentType()

// Arrays have object header and length field which is an integer
Expand All @@ -243,47 +253,59 @@ object SizeEstimator extends Logging {
arrSize += alignSize(length.toLong * primitiveSize(elementClass))
state.size += arrSize
} else {
// We know that the array we are dealing with is an array of references
// so explicitly expose this type so we can directly manipulate the array
// without help form the Scala runtime for efficency
arrSize += alignSize(length.toLong * pointerSize)
state.size += arrSize

val objArray = array.asInstanceOf[Array[AnyRef]]

if (length <= ARRAY_SIZE_FOR_SAMPLING) {
var arrayIndex = 0
while (arrayIndex < length) {
state.enqueue(ScalaRunTime.array_apply(array, arrayIndex).asInstanceOf[AnyRef])
state.enqueue(objArray(arrayIndex))
arrayIndex += 1
}
} else {
// Estimate the size of a large array by sampling elements without replacement.
// To exclude the shared objects that the array elements may link, sample twice
// and use the min one to calculate array size.
val rand = new Random(42)
// and use the min one to calculate array size.
// Use ThreadLocalRandom here since the random is only accessed from 1 thread
// and we can save the overhead of the full thread-safe Random
val rand = ThreadLocalRandom.current
val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
val s1 = sampleArray(array, state, rand, drawn, length)
val s2 = sampleArray(array, state, rand, drawn, length)
val s1 = sampleArray(objArray, state, rand, drawn, length)
val s2 = sampleArray(objArray, state, rand, drawn, length)
val size = math.min(s1, s2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to this method are excellent and should speed things up !

state.size += math.max(s1, s2) +
(size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong
}
}
}

private def sampleArray(
array: AnyRef,
array: Array[AnyRef],
state: SearchState,
rand: Random,
rand: ThreadLocalRandom,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this has to change

drawn: OpenHashSet[Int],
length: Int): Long = {
var size = 0L
for (i <- 0 until ARRAY_SAMPLE_SIZE) {
// avoid the use of an iterator derrived from the range syntax here for performance
var count = 0
val end = ARRAY_SAMPLE_SIZE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end is redundant here

while (count <= end) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< end for until semantics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, should be just < not <=, will add into the next commit

var index = 0
do {
index = rand.nextInt(length)
} while (drawn.contains(index))
drawn.add(index)
val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
val obj = array(index)
if (obj != null) {
size += SizeEstimator.estimate(obj, state.visited).toLong
}
count += 1
}
size
}
Expand Down Expand Up @@ -316,62 +338,40 @@ object SizeEstimator extends Logging {
*/
private def getClassInfo(cls: Class[_]): ClassInfo = {
// Check whether we've already cached a ClassInfo for this class
val info = classInfos.get(cls)
val info = classInfos.get().get(cls)
if (info != null) {
return info
}

val parent = getClassInfo(cls.getSuperclass)
val fields = cls.getDeclaredFields
val fieldCount = fields.length
var shellSize = parent.shellSize
var pointerFields = parent.pointerFields
val sizeCount = Array.fill(fieldSizes.max + 1)(0)
var fieldOffsets = parent.fieldOffsets.toList

var index = 0

// iterate through the fields of this class and gather information.
for (field <- cls.getDeclaredFields) {
while (index < fieldCount) {
val field = fields(index)
if (!Modifier.isStatic(field.getModifiers)) {
val fieldClass = field.getType
if (fieldClass.isPrimitive) {
sizeCount(primitiveSize(fieldClass)) += 1
if (cls == classOf[Double] || cls == classOf[Long]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the logic changes below aren't obviously OK. this seems to lose a lot of logic. I think this has to be explained or backed out

shellSize += 8
} else {
shellSize += 4
}
} else {
field.setAccessible(true) // Enable future get()'s on this field
sizeCount(pointerSize) += 1
pointerFields = field :: pointerFields
shellSize += pointerSize
fieldOffsets = Unsafe.instance.objectFieldOffset(field).toInt :: fieldOffsets
}
}
index += 1
}

// Based on the simulated field layout code in Aleksey Shipilev's report:
// http://cr.openjdk.java.net/~shade/papers/2013-shipilev-fieldlayout-latest.pdf
// The code is in Figure 9.
// The simplified idea of field layout consists of 4 parts (see more details in the report):
//
// 1. field alignment: HotSpot lays out the fields aligned by their size.
// 2. object alignment: HotSpot rounds instance size up to 8 bytes
// 3. consistent fields layouts throughout the hierarchy: This means we should layout
// superclass first. And we can use superclass's shellSize as a starting point to layout the
// other fields in this class.
// 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed
// with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322
//
// The real world field layout is much more complicated. There are three kinds of fields
// order in Java 8. And we don't consider the @contended annotation introduced by Java 8.
// see the HotSpot classloader code, layout_fields method for more details.
// hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp
var alignedSize = shellSize
for (size <- fieldSizes if sizeCount(size) > 0) {
val count = sizeCount(size).toLong
// If there are internal gaps, smaller field can fit in.
alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count)
shellSize += size * count
}

// Should choose a larger size to be new shellSize and clearly alignedSize >= shellSize, and
// round up the instance filed blocks
shellSize = alignSizeUp(alignedSize, pointerSize)

// Create and cache a new ClassInfo
val newInfo = new ClassInfo(shellSize, pointerFields)
classInfos.put(cls, newInfo)
val newInfo = new ClassInfo(shellSize, fieldOffsets.toArray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are loosing out on padding due to allignment here which the earlier code was computing. No ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into this and determine if the padding is needed

classInfos.get().put(cls, newInfo)
newInfo
}

Expand Down