Skip to content

Commit ab9d6fc

Browse files
committed
Add returnNullable parameter to callers of StaticInvoke.
1 parent 886beb0 commit ab9d6fc

File tree

3 files changed

+58
-29
lines changed

3 files changed

+58
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,17 @@ object JavaTypeInference {
204204
ObjectType(c),
205205
"toJavaDate",
206206
getPath :: Nil,
207-
propagateNull = true)
207+
propagateNull = true,
208+
returnNullable = false)
208209

209210
case c if c == classOf[java.sql.Timestamp] =>
210211
StaticInvoke(
211212
DateTimeUtils.getClass,
212213
ObjectType(c),
213214
"toJavaTimestamp",
214215
getPath :: Nil,
215-
propagateNull = true)
216+
propagateNull = true,
217+
returnNullable = false)
216218

217219
case c if c == classOf[java.lang.String] =>
218220
Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -256,7 +258,8 @@ object JavaTypeInference {
256258
"array",
257259
ObjectType(classOf[Array[Any]]))
258260

259-
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
261+
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil,
262+
returnNullable = false)
260263

261264
case _ if mapType.isAssignableFrom(typeToken) =>
262265
val (keyType, valueType) = mapKeyValueType(typeToken)
@@ -285,7 +288,8 @@ object JavaTypeInference {
285288
ArrayBasedMapData.getClass,
286289
ObjectType(classOf[JMap[_, _]]),
287290
"toJavaMap",
288-
keyData :: valueData :: Nil)
291+
keyData :: valueData :: Nil,
292+
returnNullable = false)
289293

290294
case other =>
291295
val properties = getJavaBeanProperties(other)
@@ -350,28 +354,32 @@ object JavaTypeInference {
350354
classOf[UTF8String],
351355
StringType,
352356
"fromString",
353-
inputObject :: Nil)
357+
inputObject :: Nil,
358+
returnNullable = true)
354359

355360
case c if c == classOf[java.sql.Timestamp] =>
356361
StaticInvoke(
357362
DateTimeUtils.getClass,
358363
TimestampType,
359364
"fromJavaTimestamp",
360-
inputObject :: Nil)
365+
inputObject :: Nil,
366+
returnNullable = false)
361367

362368
case c if c == classOf[java.sql.Date] =>
363369
StaticInvoke(
364370
DateTimeUtils.getClass,
365371
DateType,
366372
"fromJavaDate",
367-
inputObject :: Nil)
373+
inputObject :: Nil,
374+
returnNullable = false)
368375

369376
case c if c == classOf[java.math.BigDecimal] =>
370377
StaticInvoke(
371378
Decimal.getClass,
372379
DecimalType.SYSTEM_DEFAULT,
373380
"apply",
374-
inputObject :: Nil)
381+
inputObject :: Nil,
382+
returnNullable = false)
375383

376384
case c if c == classOf[java.lang.Boolean] =>
377385
Invoke(inputObject, "booleanValue", BooleanType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,16 @@ object ScalaReflection extends ScalaReflection {
239239
DateTimeUtils.getClass,
240240
ObjectType(classOf[java.sql.Date]),
241241
"toJavaDate",
242-
getPath :: Nil)
242+
getPath :: Nil,
243+
returnNullable = false)
243244

244245
case t if t <:< localTypeOf[java.sql.Timestamp] =>
245246
StaticInvoke(
246247
DateTimeUtils.getClass,
247248
ObjectType(classOf[java.sql.Timestamp]),
248249
"toJavaTimestamp",
249-
getPath :: Nil)
250+
getPath :: Nil,
251+
returnNullable = false)
250252

251253
case t if t <:< localTypeOf[java.lang.String] =>
252254
Invoke(getPath, "toString", ObjectType(classOf[String]))
@@ -316,7 +318,8 @@ object ScalaReflection extends ScalaReflection {
316318
scala.collection.mutable.WrappedArray.getClass,
317319
ObjectType(classOf[Seq[_]]),
318320
"make",
319-
array :: Nil)
321+
array :: Nil,
322+
returnNullable = false)
320323

321324
case t if t <:< localTypeOf[Map[_, _]] =>
322325
// TODO: add walked type path for map
@@ -344,7 +347,8 @@ object ScalaReflection extends ScalaReflection {
344347
ArrayBasedMapData.getClass,
345348
ObjectType(classOf[Map[_, _]]),
346349
"toScalaMap",
347-
keyData :: valueData :: Nil)
350+
keyData :: valueData :: Nil,
351+
returnNullable = false)
348352

349353
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
350354
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
@@ -449,7 +453,8 @@ object ScalaReflection extends ScalaReflection {
449453
classOf[UnsafeArrayData],
450454
ArrayType(dt, false),
451455
"fromPrimitiveArray",
452-
input :: Nil)
456+
input :: Nil,
457+
returnNullable = false)
453458
} else {
454459
NewInstance(
455460
classOf[GenericArrayData],
@@ -505,49 +510,56 @@ object ScalaReflection extends ScalaReflection {
505510
classOf[UTF8String],
506511
StringType,
507512
"fromString",
508-
inputObject :: Nil)
513+
inputObject :: Nil,
514+
returnNullable = true)
509515

510516
case t if t <:< localTypeOf[java.sql.Timestamp] =>
511517
StaticInvoke(
512518
DateTimeUtils.getClass,
513519
TimestampType,
514520
"fromJavaTimestamp",
515-
inputObject :: Nil)
521+
inputObject :: Nil,
522+
returnNullable = false)
516523

517524
case t if t <:< localTypeOf[java.sql.Date] =>
518525
StaticInvoke(
519526
DateTimeUtils.getClass,
520527
DateType,
521528
"fromJavaDate",
522-
inputObject :: Nil)
529+
inputObject :: Nil,
530+
returnNullable = false)
523531

524532
case t if t <:< localTypeOf[BigDecimal] =>
525533
StaticInvoke(
526534
Decimal.getClass,
527535
DecimalType.SYSTEM_DEFAULT,
528536
"apply",
529-
inputObject :: Nil)
537+
inputObject :: Nil,
538+
returnNullable = false)
530539

531540
case t if t <:< localTypeOf[java.math.BigDecimal] =>
532541
StaticInvoke(
533542
Decimal.getClass,
534543
DecimalType.SYSTEM_DEFAULT,
535544
"apply",
536-
inputObject :: Nil)
545+
inputObject :: Nil,
546+
returnNullable = false)
537547

538548
case t if t <:< localTypeOf[java.math.BigInteger] =>
539549
StaticInvoke(
540550
Decimal.getClass,
541551
DecimalType.BigIntDecimal,
542552
"apply",
543-
inputObject :: Nil)
553+
inputObject :: Nil,
554+
returnNullable = false)
544555

545556
case t if t <:< localTypeOf[scala.math.BigInt] =>
546557
StaticInvoke(
547558
Decimal.getClass,
548559
DecimalType.BigIntDecimal,
549560
"apply",
550-
inputObject :: Nil)
561+
inputObject :: Nil,
562+
returnNullable = false)
551563

552564
case t if t <:< localTypeOf[java.lang.Integer] =>
553565
Invoke(inputObject, "intValue", IntegerType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,28 +96,32 @@ object RowEncoder {
9696
DateTimeUtils.getClass,
9797
TimestampType,
9898
"fromJavaTimestamp",
99-
inputObject :: Nil)
99+
inputObject :: Nil,
100+
returnNullable = false)
100101

101102
case DateType =>
102103
StaticInvoke(
103104
DateTimeUtils.getClass,
104105
DateType,
105106
"fromJavaDate",
106-
inputObject :: Nil)
107+
inputObject :: Nil,
108+
returnNullable = false)
107109

108110
case d: DecimalType =>
109111
StaticInvoke(
110112
Decimal.getClass,
111113
d,
112114
"fromDecimal",
113-
inputObject :: Nil)
115+
inputObject :: Nil,
116+
returnNullable = false)
114117

115118
case StringType =>
116119
StaticInvoke(
117120
classOf[UTF8String],
118121
StringType,
119122
"fromString",
120-
inputObject :: Nil)
123+
inputObject :: Nil,
124+
returnNullable = true)
121125

122126
case t @ ArrayType(et, cn) =>
123127
et match {
@@ -126,7 +130,8 @@ object RowEncoder {
126130
classOf[ArrayData],
127131
t,
128132
"toArrayData",
129-
inputObject :: Nil)
133+
inputObject :: Nil,
134+
returnNullable = false)
130135
case _ => MapObjects(
131136
element => serializerFor(ValidateExternalType(element, et), et),
132137
inputObject,
@@ -252,14 +257,16 @@ object RowEncoder {
252257
DateTimeUtils.getClass,
253258
ObjectType(classOf[java.sql.Timestamp]),
254259
"toJavaTimestamp",
255-
input :: Nil)
260+
input :: Nil,
261+
returnNullable = false)
256262

257263
case DateType =>
258264
StaticInvoke(
259265
DateTimeUtils.getClass,
260266
ObjectType(classOf[java.sql.Date]),
261267
"toJavaDate",
262-
input :: Nil)
268+
input :: Nil,
269+
returnNullable = false)
263270

264271
case _: DecimalType =>
265272
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
@@ -277,7 +284,8 @@ object RowEncoder {
277284
scala.collection.mutable.WrappedArray.getClass,
278285
ObjectType(classOf[Seq[_]]),
279286
"make",
280-
arrayData :: Nil)
287+
arrayData :: Nil,
288+
returnNullable = false)
281289

282290
case MapType(kt, vt, valueNullable) =>
283291
val keyArrayType = ArrayType(kt, false)
@@ -290,7 +298,8 @@ object RowEncoder {
290298
ArrayBasedMapData.getClass,
291299
ObjectType(classOf[Map[_, _]]),
292300
"toScalaMap",
293-
keyData :: valueData :: Nil)
301+
keyData :: valueData :: Nil,
302+
returnNullable = false)
294303

295304
case schema @ StructType(fields) =>
296305
val convertedFields = fields.zipWithIndex.map { case (f, i) =>

0 commit comments

Comments
 (0)