Skip to content

Commit c03811f

Browse files
committed
Clean up Scala reflection garbage after creating Encoder
1 parent 7323186 commit c03811f

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ object ScalaReflection extends ScalaReflection {
6161
*/
6262
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
6363

64-
private def dataTypeFor(tpe: `Type`): DataType = {
64+
private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
6565
tpe.dealias match {
6666
case t if t <:< definitions.IntTpe => IntegerType
6767
case t if t <:< definitions.LongTpe => LongType
@@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection {
9393
* Special handling is performed for primitive types to map them back to their raw
9494
* JVM form instead of the Scala Array that handles auto boxing.
9595
*/
96-
private def arrayClassFor(tpe: `Type`): ObjectType = {
96+
private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects {
9797
val cls = tpe.dealias match {
9898
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
9999
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
@@ -140,7 +140,7 @@ object ScalaReflection extends ScalaReflection {
140140
private def deserializerFor(
141141
tpe: `Type`,
142142
path: Option[Expression],
143-
walkedTypePath: Seq[String]): Expression = {
143+
walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
144144

145145
/** Returns the current path with a sub-field extracted. */
146146
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
@@ -435,7 +435,7 @@ object ScalaReflection extends ScalaReflection {
435435
inputObject: Expression,
436436
tpe: `Type`,
437437
walkedTypePath: Seq[String],
438-
seenTypeSet: Set[`Type`] = Set.empty): Expression = {
438+
seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects {
439439

440440
def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
441441
dataTypeFor(elementType) match {
@@ -642,7 +642,7 @@ object ScalaReflection extends ScalaReflection {
642642
* Returns true if the given type is option of product type, e.g. `Option[Tuple2]`. Note that,
643643
* we also treat [[DefinedByConstructorParams]] as product type.
644644
*/
645-
def optionOfProductType(tpe: `Type`): Boolean = {
645+
def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
646646
tpe.dealias match {
647647
case t if t <:< localTypeOf[Option[_]] =>
648648
val TypeRef(_, _, Seq(optType)) = t
@@ -704,7 +704,7 @@ object ScalaReflection extends ScalaReflection {
704704
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
705705

706706
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
707-
def schemaFor(tpe: `Type`): Schema = {
707+
def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects {
708708
tpe.dealias match {
709709
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
710710
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
@@ -774,7 +774,7 @@ object ScalaReflection extends ScalaReflection {
774774
/**
775775
* Whether the fields of the given type is defined entirely by its constructor parameters.
776776
*/
777-
def definedByConstructorParams(tpe: Type): Boolean = {
777+
def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
778778
tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
779779
}
780780

@@ -803,6 +803,17 @@ trait ScalaReflection {
803803
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
804804
import scala.collection.Map
805805

806+
/**
807+
* Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to
808+
* clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
809+
* `scala.reflect.runtime.JavaUniverse.undoLog`.
810+
*
811+
* @see https://github.com/scala/bug/issues/8302
812+
*/
813+
def cleanUpReflectionObjects[T](func: => T): T = {
814+
universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
815+
}
816+
806817
/**
807818
* Return the Scala Type for `T` in the current classloader mirror.
808819
*

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ object ReferenceValueClass {
114114
class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
115115
OuterScopes.addOuterScope(this)
116116

117-
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
117+
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = verifyNotLeakingReflectionObjects {
118+
ExpressionEncoder()
119+
}
118120

119121
// test flat encoders
120122
encodeDecodeTest(false, "primitive boolean")
@@ -370,7 +372,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
370372
private def encodeDecodeTest[T : ExpressionEncoder](
371373
input: T,
372374
testName: String): Unit = {
373-
test(s"encode/decode for $testName: $input") {
375+
testAndVerifyNotLeakingReflectionObjects(s"encode/decode for $testName: $input") {
374376
val encoder = implicitly[ExpressionEncoder[T]]
375377
val row = encoder.toRow(input)
376378
val schema = encoder.schema.toAttributes
@@ -441,4 +443,28 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
441443
}
442444
}
443445
}
446+
447+
/**
448+
* Verify the size of scala.reflect.runtime.JavaUniverse.undoLog before and after `func` to
449+
* ensure we don't leak Scala reflection garbage.
450+
*
451+
* @see org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects
452+
*/
453+
private def verifyNotLeakingReflectionObjects[T](func: => T): T = {
454+
def undoLogSize: Int = {
455+
import scala.reflect.runtime.{JavaUniverse, universe}
456+
universe.asInstanceOf[JavaUniverse].undoLog.log.size
457+
}
458+
459+
val previousUndoLogSize = undoLogSize
460+
val r = func
461+
assert(previousUndoLogSize == undoLogSize)
462+
r
463+
}
464+
465+
private def testAndVerifyNotLeakingReflectionObjects(testName: String)(testFun: => Any) {
466+
test(testName) {
467+
verifyNotLeakingReflectionObjects(testFun)
468+
}
469+
}
444470
}

0 commit comments

Comments
 (0)