Skip to content

Commit fab4ca5

Browse files
hvanhovelldongjoon-hyun
authored andcommitted
[SPARK-31450][SQL] Make ExpressionEncoder thread-safe
### What changes were proposed in this pull request? This PR moves the `ExpressionEncoder.toRow` and `ExpressionEncoder.fromRow` functions into their own function objects(`ExpressionEncoder.Serializer` & `ExpressionEncoder.Deserializer`). This effectively makes the `ExpressionEncoder` stateless, thread-safe and (more) reusable. The function objects are not thread safe, however they are documented as such and should be used in a more limited scope (making it easier to reason about thread safety). ### Why are the changes needed? ExpressionEncoders are not thread-safe. We had various (nasty) bugs because of this. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests. Closes apache#28223 from hvanhovell/SPARK-31450. Authored-by: herman <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 8608189 commit fab4ca5

39 files changed

+282
-238
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
9090
if (requiredSchema.isEmpty) {
9191
filteredResult.map(_ => emptyUnsafeRow)
9292
} else {
93-
val converter = RowEncoder(requiredSchema)
94-
filteredResult.map(row => converter.toRow(row))
93+
val toRow = RowEncoder(requiredSchema).createSerializer()
94+
filteredResult.map(row => toRow(row))
9595
}
9696
}
9797
}

mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat
166166
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
167167
}
168168

169-
val converter = RowEncoder(dataSchema)
169+
val toRow = RowEncoder(dataSchema).createSerializer()
170170
val fullOutput = dataSchema.map { f =>
171171
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
172172
}
@@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat
178178

179179
points.map { pt =>
180180
val features = if (isSparse) pt.features.toSparse else pt.features.toDense
181-
requiredColumns(converter.toRow(Row(pt.label, features)))
181+
requiredColumns(toRow(Row(pt.label, features)))
182182
}
183183
}
184184
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,22 @@ object UDTSerializationBenchmark extends BenchmarkBase {
3838
val iters = 1e2.toInt
3939
val numRows = 1e3.toInt
4040

41-
val encoder = ExpressionEncoder[Vector].resolveAndBind()
41+
val encoder = ExpressionEncoder[Vector]().resolveAndBind()
42+
val toRow = encoder.createSerializer()
43+
val fromRow = encoder.createDeserializer()
4244

4345
val vectors = (1 to numRows).map { i =>
4446
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
4547
}.toArray
46-
val rows = vectors.map(encoder.toRow)
48+
val rows = vectors.map(toRow)
4749

4850
val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output)
4951

5052
benchmark.addCase("serialize") { _ =>
5153
var sum = 0
5254
var i = 0
5355
while (i < numRows) {
54-
sum += encoder.toRow(vectors(i)).numFields
56+
sum += toRow(vectors(i)).numFields
5557
i += 1
5658
}
5759
}
@@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
6062
var sum = 0
6163
var i = 0
6264
while (i < numRows) {
63-
sum += encoder.fromRow(rows(i)).numActives
65+
sum += fromRow(rows(i)).numActives
6466
i += 1
6567
}
6668
}

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ import org.apache.spark.sql.types._
5858
* }}}
5959
*
6060
* == Implementation ==
61-
* - Encoders are not required to be thread-safe and thus they do not need to use locks to guard
62-
* against concurrent access if they reuse internal buffers to improve performance.
61+
* - Encoders should be thread-safe.
6362
*
6463
* @since 1.6.0
6564
*/
@@ -76,10 +75,4 @@ trait Encoder[T] extends Serializable {
7675
* A ClassTag that can be used to construct an Array to contain a collection of `T`.
7776
*/
7877
def clsTag: ClassTag[T]
79-
80-
/**
81-
* Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed
82-
* up the [[Encoder]] creation.
83-
*/
84-
def makeCopy: Encoder[T]
8578
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.encoders
1919

20+
import java.io.ObjectInputStream
21+
2022
import scala.reflect.ClassTag
2123
import scala.reflect.runtime.universe.{typeTag, TypeTag}
2224

2325
import org.apache.spark.sql.Encoder
2426
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
2527
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
28+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
2629
import org.apache.spark.sql.catalyst.expressions._
2730
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
2831
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
@@ -162,6 +165,56 @@ object ExpressionEncoder {
162165
e4: ExpressionEncoder[T4],
163166
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
164167
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
168+
169+
private val anyObjectType = ObjectType(classOf[Any])
170+
171+
/**
172+
* Function that deserializes an [[InternalRow]] into an object of type `T`. This class is not
173+
* thread-safe.
174+
*/
175+
class Deserializer[T](private val expressions: Seq[Expression])
176+
extends (InternalRow => T) with Serializable {
177+
@transient
178+
private[this] var constructProjection: Projection = _
179+
180+
override def apply(row: InternalRow): T = try {
181+
if (constructProjection == null) {
182+
constructProjection = SafeProjection.create(expressions)
183+
}
184+
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
185+
} catch {
186+
case e: Exception =>
187+
throw new RuntimeException(s"Error while decoding: $e\n" +
188+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
189+
}
190+
}
191+
192+
/**
193+
* Function that serializesa an object of type `T` to an [[InternalRow]]. This class is not
194+
* thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow]]
195+
* object. Thus, the caller should copy the result before making another call if required.
196+
*/
197+
class Serializer[T](private val expressions: Seq[Expression])
198+
extends (T => InternalRow) with Serializable {
199+
@transient
200+
private[this] var inputRow: GenericInternalRow = _
201+
202+
@transient
203+
private[this] var extractProjection: UnsafeProjection = _
204+
205+
override def apply(t: T): InternalRow = try {
206+
if (extractProjection == null) {
207+
inputRow = new GenericInternalRow(1)
208+
extractProjection = GenerateUnsafeProjection.generate(expressions)
209+
}
210+
inputRow(0) = t
211+
extractProjection(inputRow)
212+
} catch {
213+
case e: Exception =>
214+
throw new RuntimeException(s"Error while encoding: $e\n" +
215+
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
216+
}
217+
}
165218
}
166219

167220
/**
@@ -302,25 +355,22 @@ case class ExpressionEncoder[T](
302355
}
303356

304357
@transient
305-
private lazy val extractProjection = GenerateUnsafeProjection.generate({
358+
private lazy val optimizedDeserializer: Seq[Expression] = {
306359
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
307360
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
308361
// important to codegen performance.
309-
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
362+
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
310363
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
311-
})
312-
313-
@transient
314-
private lazy val inputRow = new GenericInternalRow(1)
364+
}
315365

316366
@transient
317-
private lazy val constructProjection = SafeProjection.create({
367+
private lazy val optimizedSerializer = {
318368
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
319369
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
320370
// important to codegen performance.
321-
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
371+
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
322372
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
323-
})
373+
}
324374

325375
/**
326376
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
@@ -332,31 +382,21 @@ case class ExpressionEncoder[T](
332382
}
333383

334384
/**
335-
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
336-
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
337-
* copy the result before making another call if required.
385+
* Create a serializer that can convert an object of type `T` to a Spark SQL Row.
386+
*
387+
* Note that the returned [[Serializer]] is not thread safe. Multiple calls to
388+
* `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus,
389+
* the caller should copy the result before making another call if required.
338390
*/
339-
def toRow(t: T): InternalRow = try {
340-
inputRow(0) = t
341-
extractProjection(inputRow)
342-
} catch {
343-
case e: Exception =>
344-
throw new RuntimeException(s"Error while encoding: $e\n" +
345-
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
346-
}
391+
def createSerializer(): Serializer[T] = new Serializer[T](optimizedSerializer)
347392

348393
/**
349-
* Returns an object of type `T`, extracting the required values from the provided row. Note that
350-
* you must `resolveAndBind` an encoder to a specific schema before you can call this
351-
* function.
394+
* Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
395+
*
396+
* Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
397+
* deserializer.
352398
*/
353-
def fromRow(row: InternalRow): T = try {
354-
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
355-
} catch {
356-
case e: Exception =>
357-
throw new RuntimeException(s"Error while decoding: $e\n" +
358-
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
359-
}
399+
def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer)
360400

361401
/**
362402
* The process of resolution to a given schema throws away information about where a given field
@@ -383,8 +423,6 @@ case class ExpressionEncoder[T](
383423
.map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
384424

385425
override def toString: String = s"class[$schemaString]"
386-
387-
override def makeCopy: ExpressionEncoder[T] = copy()
388426
}
389427

390428
// A dummy logical plan that can hold expressions and go through optimizer rules.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ case class ScalaUDF(
110110
} else {
111111
val encoder = inputEncoders(i)
112112
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
113-
val enc = encoder.get.resolveAndBind()
114-
row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
113+
val fromRow = encoder.get.resolveAndBind().createDeserializer()
114+
row: Any => fromRow(row.asInstanceOf[InternalRow])
115115
} else {
116116
CatalystTypeConverters.createToScalaConverter(dataType)
117117
}

sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase {
4141
def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = {
4242
runBenchmark(name) {
4343
val generator = RandomDataGenerator.forType(schema, nullable = false).get
44-
val encoder = RowEncoder(schema)
44+
val toRow = RowEncoder(schema).createSerializer()
4545
val attrs = schema.toAttributes
4646
val safeProjection = GenerateSafeProjection.generate(attrs, attrs)
4747

4848
val rows = (1 to numRows).map(_ =>
4949
// The output of encoder is UnsafeRow, use safeProjection to turn in into safe format.
50-
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
50+
safeProjection(toRow(generator().asInstanceOf[Row])).copy()
5151
).toArray
5252

5353
val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output)

sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase {
3737

3838
def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
3939
val generator = RandomDataGenerator.forType(schema, nullable = false).get
40-
val encoder = RowEncoder(schema)
41-
(1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
40+
val toRow = RowEncoder(schema).createSerializer()
41+
(1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray
4242
}
4343

4444
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.dsl.expressions._
25+
import org.apache.spark.sql.catalyst.expressions.Attribute
2526
import org.apache.spark.sql.catalyst.plans.PlanTest
2627
import org.apache.spark.sql.catalyst.util.GenericArrayData
2728
import org.apache.spark.sql.types._
@@ -42,37 +43,44 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass])
4243
class EncoderResolutionSuite extends PlanTest {
4344
private val str = UTF8String.fromString("hello")
4445

46+
def testFromRow[T](
47+
encoder: ExpressionEncoder[T],
48+
attributes: Seq[Attribute],
49+
row: InternalRow): Unit = {
50+
encoder.resolveAndBind(attributes).createDeserializer().apply(row)
51+
}
52+
4553
test("real type doesn't match encoder schema but they are compatible: product") {
4654
val encoder = ExpressionEncoder[StringLongClass]
4755

4856
// int type can be up cast to long type
4957
val attrs1 = Seq('a.string, 'b.int)
50-
encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
58+
testFromRow(encoder, attrs1, InternalRow(str, 1))
5159

5260
// int type can be up cast to string type
5361
val attrs2 = Seq('a.int, 'b.long)
54-
encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
62+
testFromRow(encoder, attrs2, InternalRow(1, 2L))
5563
}
5664

5765
test("real type doesn't match encoder schema but they are compatible: nested product") {
5866
val encoder = ExpressionEncoder[ComplexClass]
5967
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
60-
encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
68+
testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
6169
}
6270

6371
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
6472
val encoder = ExpressionEncoder.tuple(
6573
ExpressionEncoder[StringLongClass],
6674
ExpressionEncoder[Long])
6775
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
68-
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
76+
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
6977
}
7078

7179
test("real type doesn't match encoder schema but they are compatible: primitive array") {
7280
val encoder = ExpressionEncoder[PrimitiveArrayClass]
7381
val attrs = Seq('arr.array(IntegerType))
7482
val array = new GenericArrayData(Array(1, 2, 3))
75-
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
83+
testFromRow(encoder, attrs, InternalRow(array))
7684
}
7785

7886
test("the real type is not compatible with encoder schema: primitive array") {
@@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest {
93101
val encoder = ExpressionEncoder[ArrayClass]
94102
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
95103
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
96-
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
104+
testFromRow(encoder, attrs, InternalRow(array))
97105
}
98106

99107
test("real type doesn't match encoder schema but they are compatible: nested array") {
@@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest {
103111
val attrs = Seq('nestedArr.array(et))
104112
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
105113
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
106-
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
114+
testFromRow(encoder, attrs, InternalRow(outerArr))
107115
}
108116

109117
test("the real type is not compatible with encoder schema: non-array field") {
@@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest {
142150
val attrs = 'a.array(IntegerType) :: Nil
143151

144152
// It should pass analysis
145-
val bound = encoder.resolveAndBind(attrs)
153+
val fromRow = encoder.resolveAndBind(attrs).createDeserializer()
146154

147155
// If no null values appear, it should work fine
148-
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
156+
fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
149157

150158
// If there is null value, it should throw runtime exception
151159
val e = intercept[RuntimeException] {
152-
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
160+
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
153161
}
154162
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
155163
}

0 commit comments

Comments
 (0)