Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
case _ => false
}
}

private[spark] override def asNullable: VectorUDT = this
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private[mllib] object Loader {
assert(loadedFields.contains(field.name), s"Unable to parse model data." +
s" Expected field with name ${field.name} was missing in loaded schema:" +
s" ${loadedFields.mkString(", ")}")
assert(loadedFields(field.name) == field.dataType,
assert(loadedFields(field.name).sameType(field.dataType),
s"Unable to parse model data. Expected field $field but found field" +
s" with different type: ${loadedFields(field.name)}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ case class InsertIntoTable(
override def output = child.output

override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType
case (childAttr, tableAttr) =>
DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
Expand All @@ -198,6 +198,43 @@ object DataType {
case (left, right) => left == right
}
}

/**
* Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
*
* Compatible nullability is defined as follows:
* - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
* if and only if `to.containsNull` is true, or both of `from.containsNull` and
* `to.containsNull` are false.
* - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
* if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
* `to.valueContainsNull` are false.
* - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
* if and only if for all every pair of fields, `to.nullable` is true, or both
* of `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can introduce a method to the class of DataType based on this one later (I am not sure what will be a good name. I thought about compatibleWith, but I feel it is not very accurate).

(from, to) match {
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
(tn || !fn) &&
equalsIgnoreCompatibleNullability(fromKey, toKey) &&
equalsIgnoreCompatibleNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.size == toFields.size &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
fromField.name == toField.name &&
(toField.nullable || !fromField.nullable) &&
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}
}


Expand Down Expand Up @@ -230,6 +267,17 @@ abstract class DataType {
def prettyJson: String = pretty(render(jsonValue))

def simpleString: String = typeName

/** Check if `this` and `other` are the same data type when ignoring nullability
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def sameType(other: DataType): Boolean =
DataType.equalsIgnoreNullability(this, other)

/** Returns the same data type but set all nullability fields are true
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def asNullable: DataType
}

/**
Expand All @@ -245,6 +293,8 @@ class NullType private() extends DataType {
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
override def defaultSize: Int = 1

private[spark] override def asNullable: NullType = this
}

case object NullType extends NullType
Expand Down Expand Up @@ -310,6 +360,8 @@ class StringType private() extends NativeType with PrimitiveType {
* The default size of a value of the StringType is 4096 bytes.
*/
override def defaultSize: Int = 4096

private[spark] override def asNullable: StringType = this
}

case object StringType extends StringType
Expand Down Expand Up @@ -344,6 +396,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
* The default size of a value of the BinaryType is 4096 bytes.
*/
override def defaultSize: Int = 4096

private[spark] override def asNullable: BinaryType = this
}

case object BinaryType extends BinaryType
Expand All @@ -369,6 +423,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
* The default size of a value of the BooleanType is 1 byte.
*/
override def defaultSize: Int = 1

private[spark] override def asNullable: BooleanType = this
}

case object BooleanType extends BooleanType
Expand Down Expand Up @@ -399,6 +455,8 @@ class TimestampType private() extends NativeType {
* The default size of a value of the TimestampType is 12 bytes.
*/
override def defaultSize: Int = 12

private[spark] override def asNullable: TimestampType = this
}

case object TimestampType extends TimestampType
Expand Down Expand Up @@ -427,6 +485,8 @@ class DateType private() extends NativeType {
* The default size of a value of the DateType is 4 bytes.
*/
override def defaultSize: Int = 4

private[spark] override def asNullable: DateType = this
}

case object DateType extends DateType
Expand Down Expand Up @@ -485,6 +545,8 @@ class LongType private() extends IntegralType {
override def defaultSize: Int = 8

override def simpleString = "bigint"

private[spark] override def asNullable: LongType = this
}

case object LongType extends LongType
Expand Down Expand Up @@ -514,6 +576,8 @@ class IntegerType private() extends IntegralType {
override def defaultSize: Int = 4

override def simpleString = "int"

private[spark] override def asNullable: IntegerType = this
}

case object IntegerType extends IntegerType
Expand Down Expand Up @@ -543,6 +607,8 @@ class ShortType private() extends IntegralType {
override def defaultSize: Int = 2

override def simpleString = "smallint"

private[spark] override def asNullable: ShortType = this
}

case object ShortType extends ShortType
Expand Down Expand Up @@ -572,6 +638,8 @@ class ByteType private() extends IntegralType {
override def defaultSize: Int = 1

override def simpleString = "tinyint"

private[spark] override def asNullable: ByteType = this
}

case object ByteType extends ByteType
Expand Down Expand Up @@ -638,6 +706,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}

private[spark] override def asNullable: DecimalType = this
}


Expand Down Expand Up @@ -696,6 +766,8 @@ class DoubleType private() extends FractionalType {
* The default size of a value of the DoubleType is 8 bytes.
*/
override def defaultSize: Int = 8

private[spark] override def asNullable: DoubleType = this
}

case object DoubleType extends DoubleType
Expand Down Expand Up @@ -724,6 +796,8 @@ class FloatType private() extends FractionalType {
* The default size of a value of the FloatType is 4 bytes.
*/
override def defaultSize: Int = 4

private[spark] override def asNullable: FloatType = this
}

case object FloatType extends FloatType
Expand Down Expand Up @@ -772,6 +846,9 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def defaultSize: Int = 100 * elementType.defaultSize

override def simpleString = s"array<${elementType.simpleString}>"

private[spark] override def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
}


Expand Down Expand Up @@ -1017,6 +1094,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]

private[spark] override def asNullable: StructType = {
val newFields = fields.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(name, dataType.asNullable, nullable = true, metadata)
}

StructType(newFields)
}
}


Expand Down Expand Up @@ -1069,6 +1155,9 @@ case class MapType(
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)

override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"

private[spark] override def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
}


Expand Down Expand Up @@ -1122,4 +1211,10 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* The default size of a value of the UserDefinedType is 4096 bytes.
*/
override def defaultSize: Int = 4096

/**
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
* itself.
*/
private[spark] override def asNullable: UserDefinedType[UserType] = this
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,87 @@ class DataTypeSuite extends FunSuite {
checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
checkDefaultSize(structType, 812)

def checkEqualsIgnoreCompatibleNullability(
from: DataType,
to: DataType,
expected: Boolean): Unit = {
val testName =
s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})"
test(testName) {
assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected)
}
}

checkEqualsIgnoreCompatibleNullability(
from = ArrayType(DoubleType, containsNull = true),
to = ArrayType(DoubleType, containsNull = true),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = ArrayType(DoubleType, containsNull = false),
to = ArrayType(DoubleType, containsNull = false),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = ArrayType(DoubleType, containsNull = false),
to = ArrayType(DoubleType, containsNull = true),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = ArrayType(DoubleType, containsNull = true),
to = ArrayType(DoubleType, containsNull = false),
expected = false)
checkEqualsIgnoreCompatibleNullability(
from = ArrayType(DoubleType, containsNull = false),
to = ArrayType(StringType, containsNull = false),
expected = false)

checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, DoubleType, valueContainsNull = true),
to = MapType(StringType, DoubleType, valueContainsNull = true),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, DoubleType, valueContainsNull = false),
to = MapType(StringType, DoubleType, valueContainsNull = false),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, DoubleType, valueContainsNull = false),
to = MapType(StringType, DoubleType, valueContainsNull = true),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, DoubleType, valueContainsNull = true),
to = MapType(StringType, DoubleType, valueContainsNull = false),
expected = false)
checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
expected = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add another test case to show nested case:

  checkEqualsIgnoreCompatibleNullability(
    from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
    to = MapType(StringType,  ArrayType(IntegerType, true), valueContainsNull = true),
    expected = true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (added a test after this one)

checkEqualsIgnoreCompatibleNullability(
from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true),
to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true),
expected = true)


checkEqualsIgnoreCompatibleNullability(
from = StructType(StructField("a", StringType, nullable = true) :: Nil),
to = StructType(StructField("a", StringType, nullable = true) :: Nil),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = StructType(StructField("a", StringType, nullable = false) :: Nil),
to = StructType(StructField("a", StringType, nullable = false) :: Nil),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = StructType(StructField("a", StringType, nullable = false) :: Nil),
to = StructType(StructField("a", StringType, nullable = true) :: Nil),
expected = true)
checkEqualsIgnoreCompatibleNullability(
from = StructType(StructField("a", StringType, nullable = true) :: Nil),
to = StructType(StructField("a", StringType, nullable = false) :: Nil),
expected = false)
checkEqualsIgnoreCompatibleNullability(
from = StructType(
StructField("a", StringType, nullable = false) ::
StructField("b", StringType, nullable = true) :: Nil),
to = StructType(
StructField("a", StringType, nullable = false) ::
StructField("b", StringType, nullable = false) :: Nil),
expected = false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}


private[sql] class DefaultSource
Expand Down Expand Up @@ -131,7 +131,7 @@ private[sql] case class JSONRelation(

override def equals(other: Any): Boolean = other match {
case that: JSONRelation =>
(this.path == that.path) && (this.schema == that.schema)
(this.path == that.path) && this.schema.sameType(that.schema)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.logging.Level
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.permission.FsAction
import org.apache.spark.sql.types.{StructType, DataType}
import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import parquet.hadoop.metadata.CompressionCodecName
import parquet.schema.MessageType
Expand Down Expand Up @@ -172,9 +173,13 @@ private[sql] object ParquetRelation {
sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED)
.name())
ParquetRelation.enableLogForwarding()
ParquetTypesConverter.writeMetaData(attributes, path, conf)
// This is a hack. We always set nullable/containsNull/valueContainsNull to true
// for the schema of a parquet data.
val schema = StructType.fromAttributes(attributes).asNullable
val newAttributes = schema.toAttributes
ParquetTypesConverter.writeMetaData(newAttributes, path, conf)
new ParquetRelation(path.toString, Some(conf), sqlContext) {
override val output = attributes
override val output = newAttributes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liancheng @marmbrus I am also changing the nullability for our old parquet write path to make the behavior consistent with our new write path. Let me know if there is any potential compatibility issue and we should revert this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also make data types of ParquetRelation.output always nullable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, since we always write nullable data, it should be OK to leave ParquetRelation.output untouched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified that when merging schemas, official Parquet implementation will handle nullability (repetition level) properly. So our change should be safe for interoperation with other systems that support Parquet schema evolving.

}
}

Expand Down
Loading