Skip to content

Commit ae1b642

Browse files
committed
fix
1 parent fe5145e commit ae1b642

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,50 @@ class KryoSerializer(conf: SparkConf)
8888
private val useUnsafe = conf.get(KRYO_USE_UNSAFE)
8989
private val usePool = conf.get(KRYO_USE_POOL)
9090

91+
// classForName() is expensive in case the class is not found, so we filter the list of
92+
// SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls.
93+
private lazy val loadableClasses: Seq[Class[_]] = {
94+
Seq(
95+
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
96+
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
97+
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
98+
99+
"org.apache.spark.ml.attribute.Attribute",
100+
"org.apache.spark.ml.attribute.AttributeGroup",
101+
"org.apache.spark.ml.attribute.BinaryAttribute",
102+
"org.apache.spark.ml.attribute.NominalAttribute",
103+
"org.apache.spark.ml.attribute.NumericAttribute",
104+
105+
"org.apache.spark.ml.feature.Instance",
106+
"org.apache.spark.ml.feature.LabeledPoint",
107+
"org.apache.spark.ml.feature.OffsetInstance",
108+
"org.apache.spark.ml.linalg.DenseMatrix",
109+
"org.apache.spark.ml.linalg.DenseVector",
110+
"org.apache.spark.ml.linalg.Matrix",
111+
"org.apache.spark.ml.linalg.SparseMatrix",
112+
"org.apache.spark.ml.linalg.SparseVector",
113+
"org.apache.spark.ml.linalg.Vector",
114+
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
115+
"org.apache.spark.ml.tree.impl.TreePoint",
116+
"org.apache.spark.mllib.clustering.VectorWithNorm",
117+
"org.apache.spark.mllib.linalg.DenseMatrix",
118+
"org.apache.spark.mllib.linalg.DenseVector",
119+
"org.apache.spark.mllib.linalg.Matrix",
120+
"org.apache.spark.mllib.linalg.SparseMatrix",
121+
"org.apache.spark.mllib.linalg.SparseVector",
122+
"org.apache.spark.mllib.linalg.Vector",
123+
"org.apache.spark.mllib.regression.LabeledPoint",
124+
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
125+
).flatMap { name =>
126+
try {
127+
Some[Class[_]](Utils.classForName(name))
128+
} catch {
129+
case NonFatal(_) => None // do nothing
130+
case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422.
131+
}
132+
}
133+
}
134+
91135
def newKryoOutput(): KryoOutput =
92136
if (useUnsafe) {
93137
new KryoUnsafeOutput(bufferSize, math.max(bufferSize, maxBufferSize))
@@ -212,40 +256,8 @@ class KryoSerializer(conf: SparkConf)
212256

213257
// We can't load those class directly in order to avoid unnecessary jar dependencies.
214258
// We load them safely, ignore it if the class not found.
215-
Seq(
216-
"org.apache.spark.sql.catalyst.expressions.UnsafeRow",
217-
"org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
218-
"org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
219-
220-
"org.apache.spark.ml.attribute.Attribute",
221-
"org.apache.spark.ml.attribute.AttributeGroup",
222-
"org.apache.spark.ml.attribute.BinaryAttribute",
223-
"org.apache.spark.ml.attribute.NominalAttribute",
224-
"org.apache.spark.ml.attribute.NumericAttribute",
225-
226-
"org.apache.spark.ml.feature.Instance",
227-
"org.apache.spark.ml.feature.LabeledPoint",
228-
"org.apache.spark.ml.feature.OffsetInstance",
229-
"org.apache.spark.ml.linalg.DenseMatrix",
230-
"org.apache.spark.ml.linalg.DenseVector",
231-
"org.apache.spark.ml.linalg.Matrix",
232-
"org.apache.spark.ml.linalg.SparseMatrix",
233-
"org.apache.spark.ml.linalg.SparseVector",
234-
"org.apache.spark.ml.linalg.Vector",
235-
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
236-
"org.apache.spark.ml.tree.impl.TreePoint",
237-
"org.apache.spark.mllib.clustering.VectorWithNorm",
238-
"org.apache.spark.mllib.linalg.DenseMatrix",
239-
"org.apache.spark.mllib.linalg.DenseVector",
240-
"org.apache.spark.mllib.linalg.Matrix",
241-
"org.apache.spark.mllib.linalg.SparseMatrix",
242-
"org.apache.spark.mllib.linalg.SparseVector",
243-
"org.apache.spark.mllib.linalg.Vector",
244-
"org.apache.spark.mllib.regression.LabeledPoint",
245-
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
246-
).foreach { name =>
259+
loadableClasses.foreach { clazz =>
247260
try {
248-
val clazz = Utils.classForName(name)
249261
kryo.register(clazz)
250262
} catch {
251263
case NonFatal(_) => // do nothing

0 commit comments

Comments
 (0)