diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index 49313a3ce9152..bf207d6068f73 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -100,7 +100,7 @@ private[sql] object ProtobufUtils extends Logging { */ def validateNoExtraRequiredProtoFields(): Unit = { val extraFields = protoFieldArray.toSet -- matchedFields.map(_.fieldDescriptor) - extraFields.filterNot(isNullable).foreach { extraField => + extraFields.filter(_.isRequired).foreach { extraField => throw QueryCompilationErrors.cannotFindProtobufFieldInCatalystError( toFieldStr(protoPath :+ extraField.getName())) } @@ -283,9 +283,4 @@ private[sql] object ProtobufUtils extends Logging { case Seq() => "top-level record" case n => s"field '${n.mkString(".")}'" } - - /** Return true if `fieldDescriptor` is optional. */ - private[protobuf] def isNullable(fieldDescriptor: FieldDescriptor): Boolean = - !fieldDescriptor.isOptional - } diff --git a/connector/protobuf/src/test/resources/protobuf/proto2_messages.desc b/connector/protobuf/src/test/resources/protobuf/proto2_messages.desc new file mode 100644 index 0000000000000..a9e4099a7f2b5 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages.desc @@ -0,0 +1,8 @@ + + +proto2_messages.proto$org.apache.spark.sql.protobuf.protos"@ +FoobarWithRequiredFieldBar +foo ( Rfoo +bar (Rbar" + NestedFoobarWithRequiredFieldBare + nested_foobar ( 2@.org.apache.spark.sql.protobuf.protos.FoobarWithRequiredFieldBarR nestedFoobarBBProto2Messages \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto new file mode 100644 index 0000000000000..a5d09df8514e0 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/proto2_messages.proto @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package org.apache.spark.sql.protobuf.protos; +option java_outer_classname = "Proto2Messages"; + + +// Used to test missing required field bar in top level schema. +message FoobarWithRequiredFieldBar { + optional string foo = 1; + required int32 bar = 2; +} + +// Used to test missing required field bar in nested struct. +message NestedFoobarWithRequiredFieldBar { + optional FoobarWithRequiredFieldBar nested_foobar = 1; +} diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto index a7459213a87b2..87152b035b015 100644 --- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto @@ -59,17 +59,6 @@ message TypeMiss { int64 bar = 1; } -/* Field boo missing from SQL root, but available in Protobuf root*/ -message FieldMissingInSQLRoot { - Foo foo = 1; - int32 boo = 2; -} - -/* Field baz missing from SQL nested and available in Protobuf nested*/ -message FieldMissingInSQLNested { - Baz foo = 1; -} - message Baz { int32 bar = 1; int32 baz = 2; diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 87ed534094331..356cd20eb4e4d 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -21,7 +21,7 @@ import com.google.protobuf.Descriptors.Descriptor import com.google.protobuf.DynamicMessage import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.NoopFilters +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters} import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession @@ -39,6 +39,8 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { val testFileDesc = testFile("serde_suite.desc", "protobuf/serde_suite.desc") private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" + val proto2Desc = testFile("proto2_messages.desc", "protobuf/proto2_messages.desc") + test("Test basic conversion") { withFieldMatchType { fieldMatch => val (top, nest) = fieldMatch match { @@ -64,6 +66,28 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { } } + test("Optional fields can be dropped from input SQL schema for the serializer") { + // This test verifies that optional fields can be missing from input Catalyst schema + // while serializing rows to protobuf. + + val desc = ProtobufUtils.buildDescriptor(proto2Desc, "FoobarWithRequiredFieldBar") + + // Confirm desc contains optional field 'foo' and required field bar. + assert(desc.getFields.size() == 2) + assert(desc.findFieldByName("foo").isOptional) + + // Use catalyst type without optional "foo". + val sqlType = structFromDDL("struct") + val serializer = new ProtobufSerializer(sqlType, desc, nullable = false) // Should work fine. + + // Should be able to deserialize a row. + val protoMessage = serializer.serialize(InternalRow(22)).asInstanceOf[DynamicMessage] + + // Verify the descriptor and the value. + assert(protoMessage.getDescriptorForType == desc) + assert(protoMessage.getField(desc.findFieldByName("bar")).asInstanceOf[Int] == 22) + } + test("Fail to convert with field type mismatch") { val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot") withFieldMatchType { fieldMatch => @@ -144,44 +168,50 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { test("Fail to convert with missing Catalyst fields") { val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") - // serializing with extra fails if extra field is missing in SQL Schema + val foobarSQLType = structFromDDL("struct") // "bar" is missing. + assertFailedConversionMessage( - protoFile, + ProtobufUtils.buildDescriptor(proto2Desc, "FoobarWithRequiredFieldBar"), Serializer, BY_NAME, + catalystSchema = foobarSQLType, errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( - "protobufType" -> "FieldMissingInSQLRoot", - "toType" -> toSQLType(CATALYST_STRUCT))) + "protobufType" -> "FoobarWithRequiredFieldBar", + "toType" -> toSQLType(foobarSQLType))) /* deserializing should work regardless of whether the extra field is missing in SQL Schema or not */ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) - val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested") + val protoNestedFile = ProtobufUtils + .buildDescriptor(proto2Desc, "NestedFoobarWithRequiredFieldBar") - // serializing with extra fails if extra field is missing in SQL Schema + val nestedFoobarSQLType = structFromDDL( + "struct>" // "bar" field is missing. + ) + // serializing with extra fails if required field is missing in inner struct assertFailedConversionMessage( - protoNestedFile, + ProtobufUtils.buildDescriptor(proto2Desc, "NestedFoobarWithRequiredFieldBar"), Serializer, BY_NAME, + catalystSchema = nestedFoobarSQLType, errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( - "protobufType" -> "FieldMissingInSQLNested", - "toType" -> toSQLType(CATALYST_STRUCT))) + "protobufType" -> "NestedFoobarWithRequiredFieldBar", + "toType" -> toSQLType(nestedFoobarSQLType))) /* deserializing should work regardless of whether the extra field is missing in SQL Schema or not */ - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(nestedFoobarSQLType, protoNestedFile, _)) } test("raise cannot parse and construct protobuf descriptor error") { // passing serde_suite.proto instead serde_suite.desc var testFileDesc = testFile("serde_suite.proto", "protobuf/serde_suite.proto") val e1 = intercept[AnalysisException] { - ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + ProtobufUtils.buildDescriptor(testFileDesc, "SerdeBasicMessage") } checkError( @@ -191,7 +221,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { testFileDesc = testFile("basicmessage_noimports.desc", "protobuf/basicmessage_noimports.desc") val e2 = intercept[AnalysisException] { - ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + ProtobufUtils.buildDescriptor(testFileDesc, "SerdeBasicMessage") } checkError( diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestBase.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestBase.scala index 831b4a26c06d9..2ead89e4545c2 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestBase.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.protobuf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{DataType, StructType} trait ProtobufTestBase extends SQLTestUtils { @@ -34,4 +35,7 @@ trait ProtobufTestBase extends SQLTestUtils { } ret.replace("file:/", "/") } + + protected def structFromDDL(ddl: String): StructType = + DataType.fromDDL(ddl).asInstanceOf[StructType] }