Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 83 additions & 65 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,23 +226,18 @@ object DataType {
}
}

private val NoNameCheck = 0
private val CaseSensitiveNameCheck = 1
private val CaseInsensitiveNameCheck = 2
private val NoNullabilityCheck = 0
private val NullabilityCheck = 1
private val CompatibleNullabilityCheck = 2

/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
equalsIgnoreNullability(leftValueType, rightValueType)
case (StructType(leftFields), StructType(rightFields)) =>
leftFields.length == rightFields.length &&
leftFields.zip(rightFields).forall { case (l, r) =>
l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
}
case (l, r) => l == r
}
equalsDataTypes(left, right, CaseSensitiveNameCheck, NoNullabilityCheck)
}

/**
Expand All @@ -260,49 +255,15 @@ object DataType {
* of `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
(tn || !fn) &&
equalsIgnoreCompatibleNullability(fromKey, toKey) &&
equalsIgnoreCompatibleNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (fromField, toField) =>
fromField.name == toField.name &&
(toField.nullable || !fromField.nullable) &&
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
equalsDataTypes(from, to, CaseSensitiveNameCheck, CompatibleNullabilityCheck)
}

/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType, and ignoring case
* sensitivity of field names in StructType.
*/
private[sql] def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
equalsIgnoreCaseAndNullability(fromElement, toElement)

case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
equalsIgnoreCaseAndNullability(fromKey, toKey) &&
equalsIgnoreCaseAndNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (l, r) =>
l.name.equalsIgnoreCase(r.name) &&
equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
equalsDataTypes(from, to, CaseInsensitiveNameCheck, NoNullabilityCheck)
}

/**
Expand All @@ -315,25 +276,82 @@ object DataType {
from: DataType,
to: DataType,
ignoreNullability: Boolean = false): Boolean = {
(from, to) match {
if (ignoreNullability) {
equalsDataTypes(from, to, NoNameCheck, NoNullabilityCheck)
} else {
equalsDataTypes(from, to, NoNameCheck, NullabilityCheck)
}
}

/** Given the fieldNames compare for equality based on nameCheckType */
private def isSameFieldName(left: String, right: String, nameCheckType: Int): Boolean = {
nameCheckType match {
case NoNameCheck => true
case CaseSensitiveNameCheck => left == right
case CaseInsensitiveNameCheck => left.toLowerCase == right.toLowerCase
}
}

/** Given the nullability of two datatypes compare for equality based on nullabilityCheckType */
private def isSameNullability(
leftNullability: Boolean,
rightNullability: Boolean,
nullabilityCheckType: Int): Boolean = {
nullabilityCheckType match {
case NoNullabilityCheck => true
case NullabilityCheck => leftNullability == rightNullability
case CompatibleNullabilityCheck => rightNullability || !leftNullability
}
}

/**
* Compare two dataTypes based on -
* nameCheckType - (NoNameCheck, CaseSensitiveNameCheck, CaseInsensitiveNameCheck)
* nullabilityCheckType - (NoNullabilityCheck, NullabilityCheck, CompatibleNullabilityCheck)
* @param left
* @param right
* @param nameCheckType
* @param nullabilityCheckType
* @return
*/
private def equalsDataTypes(
left: DataType,
right: DataType,
nameCheckType: Int,
nullabilityCheckType: Int
): Boolean = {
(left, right) match {
case (left: ArrayType, right: ArrayType) =>
equalsStructurally(left.elementType, right.elementType) &&
(ignoreNullability || left.containsNull == right.containsNull)
val sameNullability = isSameNullability(left.containsNull, right.containsNull,
nullabilityCheckType)
val sameType = equalsDataTypes(left.elementType, right.elementType,
nameCheckType, nullabilityCheckType)
sameNullability && sameType

case (left: MapType, right: MapType) =>
equalsStructurally(left.keyType, right.keyType) &&
equalsStructurally(left.valueType, right.valueType) &&
(ignoreNullability || left.valueContainsNull == right.valueContainsNull)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields)
.forall { case (l, r) =>
equalsStructurally(l.dataType, r.dataType) &&
(ignoreNullability || l.nullable == r.nullable)
}

case (fromDataType, toDataType) => fromDataType == toDataType
val sameNullability = isSameNullability(left.valueContainsNull, right.valueContainsNull,
nullabilityCheckType)
val sameKeyType = equalsDataTypes(left.keyType, right.keyType,
nameCheckType, nullabilityCheckType)
val sameValueType = equalsDataTypes(left.valueType, right.valueType,
nameCheckType, nullabilityCheckType)
sameNullability && sameKeyType && sameValueType

case (StructType(leftFields), StructType(rightFields)) =>
leftFields.length == rightFields.length &&
leftFields.zip(rightFields).forall { case (lf, rf) =>
val sameFieldName = isSameFieldName(lf.name, rf.name, nameCheckType)
val sameNullability = isSameNullability(lf.nullable, rf.nullable, nullabilityCheckType)
val sameType = equalsDataTypes(lf.dataType, rf.dataType,
nameCheckType, nullabilityCheckType)

sameFieldName && sameNullability && sameType
}

case (leftDataType, rightDataType) => leftDataType == rightDataType
}
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,88 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 12)
checkDefaultSize(structType, 20)

def checkEqualsIgnoreNullability(
from: DataType,
to: DataType,
expected: Boolean): Unit = {
val testName =
s"equalsIgnoreNullability: (from: $from, to: $to)"
test(testName) {
assert(DataType.equalsIgnoreNullability(from, to) === expected)
}
}

checkEqualsIgnoreNullability(
from = ArrayType(DoubleType, containsNull = false),
to = ArrayType(DoubleType, containsNull = true),
expected = true)
checkEqualsIgnoreNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false):: Nil
),
expected = true)
checkEqualsIgnoreNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("c", ArrayType(DoubleType, containsNull = false), nullable = false):: Nil
),
expected = false)
checkEqualsIgnoreNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("B", ArrayType(DoubleType, containsNull = false), nullable = true):: Nil
),
expected = false)
checkEqualsIgnoreNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(StringType, StringType, valueContainsNull = false),
containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(StringType, StringType, valueContainsNull = false),
containsNull = false), nullable = true):: Nil
),
expected = true)

def checkEqualsIgnoreCaseAndNullability(
from: DataType,
to: DataType,
expected: Boolean): Unit = {
val testName =
s"equalsIgnoreCaseAndNullability: (from: $from, to: $to)"
test(testName) {
assert(DataType.equalsIgnoreCaseAndNullability(from, to) === expected)
}
}

checkEqualsIgnoreCaseAndNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false),
containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("B", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false),
containsNull = false), nullable = true):: Nil
),
expected = true)

def checkEqualsIgnoreCompatibleNullability(
from: DataType,
to: DataType,
Expand Down Expand Up @@ -392,6 +474,30 @@ class DataTypeSuite extends SparkFunSuite {
StructField("a", StringType, nullable = false) ::
StructField("b", StringType, nullable = false) :: Nil),
expected = false)
checkEqualsIgnoreCompatibleNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false),
containsNull = false), nullable = true):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("B", ArrayType(MapType(IntegerType, StringType, valueContainsNull = true),
containsNull = false), nullable = false):: Nil
),
expected = false)
checkEqualsIgnoreCompatibleNullability(
from = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = false),
containsNull = false), nullable = false):: Nil
),
to = StructType(
StructField("a", DoubleType, nullable = false)::
StructField("b", ArrayType(MapType(IntegerType, StringType, valueContainsNull = true),
containsNull = false), nullable = true):: Nil
),
expected = true)

def checkCatalogString(dt: DataType): Unit = {
test(s"catalogString: $dt") {
Expand Down