Skip to content

Commit 15557a7

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-31071][SQL] Allow annotating non-null fields when encoding Java Beans
### What changes were proposed in this pull request? When encoding Java Beans to Spark DataFrame, respecting `javax.annotation.Nonnull` and producing non-null fields. ### Why are the changes needed? When encoding Java Beans to Spark DataFrame, non-primitive types are encoded as nullable fields. Although It works for most cases, it can be an issue under a few situations, e.g. the one described in the JIRA ticket when saving DataFrame to Avro format with non-null field. We should allow Spark users more flexibility when creating Spark DataFrame from Java Beans. Currently, Spark users cannot create DataFrame with non-nullable fields in the schema from beans with non-nullable properties. Although it is possible to project top-level columns with SQL expressions like `AssertNotNull` to make it non-null, for nested fields it is more tricky to do it similarly. ### Does this PR introduce any user-facing change? Yes. After this change, Spark users can use `javax.annotation.Nonnull` to annotate non-null fields in Java Beans when encoding beans to Spark DataFrame. ### How was this patch tested? Added unit test. Closes apache#27851 from viirya/SPARK-31071. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3493162 commit 15557a7

File tree

4 files changed

+136
-5
lines changed

4 files changed

+136
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.beans.{Introspector, PropertyDescriptor}
2121
import java.lang.{Iterable => JIterable}
2222
import java.lang.reflect.Type
2323
import java.util.{Iterator => JIterator, List => JList, Map => JMap}
24+
import javax.annotation.Nonnull
2425

2526
import scala.language.existentials
2627

@@ -148,7 +149,9 @@ object JavaTypeInference {
148149
val fields = properties.map { property =>
149150
val returnType = typeToken.method(property.getReadMethod).getReturnType
150151
val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
151-
new StructField(property.getName, dataType, nullable)
152+
// The existence of `javax.annotation.Nonnull`, means this field is not nullable.
153+
val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull])
154+
new StructField(property.getName, dataType, nullable && !hasNonNull)
152155
}
153156
(new StructType(fields), true)
154157
}
@@ -340,10 +343,12 @@ object JavaTypeInference {
340343
val fieldType = typeToken.method(p.getReadMethod).getReturnType
341344
val (dataType, nullable) = inferDataType(fieldType)
342345
val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
346+
// The existence of `javax.annotation.Nonnull`, means this field is not nullable.
347+
val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
343348
val setter = expressionWithNullSafety(
344349
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
345350
newTypePath),
346-
nullable = nullable,
351+
nullable = nullable && !hasNonNull,
347352
newTypePath)
348353
p.getWriteMethod.getName -> setter
349354
}.toMap
@@ -442,10 +447,13 @@ object JavaTypeInference {
442447
val fields = properties.map { p =>
443448
val fieldName = p.getName
444449
val fieldType = typeToken.method(p.getReadMethod).getReturnType
450+
val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull])
445451
val fieldValue = Invoke(
446452
inputObject,
447453
p.getReadMethod.getName,
448-
inferExternalType(fieldType.getRawType))
454+
inferExternalType(fieldType.getRawType),
455+
propagateNull = !hasNonNull,
456+
returnNullable = !hasNonNull)
449457
(fieldName, serializerFor(fieldValue, fieldType))
450458
}
451459
createSerializerForObject(inputObject, fields)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,12 @@ object SerializerBuildHelper {
187187
val nonNullOutput = CreateNamedStruct(fields.flatMap { case(fieldName, fieldExpr) =>
188188
argumentsForFieldSerializer(fieldName, fieldExpr)
189189
})
190-
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
191-
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
190+
if (inputObject.nullable) {
191+
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
192+
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
193+
} else {
194+
nonNullOutput
195+
}
192196
}
193197

194198
def createSerializerForUserDefinedType(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ case class ExpressionEncoder[T](
203203
}
204204
nullSafeSerializer match {
205205
case If(_: IsNull, _, s: CreateNamedStruct) => s
206+
case s: CreateNamedStruct => s
206207
case _ =>
207208
throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer")
208209
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.time.Instant;
2525
import java.time.LocalDate;
2626
import java.util.*;
27+
import javax.annotation.Nonnull;
2728

2829
import org.apache.spark.sql.streaming.GroupStateTimeout;
2930
import org.apache.spark.sql.streaming.OutputMode;
@@ -823,6 +824,75 @@ public int hashCode() {
823824
}
824825
}
825826

827+
public static class NestedSmallBeanWithNonNullField implements Serializable {
828+
private SmallBean nonNull_f;
829+
private SmallBean nullable_f;
830+
private Map<String, SmallBean> childMap;
831+
832+
@Nonnull
833+
public SmallBean getNonNull_f() {
834+
return nonNull_f;
835+
}
836+
837+
public void setNonNull_f(SmallBean f) {
838+
this.nonNull_f = f;
839+
}
840+
841+
public SmallBean getNullable_f() {
842+
return nullable_f;
843+
}
844+
845+
public void setNullable_f(SmallBean f) {
846+
this.nullable_f = f;
847+
}
848+
849+
@Nonnull
850+
public Map<String, SmallBean> getChildMap() { return childMap; }
851+
public void setChildMap(Map<String, SmallBean> childMap) {
852+
this.childMap = childMap;
853+
}
854+
855+
@Override
856+
public boolean equals(Object o) {
857+
if (this == o) return true;
858+
if (o == null || getClass() != o.getClass()) return false;
859+
NestedSmallBeanWithNonNullField that = (NestedSmallBeanWithNonNullField) o;
860+
return Objects.equal(nullable_f, that.nullable_f) &&
861+
Objects.equal(nonNull_f, that.nonNull_f) && Objects.equal(childMap, that.childMap);
862+
}
863+
864+
@Override
865+
public int hashCode() {
866+
return Objects.hashCode(nullable_f, nonNull_f, childMap);
867+
}
868+
}
869+
870+
public static class NestedSmallBean2 implements Serializable {
871+
private NestedSmallBeanWithNonNullField f;
872+
873+
@Nonnull
874+
public NestedSmallBeanWithNonNullField getF() {
875+
return f;
876+
}
877+
878+
public void setF(NestedSmallBeanWithNonNullField f) {
879+
this.f = f;
880+
}
881+
882+
@Override
883+
public boolean equals(Object o) {
884+
if (this == o) return true;
885+
if (o == null || getClass() != o.getClass()) return false;
886+
NestedSmallBean2 that = (NestedSmallBean2) o;
887+
return Objects.equal(f, that.f);
888+
}
889+
890+
@Override
891+
public int hashCode() {
892+
return Objects.hashCode(f);
893+
}
894+
}
895+
826896
@Rule
827897
public transient ExpectedException nullabilityCheck = ExpectedException.none();
828898

@@ -1504,6 +1574,54 @@ public void testSerializeNull() {
15041574
Assert.assertEquals(beans, ds2.collectAsList());
15051575
}
15061576

1577+
@Test
1578+
public void testNonNullField() {
1579+
NestedSmallBeanWithNonNullField bean1 = new NestedSmallBeanWithNonNullField();
1580+
SmallBean smallBean = new SmallBean();
1581+
bean1.setNonNull_f(smallBean);
1582+
Map<String, SmallBean> map = new HashMap<>();
1583+
bean1.setChildMap(map);
1584+
1585+
Encoder<NestedSmallBeanWithNonNullField> encoder1 =
1586+
Encoders.bean(NestedSmallBeanWithNonNullField.class);
1587+
List<NestedSmallBeanWithNonNullField> beans1 = Arrays.asList(bean1);
1588+
Dataset<NestedSmallBeanWithNonNullField> ds1 = spark.createDataset(beans1, encoder1);
1589+
1590+
StructType schema = ds1.schema();
1591+
Assert.assertFalse(schema.apply("nonNull_f").nullable());
1592+
Assert.assertTrue(schema.apply("nullable_f").nullable());
1593+
Assert.assertFalse(schema.apply("childMap").nullable());
1594+
1595+
Assert.assertEquals(beans1, ds1.collectAsList());
1596+
Dataset<NestedSmallBeanWithNonNullField> ds2 = ds1.map(
1597+
(MapFunction<NestedSmallBeanWithNonNullField, NestedSmallBeanWithNonNullField>) b -> b,
1598+
encoder1);
1599+
Assert.assertEquals(beans1, ds2.collectAsList());
1600+
1601+
// Nonnull nested fields
1602+
NestedSmallBean2 bean2 = new NestedSmallBean2();
1603+
bean2.setF(bean1);
1604+
1605+
Encoder<NestedSmallBean2> encoder2 =
1606+
Encoders.bean(NestedSmallBean2.class);
1607+
List<NestedSmallBean2> beans2 = Arrays.asList(bean2);
1608+
Dataset<NestedSmallBean2> ds3 = spark.createDataset(beans2, encoder2);
1609+
1610+
StructType nestedSchema = (StructType) ds3.schema()
1611+
.fields()[ds3.schema().fieldIndex("f")]
1612+
.dataType();
1613+
Assert.assertFalse(nestedSchema.apply("nonNull_f").nullable());
1614+
Assert.assertTrue(nestedSchema.apply("nullable_f").nullable());
1615+
Assert.assertFalse(nestedSchema.apply("childMap").nullable());
1616+
1617+
Assert.assertEquals(beans2, ds3.collectAsList());
1618+
1619+
Dataset<NestedSmallBean2> ds4 = ds3.map(
1620+
(MapFunction<NestedSmallBean2, NestedSmallBean2>) b -> b,
1621+
encoder2);
1622+
Assert.assertEquals(beans2, ds4.collectAsList());
1623+
}
1624+
15071625
@Test
15081626
public void testSpecificLists() {
15091627
SpecificListsBean bean = new SpecificListsBean();

0 commit comments

Comments
 (0)