diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 39691069bf5f..20774c8d999c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -212,40 +212,8 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. - Seq( - "org.apache.spark.sql.catalyst.expressions.UnsafeRow", - "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", - "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", - - "org.apache.spark.ml.attribute.Attribute", - "org.apache.spark.ml.attribute.AttributeGroup", - "org.apache.spark.ml.attribute.BinaryAttribute", - "org.apache.spark.ml.attribute.NominalAttribute", - "org.apache.spark.ml.attribute.NumericAttribute", - - "org.apache.spark.ml.feature.Instance", - "org.apache.spark.ml.feature.LabeledPoint", - "org.apache.spark.ml.feature.OffsetInstance", - "org.apache.spark.ml.linalg.DenseMatrix", - "org.apache.spark.ml.linalg.DenseVector", - "org.apache.spark.ml.linalg.Matrix", - "org.apache.spark.ml.linalg.SparseMatrix", - "org.apache.spark.ml.linalg.SparseVector", - "org.apache.spark.ml.linalg.Vector", - "org.apache.spark.ml.stat.distribution.MultivariateGaussian", - "org.apache.spark.ml.tree.impl.TreePoint", - "org.apache.spark.mllib.clustering.VectorWithNorm", - "org.apache.spark.mllib.linalg.DenseMatrix", - "org.apache.spark.mllib.linalg.DenseVector", - "org.apache.spark.mllib.linalg.Matrix", - "org.apache.spark.mllib.linalg.SparseMatrix", - "org.apache.spark.mllib.linalg.SparseVector", - "org.apache.spark.mllib.linalg.Vector", - "org.apache.spark.mllib.regression.LabeledPoint", - "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" - ).foreach { name => + KryoSerializer.loadableSparkClasses.foreach { clazz => try { - val clazz = Utils.classForName(name) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing @@ -516,6 +484,50 @@ private[serializer] object KryoSerializer { } } ) + + // classForName() is expensive in case the class is not found, so we filter the list of + // SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls. + private lazy val loadableSparkClasses: Seq[Class[_]] = { + Seq( + "org.apache.spark.sql.catalyst.expressions.UnsafeRow", + "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", + "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", + + "org.apache.spark.ml.attribute.Attribute", + "org.apache.spark.ml.attribute.AttributeGroup", + "org.apache.spark.ml.attribute.BinaryAttribute", + "org.apache.spark.ml.attribute.NominalAttribute", + "org.apache.spark.ml.attribute.NumericAttribute", + + "org.apache.spark.ml.feature.Instance", + "org.apache.spark.ml.feature.LabeledPoint", + "org.apache.spark.ml.feature.OffsetInstance", + "org.apache.spark.ml.linalg.DenseMatrix", + "org.apache.spark.ml.linalg.DenseVector", + "org.apache.spark.ml.linalg.Matrix", + "org.apache.spark.ml.linalg.SparseMatrix", + "org.apache.spark.ml.linalg.SparseVector", + "org.apache.spark.ml.linalg.Vector", + "org.apache.spark.ml.stat.distribution.MultivariateGaussian", + "org.apache.spark.ml.tree.impl.TreePoint", + "org.apache.spark.mllib.clustering.VectorWithNorm", + "org.apache.spark.mllib.linalg.DenseMatrix", + "org.apache.spark.mllib.linalg.DenseVector", + "org.apache.spark.mllib.linalg.Matrix", + "org.apache.spark.mllib.linalg.SparseMatrix", + "org.apache.spark.mllib.linalg.SparseVector", + "org.apache.spark.mllib.linalg.Vector", + "org.apache.spark.mllib.regression.LabeledPoint", + "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" + ).flatMap { name => + try { + Some[Class[_]](Utils.classForName(name)) + } catch { + case NonFatal(_) => None // do nothing + case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422. + } + } + } } /**