Skip to content

Commit 1b490e9

Browse files
yhuaimarmbrus
authored andcommitted
[SPARK-5950][SQL]Insert array into a metastore table saved as parquet should work when using datasource api
This PR contains the following changes: 1. Add a new method, `DataType.equalsIgnoreCompatibleNullability`, which is the middle ground between DataType's equality check and `DataType.equalsIgnoreNullability`. For two data types `from` and `to`, it does `equalsIgnoreNullability` as well as if the nullability of `from` is compatible with that of `to`. For example, the nullability of `ArrayType(IntegerType, containsNull = false)` is compatible with that of `ArrayType(IntegerType, containsNull = true)` (for an array without null values, we can always say it may contain null values). However, the nullability of `ArrayType(IntegerType, containsNull = true)` is incompatible with that of `ArrayType(IntegerType, containsNull = false)` (for an array that may have null values, we cannot say it does not have null values). 2. For the `resolved` field of `InsertIntoTable`, use `equalsIgnoreCompatibleNullability` to replace the equality check of the data types. 3. For our data source write path, when appending data, we always use the schema of existing table to write the data. This is important for parquet, since nullability direct impacts the way to encode/decode values. If we do not do this, we may see corrupted values when reading values from a set of parquet files generated with different nullability settings. 4. When generating a new parquet table, we always set nullable/containsNull/valueContainsNull to true. So, we will not face situations that we cannot append data because containsNull/valueContainsNull in an Array/Map column of the existing table has already been set to `false`. This change makes the whole data pipeline more robust. 5. Update the equality check of JSON relation. Since JSON does not really cares nullability, `equalsIgnoreNullability` seems a better choice to compare schemata from to JSON tables. JIRA: https://issues.apache.org/jira/browse/SPARK-5950 Thanks viirya for the initial work in #4729. cc marmbrus liancheng Author: Yin Huai <[email protected]> Closes #4826 from yhuai/insertNullabilityCheck and squashes the following commits: 3b61a04 [Yin Huai] Revert change on equals. 80e487e [Yin Huai] asNullable in UDT. 587d88b [Yin Huai] Make methods private. 0cb7ea2 [Yin Huai] marmbrus's comments. 3cec464 [Yin Huai] Cheng's comments. 486ed08 [Yin Huai] Merge remote-tracking branch 'upstream/master' into insertNullabilityCheck d3747d1 [Yin Huai] Remove unnecessary change. 8360817 [Yin Huai] Merge remote-tracking branch 'upstream/master' into insertNullabilityCheck 8a3f237 [Yin Huai] Use equalsIgnoreNullability instead of equality check. 0eb5578 [Yin Huai] Fix tests. f6ed813 [Yin Huai] Update old parquet path. e4f397c [Yin Huai] Unit tests. b2c06f8 [Yin Huai] Ignore nullability in JSON relation's equality check. 8bd008b [Yin Huai] nullable, containsNull, and valueContainsNull will be always true for parquet data. bf50d73 [Yin Huai] When appending data, we use the schema of the existing table instead of the schema of the new data. 0a703e7 [Yin Huai] Test failed again since we cannot read correct content. 9a26611 [Yin Huai] Make InsertIntoTable happy. 8f19fe5 [Yin Huai] equalsIgnoreCompatibleNullability 4ec17fd [Yin Huai] Failed test. (cherry picked from commit 1259994) Signed-off-by: Michael Armbrust <[email protected]>
1 parent ffd0591 commit 1b490e9

File tree

17 files changed

+330
-36
lines changed

17 files changed

+330
-36
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
182182
case _ => false
183183
}
184184
}
185+
186+
private[spark] override def asNullable: VectorUDT = this
185187
}
186188

187189
/**

mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ private[mllib] object Loader {
110110
assert(loadedFields.contains(field.name), s"Unable to parse model data." +
111111
s" Expected field with name ${field.name} was missing in loaded schema:" +
112112
s" ${loadedFields.mkString(", ")}")
113-
assert(loadedFields(field.name) == field.dataType,
113+
assert(loadedFields(field.name).sameType(field.dataType),
114114
s"Unable to parse model data. Expected field $field but found field" +
115115
s" with different type: ${loadedFields(field.name)}")
116116
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ case class InsertIntoTable(
120120
override def output = child.output
121121

122122
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
123-
case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType
123+
case (childAttr, tableAttr) =>
124+
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
124125
}
125126
}
126127

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ object DataType {
181181
/**
182182
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
183183
*/
184-
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
184+
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
185185
(left, right) match {
186186
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
187187
equalsIgnoreNullability(leftElementType, rightElementType)
@@ -198,6 +198,43 @@ object DataType {
198198
case (left, right) => left == right
199199
}
200200
}
201+
202+
/**
203+
* Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
204+
*
205+
* Compatible nullability is defined as follows:
206+
* - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
207+
* if and only if `to.containsNull` is true, or both of `from.containsNull` and
208+
* `to.containsNull` are false.
209+
* - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
210+
* if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
211+
* `to.valueContainsNull` are false.
212+
* - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
213+
* if and only if for all every pair of fields, `to.nullable` is true, or both
214+
* of `fromField.nullable` and `toField.nullable` are false.
215+
*/
216+
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
217+
(from, to) match {
218+
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
219+
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
220+
221+
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
222+
(tn || !fn) &&
223+
equalsIgnoreCompatibleNullability(fromKey, toKey) &&
224+
equalsIgnoreCompatibleNullability(fromValue, toValue)
225+
226+
case (StructType(fromFields), StructType(toFields)) =>
227+
fromFields.size == toFields.size &&
228+
fromFields.zip(toFields).forall {
229+
case (fromField, toField) =>
230+
fromField.name == toField.name &&
231+
(toField.nullable || !fromField.nullable) &&
232+
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
233+
}
234+
235+
case (fromDataType, toDataType) => fromDataType == toDataType
236+
}
237+
}
201238
}
202239

203240

@@ -230,6 +267,17 @@ abstract class DataType {
230267
def prettyJson: String = pretty(render(jsonValue))
231268

232269
def simpleString: String = typeName
270+
271+
/** Check if `this` and `other` are the same data type when ignoring nullability
272+
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
273+
*/
274+
private[spark] def sameType(other: DataType): Boolean =
275+
DataType.equalsIgnoreNullability(this, other)
276+
277+
/** Returns the same data type but set all nullability fields are true
278+
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
279+
*/
280+
private[spark] def asNullable: DataType
233281
}
234282

235283
/**
@@ -245,6 +293,8 @@ class NullType private() extends DataType {
245293
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
246294
// Defined with a private constructor so the companion object is the only possible instantiation.
247295
override def defaultSize: Int = 1
296+
297+
private[spark] override def asNullable: NullType = this
248298
}
249299

250300
case object NullType extends NullType
@@ -310,6 +360,8 @@ class StringType private() extends NativeType with PrimitiveType {
310360
* The default size of a value of the StringType is 4096 bytes.
311361
*/
312362
override def defaultSize: Int = 4096
363+
364+
private[spark] override def asNullable: StringType = this
313365
}
314366

315367
case object StringType extends StringType
@@ -344,6 +396,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
344396
* The default size of a value of the BinaryType is 4096 bytes.
345397
*/
346398
override def defaultSize: Int = 4096
399+
400+
private[spark] override def asNullable: BinaryType = this
347401
}
348402

349403
case object BinaryType extends BinaryType
@@ -369,6 +423,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
369423
* The default size of a value of the BooleanType is 1 byte.
370424
*/
371425
override def defaultSize: Int = 1
426+
427+
private[spark] override def asNullable: BooleanType = this
372428
}
373429

374430
case object BooleanType extends BooleanType
@@ -399,6 +455,8 @@ class TimestampType private() extends NativeType {
399455
* The default size of a value of the TimestampType is 12 bytes.
400456
*/
401457
override def defaultSize: Int = 12
458+
459+
private[spark] override def asNullable: TimestampType = this
402460
}
403461

404462
case object TimestampType extends TimestampType
@@ -427,6 +485,8 @@ class DateType private() extends NativeType {
427485
* The default size of a value of the DateType is 4 bytes.
428486
*/
429487
override def defaultSize: Int = 4
488+
489+
private[spark] override def asNullable: DateType = this
430490
}
431491

432492
case object DateType extends DateType
@@ -485,6 +545,8 @@ class LongType private() extends IntegralType {
485545
override def defaultSize: Int = 8
486546

487547
override def simpleString = "bigint"
548+
549+
private[spark] override def asNullable: LongType = this
488550
}
489551

490552
case object LongType extends LongType
@@ -514,6 +576,8 @@ class IntegerType private() extends IntegralType {
514576
override def defaultSize: Int = 4
515577

516578
override def simpleString = "int"
579+
580+
private[spark] override def asNullable: IntegerType = this
517581
}
518582

519583
case object IntegerType extends IntegerType
@@ -543,6 +607,8 @@ class ShortType private() extends IntegralType {
543607
override def defaultSize: Int = 2
544608

545609
override def simpleString = "smallint"
610+
611+
private[spark] override def asNullable: ShortType = this
546612
}
547613

548614
case object ShortType extends ShortType
@@ -572,6 +638,8 @@ class ByteType private() extends IntegralType {
572638
override def defaultSize: Int = 1
573639

574640
override def simpleString = "tinyint"
641+
642+
private[spark] override def asNullable: ByteType = this
575643
}
576644

577645
case object ByteType extends ByteType
@@ -638,6 +706,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
638706
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
639707
case None => "decimal(10,0)"
640708
}
709+
710+
private[spark] override def asNullable: DecimalType = this
641711
}
642712

643713

@@ -696,6 +766,8 @@ class DoubleType private() extends FractionalType {
696766
* The default size of a value of the DoubleType is 8 bytes.
697767
*/
698768
override def defaultSize: Int = 8
769+
770+
private[spark] override def asNullable: DoubleType = this
699771
}
700772

701773
case object DoubleType extends DoubleType
@@ -724,6 +796,8 @@ class FloatType private() extends FractionalType {
724796
* The default size of a value of the FloatType is 4 bytes.
725797
*/
726798
override def defaultSize: Int = 4
799+
800+
private[spark] override def asNullable: FloatType = this
727801
}
728802

729803
case object FloatType extends FloatType
@@ -772,6 +846,9 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
772846
override def defaultSize: Int = 100 * elementType.defaultSize
773847

774848
override def simpleString = s"array<${elementType.simpleString}>"
849+
850+
private[spark] override def asNullable: ArrayType =
851+
ArrayType(elementType.asNullable, containsNull = true)
775852
}
776853

777854

@@ -1017,6 +1094,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10171094
*/
10181095
private[sql] def merge(that: StructType): StructType =
10191096
StructType.merge(this, that).asInstanceOf[StructType]
1097+
1098+
private[spark] override def asNullable: StructType = {
1099+
val newFields = fields.map {
1100+
case StructField(name, dataType, nullable, metadata) =>
1101+
StructField(name, dataType.asNullable, nullable = true, metadata)
1102+
}
1103+
1104+
StructType(newFields)
1105+
}
10201106
}
10211107

10221108

@@ -1069,6 +1155,9 @@ case class MapType(
10691155
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
10701156

10711157
override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
1158+
1159+
private[spark] override def asNullable: MapType =
1160+
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
10721161
}
10731162

10741163

@@ -1122,4 +1211,10 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
11221211
* The default size of a value of the UserDefinedType is 4096 bytes.
11231212
*/
11241213
override def defaultSize: Int = 4096
1214+
1215+
/**
1216+
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
1217+
* itself.
1218+
*/
1219+
private[spark] override def asNullable: UserDefinedType[UserType] = this
11251220
}

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,87 @@ class DataTypeSuite extends FunSuite {
115115
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
116116
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
117117
checkDefaultSize(structType, 812)
118+
119+
def checkEqualsIgnoreCompatibleNullability(
120+
from: DataType,
121+
to: DataType,
122+
expected: Boolean): Unit = {
123+
val testName =
124+
s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})"
125+
test(testName) {
126+
assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected)
127+
}
128+
}
129+
130+
checkEqualsIgnoreCompatibleNullability(
131+
from = ArrayType(DoubleType, containsNull = true),
132+
to = ArrayType(DoubleType, containsNull = true),
133+
expected = true)
134+
checkEqualsIgnoreCompatibleNullability(
135+
from = ArrayType(DoubleType, containsNull = false),
136+
to = ArrayType(DoubleType, containsNull = false),
137+
expected = true)
138+
checkEqualsIgnoreCompatibleNullability(
139+
from = ArrayType(DoubleType, containsNull = false),
140+
to = ArrayType(DoubleType, containsNull = true),
141+
expected = true)
142+
checkEqualsIgnoreCompatibleNullability(
143+
from = ArrayType(DoubleType, containsNull = true),
144+
to = ArrayType(DoubleType, containsNull = false),
145+
expected = false)
146+
checkEqualsIgnoreCompatibleNullability(
147+
from = ArrayType(DoubleType, containsNull = false),
148+
to = ArrayType(StringType, containsNull = false),
149+
expected = false)
150+
151+
checkEqualsIgnoreCompatibleNullability(
152+
from = MapType(StringType, DoubleType, valueContainsNull = true),
153+
to = MapType(StringType, DoubleType, valueContainsNull = true),
154+
expected = true)
155+
checkEqualsIgnoreCompatibleNullability(
156+
from = MapType(StringType, DoubleType, valueContainsNull = false),
157+
to = MapType(StringType, DoubleType, valueContainsNull = false),
158+
expected = true)
159+
checkEqualsIgnoreCompatibleNullability(
160+
from = MapType(StringType, DoubleType, valueContainsNull = false),
161+
to = MapType(StringType, DoubleType, valueContainsNull = true),
162+
expected = true)
163+
checkEqualsIgnoreCompatibleNullability(
164+
from = MapType(StringType, DoubleType, valueContainsNull = true),
165+
to = MapType(StringType, DoubleType, valueContainsNull = false),
166+
expected = false)
167+
checkEqualsIgnoreCompatibleNullability(
168+
from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
169+
to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
170+
expected = false)
171+
checkEqualsIgnoreCompatibleNullability(
172+
from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
173+
to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
174+
expected = true)
175+
176+
177+
checkEqualsIgnoreCompatibleNullability(
178+
from = StructType(StructField("a", StringType, nullable = true) :: Nil),
179+
to = StructType(StructField("a", StringType, nullable = true) :: Nil),
180+
expected = true)
181+
checkEqualsIgnoreCompatibleNullability(
182+
from = StructType(StructField("a", StringType, nullable = false) :: Nil),
183+
to = StructType(StructField("a", StringType, nullable = false) :: Nil),
184+
expected = true)
185+
checkEqualsIgnoreCompatibleNullability(
186+
from = StructType(StructField("a", StringType, nullable = false) :: Nil),
187+
to = StructType(StructField("a", StringType, nullable = true) :: Nil),
188+
expected = true)
189+
checkEqualsIgnoreCompatibleNullability(
190+
from = StructType(StructField("a", StringType, nullable = true) :: Nil),
191+
to = StructType(StructField("a", StringType, nullable = false) :: Nil),
192+
expected = false)
193+
checkEqualsIgnoreCompatibleNullability(
194+
from = StructType(
195+
StructField("a", StringType, nullable = false) ::
196+
StructField("b", StringType, nullable = true) :: Nil),
197+
to = StructType(
198+
StructField("a", StringType, nullable = false) ::
199+
StructField("b", StringType, nullable = false) :: Nil),
200+
expected = false)
118201
}

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path
2323

2424
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
2525
import org.apache.spark.sql.sources._
26-
import org.apache.spark.sql.types.StructType
26+
import org.apache.spark.sql.types.{DataType, StructType}
2727

2828

2929
private[sql] class DefaultSource
@@ -131,7 +131,7 @@ private[sql] case class JSONRelation(
131131

132132
override def equals(other: Any): Boolean = other match {
133133
case that: JSONRelation =>
134-
(this.path == that.path) && (this.schema == that.schema)
134+
(this.path == that.path) && this.schema.sameType(that.schema)
135135
case _ => false
136136
}
137137
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.logging.Level
2323
import org.apache.hadoop.conf.Configuration
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.hadoop.fs.permission.FsAction
26+
import org.apache.spark.sql.types.{StructType, DataType}
2627
import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
2728
import parquet.hadoop.metadata.CompressionCodecName
2829
import parquet.schema.MessageType
@@ -172,9 +173,13 @@ private[sql] object ParquetRelation {
172173
sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED)
173174
.name())
174175
ParquetRelation.enableLogForwarding()
175-
ParquetTypesConverter.writeMetaData(attributes, path, conf)
176+
// This is a hack. We always set nullable/containsNull/valueContainsNull to true
177+
// for the schema of a parquet data.
178+
val schema = StructType.fromAttributes(attributes).asNullable
179+
val newAttributes = schema.toAttributes
180+
ParquetTypesConverter.writeMetaData(newAttributes, path, conf)
176181
new ParquetRelation(path.toString, Some(conf), sqlContext) {
177-
override val output = attributes
182+
override val output = newAttributes
178183
}
179184
}
180185

0 commit comments

Comments
 (0)